msi [2/2]: Reimplement joins to allow joining any number of tables, each of arbitrary size

James Hawkins truiken at gmail.com
Wed Aug 1 16:25:45 CDT 2007


Hi,

Another fix for the Circuit Design Suite installer.  Also, this
implementation fixes two todo wines that weren't working in the last
implementation.

Changelog:
* Reimplement joins to allow joining any number of tables, each of
arbitrary size.

 dlls/msi/join.c     |  308 ++++++++++++++++++++++++++-------------------------
 dlls/msi/query.h    |    3
 dlls/msi/sql.y      |   36 +++++-
 dlls/msi/tests/db.c |   25 +---
 4 files changed, 200 insertions(+), 172 deletions(-)

-- 
James Hawkins
-------------- next part --------------
diff --git a/dlls/msi/join.c b/dlls/msi/join.c
index e071078..325f58e 100644
--- a/dlls/msi/join.c
+++ b/dlls/msi/join.c
@@ -23,7 +23,6 @@ #include <stdarg.h>
 #include "windef.h"
 #include "winbase.h"
 #include "winerror.h"
-#include "wine/debug.h"
 #include "msi.h"
 #include "msiquery.h"
 #include "objbase.h"
@@ -31,157 +30,150 @@ #include "objidl.h"
 #include "msipriv.h"
 #include "query.h"
 
+#include "wine/debug.h"
+#include "wine/unicode.h"
+
 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
 
+typedef struct tagJOINTABLE
+{
+    struct list entry;
+    MSIVIEW *view;
+    UINT columns;
+    UINT rows;
+    UINT next_rows;
+} JOINTABLE;
+
 typedef struct tagMSIJOINVIEW
 {
     MSIVIEW        view;
     MSIDATABASE   *db;
-    MSIVIEW       *left, *right;
-    UINT           left_count, right_count;
-    UINT           left_rows, right_rows;
+    struct list    tables;
+    UINT           columns;
+    UINT           rows;
 } MSIJOINVIEW;
 
 static UINT JOIN_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
 {
     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
-    MSIVIEW *table;
+    JOINTABLE *table;
+    UINT cols = 0;
+    UINT prev_rows = 1;
 
-    TRACE("%p %d %d %p\n", jv, row, col, val );
+    TRACE("%d, %d\n", row, col);
 
-    if( !jv->left || !jv->right )
+    if (col == 0 || col > jv->columns)
          return ERROR_FUNCTION_FAILED;
 
-    if( (col==0) || (col>(jv->left_count + jv->right_count)) )
+    if (row >= jv->rows)
          return ERROR_FUNCTION_FAILED;
 
-    if( row >= (jv->left_rows * jv->right_rows) )
-         return ERROR_FUNCTION_FAILED;
-
-    if( col <= jv->left_count )
-    {
-        table = jv->left;
-        row = (row/jv->right_rows);
-    }
-    else
+    LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
     {
-        table = jv->right;
-        row = (row % jv->right_rows);
-        col -= jv->left_count;
+        if (col <= cols + table->columns)
+        {
+            row = (row % (jv->rows / table->next_rows)) / prev_rows;
+            col -= cols;
+            break;
+        }
+
+        prev_rows = table->rows;
+        cols += table->columns;
     }
 
-    return table->ops->fetch_int( table, row, col, val );
+    return table->view->ops->fetch_int( table->view, row, col, val );
 }
 
 static UINT JOIN_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm)
 {
     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
-    MSIVIEW *table;
+    JOINTABLE *table;
+    UINT cols = 0;
+    UINT prev_rows = 1;
 
     TRACE("%p %d %d %p\n", jv, row, col, stm );
 
-    if( !jv->left || !jv->right )
+    if (col == 0 || col > jv->columns)
          return ERROR_FUNCTION_FAILED;
 
-    if( (col==0) || (col>(jv->left_count + jv->right_count)) )
+    if (row >= jv->rows)
          return ERROR_FUNCTION_FAILED;
 
-    if( row >= jv->left_rows * jv->right_rows )
-         return ERROR_FUNCTION_FAILED;
-
-    if( row <= jv->left_count )
-    {
-        table = jv->left;
-        row = (row/jv->right_rows);
-    }
-    else
+    LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
     {
-        table = jv->right;
-        row = (row % jv->right_rows);
-        col -= jv->left_count;
+        if (col <= cols + table->columns)
+        {
+            row = (row % (jv->rows / table->next_rows)) / prev_rows;
+            col -= cols;
+            break;
+        }
+
+        prev_rows = table->rows;
+        cols += table->columns;
     }
 
-    return table->ops->fetch_stream( table, row, col, stm );
+    return table->view->ops->fetch_stream( table->view, row, col, stm );
 }
 
 static UINT JOIN_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec )
 {
-    MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
-    MSIVIEW *table;
-
-    TRACE("%p %d %p\n", jv, row, rec );
-
-    if( !jv->left || !jv->right )
-             return ERROR_FUNCTION_FAILED;
-
-    if( row >= jv->left_rows * jv->right_rows )
-         return ERROR_FUNCTION_FAILED;
-
-    if( row <= jv->left_count )
-    {
-        table = jv->left;
-        row = (row/jv->right_rows);
-    }
-    else
-    {
-        table = jv->right;
-        row = (row % jv->right_rows);
-    }
-
-    return table->ops->get_row(table, row, rec);
+    FIXME("(%p, %d, %p): stub!\n", view, row, rec);
+    return ERROR_FUNCTION_FAILED;
 }
 
 static UINT JOIN_execute( struct tagMSIVIEW *view, MSIRECORD *record )
 {
     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
-    UINT r, *ldata = NULL, *rdata = NULL;
+    JOINTABLE *table;
+    UINT r, rows;
 
     TRACE("%p %p\n", jv, record);
 
-    if( !jv->left || !jv->right )
-         return ERROR_FUNCTION_FAILED;
-
-    r = jv->left->ops->execute( jv->left, NULL );
-    if (r != ERROR_SUCCESS)
-        return r;
-
-    r = jv->right->ops->execute( jv->right, NULL );
-    if (r != ERROR_SUCCESS)
-        return r;
-
-    /* get the number of rows in each table */
-    r = jv->left->ops->get_dimensions( jv->left, &jv->left_rows, NULL );
-    if( r != ERROR_SUCCESS )
+    LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
     {
-        ERR("can't get left table dimensions\n");
-        goto end;
+        table->view->ops->execute(table->view, NULL);
+
+        r = table->view->ops->get_dimensions(table->view, &table->rows, NULL);
+        if (r != ERROR_SUCCESS)
+        {
+            ERR("failed to get table dimensions\n");
+            return r;
+        }
+
+        /* each table must have at least one row */
+        if (table->rows == 0)
+        {
+            jv->rows = 0;
+            return ERROR_SUCCESS;
+        }
+
+        if (jv->rows == 0)
+            jv->rows = table->rows;
+        else
+            jv->rows *= table->rows;
     }
 
-    r = jv->right->ops->get_dimensions( jv->right, &jv->right_rows, NULL );
-    if( r != ERROR_SUCCESS )
+    rows = jv->rows;
+    LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
     {
-        ERR("can't get right table dimensions\n");
-        goto end;
+        rows /= table->rows;
+        table->next_rows = rows;
     }
 
-end:
-    msi_free( ldata );
-    msi_free( rdata );
-
-    return r;
+    return ERROR_SUCCESS;
 }
 
 static UINT JOIN_close( struct tagMSIVIEW *view )
 {
     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
+    JOINTABLE *table;
 
     TRACE("%p\n", jv );
 
-    if( !jv->left || !jv->right )
-        return ERROR_FUNCTION_FAILED;
-
-    jv->left->ops->close( jv->left );
-    jv->right->ops->close( jv->right );
+    LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
+    {
+        table->view->ops->close(table->view);
+    }
 
     return ERROR_SUCCESS;
 }
@@ -192,16 +184,11 @@ static UINT JOIN_get_dimensions( struct 
 
     TRACE("%p %p %p\n", jv, rows, cols );
 
-    if( cols )
-        *cols = jv->left_count + jv->right_count;
+    if (cols)
+        *cols = jv->columns;
 
-    if( rows )
-    {
-        if( !jv->left || !jv->right )
-            return ERROR_FUNCTION_FAILED;
-
-        *rows = jv->left_rows * jv->right_rows;
-    }
+    if (rows)
+        *rows = jv->rows;
 
     return ERROR_SUCCESS;
 }
@@ -210,48 +197,46 @@ static UINT JOIN_get_column_info( struct
                 UINT n, LPWSTR *name, UINT *type )
 {
     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
+    JOINTABLE *table;
+    UINT cols = 0;
 
     TRACE("%p %d %p %p\n", jv, n, name, type );
 
-    if( !jv->left || !jv->right )
-        return ERROR_FUNCTION_FAILED;
-
-    if( (n==0) || (n>(jv->left_count + jv->right_count)) )
+    if (n == 0 || n > jv->columns)
         return ERROR_FUNCTION_FAILED;
 
-    if( n <= jv->left_count )
-        return jv->left->ops->get_column_info( jv->left, n, name, type );
+    LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
+    {
+        if (n <= cols + table->columns)
+            return table->view->ops->get_column_info(table->view, n - cols, name, type);
 
-    n = n - jv->left_count;
+        cols += table->columns;
+    }
 
-    return jv->right->ops->get_column_info( jv->right, n, name, type );
+    return ERROR_FUNCTION_FAILED;
 }
 
 static UINT JOIN_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
                          MSIRECORD *rec, UINT row )
 {
-    MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
-
-    TRACE("%p %d %p\n", jv, eModifyMode, rec );
-
+    TRACE("%p %d %p\n", view, eModifyMode, rec);
     return ERROR_FUNCTION_FAILED;
 }
 
 static UINT JOIN_delete( struct tagMSIVIEW *view )
 {
     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
+    JOINTABLE *table;
 
     TRACE("%p\n", jv );
 
-    if( jv->left )
-        jv->left->ops->delete( jv->left );
-    jv->left = NULL;
-
-    if( jv->right )
-        jv->right->ops->delete( jv->right );
-    jv->right = NULL;
+    LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
+    {
+        table->view->ops->delete(table->view);
+        table->view = NULL;
+    }
 
-    msi_free( jv );
+    msi_free(jv);
 
     return ERROR_SUCCESS;
 }
@@ -260,10 +245,27 @@ static UINT JOIN_find_matching_rows( str
     UINT val, UINT *row, MSIITERHANDLE *handle )
 {
     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
+    UINT i, row_value;
 
-    FIXME("%p, %d, %u, %p\n", jv, col, val, *handle);
+    TRACE("%p, %d, %u, %p\n", view, col, val, *handle);
 
-    return ERROR_FUNCTION_FAILED;
+    if (col == 0 || col > jv->columns)
+        return ERROR_INVALID_PARAMETER;
+
+    for (i = (UINT)*handle; i < jv->rows; i++)
+    {
+        if (view->ops->fetch_int( view, i, col, &row_value ) != ERROR_SUCCESS)
+            continue;
+
+        if (row_value == val)
+        {
+            *row = i;
+            (*(UINT *)handle) = i + 1;
+            return ERROR_SUCCESS;
+        }
+    }
+
+    return ERROR_NO_MORE_ITEMS;
 }
 
 static const MSIVIEWOPS join_ops =
@@ -287,13 +289,14 @@ static const MSIVIEWOPS join_ops =
     NULL,
 };
 
-UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view,
-                      LPCWSTR left, LPCWSTR right )
+UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, LPWSTR tables )
 {
     MSIJOINVIEW *jv = NULL;
     UINT r = ERROR_SUCCESS;
+    JOINTABLE *table;
+    LPWSTR ptr;
 
-    TRACE("%p (%s,%s)\n", jv, debugstr_w(left), debugstr_w(right) );
+    TRACE("%p (%s)\n", jv, debugstr_w(tables) );
 
     jv = msi_alloc_zero( sizeof *jv );
     if( !jv )
@@ -302,35 +305,42 @@ UINT JOIN_CreateView( MSIDATABASE *db, M
     /* fill the structure */
     jv->view.ops = &join_ops;
     jv->db = db;
+    jv->columns = 0;
+    jv->rows = 0;
 
-    /* create the tables to join */
-    r = TABLE_CreateView( db, left, &jv->left );
-    if( r != ERROR_SUCCESS )
-    {
-        ERR("can't create left table\n");
-        goto end;
-    }
+    list_init(&jv->tables);
 
-    r = TABLE_CreateView( db, right, &jv->right );
-    if( r != ERROR_SUCCESS )
+    while (*tables)
     {
-        ERR("can't create right table\n");
-        goto end;
-    }
+        if ((ptr = strchrW(tables, ' ')))
+            *ptr = '\0';
 
-    /* get the number of columns in each table */
-    r = jv->left->ops->get_dimensions( jv->left, NULL, &jv->left_count );
-    if( r != ERROR_SUCCESS )
-    {
-        ERR("can't get left table dimensions\n");
-        goto end;
-    }
+        table = msi_alloc(sizeof(JOINTABLE));
+        if (!table)
+            return ERROR_OUTOFMEMORY;
 
-    r = jv->right->ops->get_dimensions( jv->right, NULL, &jv->right_count );
-    if( r != ERROR_SUCCESS )
-    {
-        ERR("can't get right table dimensions\n");
-        goto end;
+        r = TABLE_CreateView( db, tables, &table->view );
+        if( r != ERROR_SUCCESS )
+        {
+            ERR("can't create table\n");
+            goto end;
+        }
+
+        r = table->view->ops->get_dimensions( table->view, NULL, &table->columns );
+        if( r != ERROR_SUCCESS )
+        {
+            ERR("can't get table dimensions\n");
+            goto end;
+        }
+
+        jv->columns += table->columns;
+
+        list_add_head( &jv->tables, &table->entry );
+
+        if (!ptr)
+            break;
+
+        tables = ptr + 1;
     }
 
     *view = &jv->view;
diff --git a/dlls/msi/query.h b/dlls/msi/query.h
index 268989d..4cf47c7 100644
--- a/dlls/msi/query.h
+++ b/dlls/msi/query.h
@@ -119,8 +119,7 @@ UINT UPDATE_CreateView( MSIDATABASE *db,
 
 UINT DELETE_CreateView( MSIDATABASE *db, MSIVIEW **view, MSIVIEW *table );
 
-UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view,
-                      LPCWSTR left, LPCWSTR right );
+UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, LPWSTR tables );
 
 UINT ALTER_CreateView( MSIDATABASE *db, MSIVIEW **view, LPCWSTR name, column_info *colinfo, int hold );
 
diff --git a/dlls/msi/sql.y b/dlls/msi/sql.y
index f5e1cf9..59a67b5 100644
--- a/dlls/msi/sql.y
+++ b/dlls/msi/sql.y
@@ -53,6 +53,7 @@ static LPWSTR SQL_getstring( void *info,
 static INT SQL_getint( void *info );
 static int sql_lex( void *SQL_lval, SQL_input *info );
 
+static LPWSTR parser_add_table( LPWSTR list, LPWSTR table );
 static void *parser_alloc( void *info, unsigned int sz );
 static column_info *parser_alloc_column( void *info, LPCWSTR table, LPCWSTR column );
 
@@ -101,7 +102,7 @@ static struct expr * EXPR_wildcard( void
 %nonassoc END_OF_FILE ILLEGAL SPACE UNCLOSED_STRING COMMENT FUNCTION
           COLUMN AGG_FUNCTION.
 
-%type <string> table id
+%type <string> table tablelist id
 %type <column_list> selcollist column column_and_type column_def table_def
 %type <column_list> column_assignment update_assign_list constlist
 %type <query> query from fromtable selectfrom unorderedsel
@@ -466,18 +467,32 @@ fromtable:
             if( r != ERROR_SUCCESS || !$$ )
                 YYABORT;
         }
-  | TK_FROM table TK_COMMA table
+  | TK_FROM tablelist
         {
             SQL_input* sql = (SQL_input*) info;
             UINT r;
 
-            /* only support inner joins on two tables */
-            r = JOIN_CreateView( sql->db, &$$, $2, $4 );
+            r = JOIN_CreateView( sql->db, &$$, $2 );
+            msi_free( $2 );
             if( r != ERROR_SUCCESS )
                 YYABORT;
         }
     ;
 
+tablelist:
+    table
+        {
+            $$ = strdupW($1);
+        }
+  |
+    table TK_COMMA tablelist
+        {
+            $$ = parser_add_table($3, $1);
+            if (!$$)
+                YYABORT;
+        }
+    ;
+
 expr:
     TK_LP expr TK_RP
         {
@@ -663,6 +678,19 @@ number:
 
 %%
 
+static LPWSTR parser_add_table(LPWSTR list, LPWSTR table)
+{
+    DWORD size = lstrlenW(list) + lstrlenW(table) + 2;
+    static const WCHAR space[] = {' ',0};
+
+    list = msi_realloc(list, size * sizeof(WCHAR));
+    if (!list) return NULL;
+
+    lstrcatW(list, space);
+    lstrcatW(list, table);
+    return list;
+}
+
 static void *parser_alloc( void *info, unsigned int sz )
 {
     SQL_input* sql = (SQL_input*) info;
diff --git a/dlls/msi/tests/db.c b/dlls/msi/tests/db.c
index a857dbf..293cf0c 100644
--- a/dlls/msi/tests/db.c
+++ b/dlls/msi/tests/db.c
@@ -2890,10 +2890,7 @@ static void test_join(void)
     ok( r == ERROR_SUCCESS, "failed to open view: %d\n", r );
 
     r = MsiViewExecute(hview, 0);
-    todo_wine
-    {
-        ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r );
-    }
+    ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r );
 
     i = 0;
     data_correct = TRUE;
@@ -2919,10 +2916,7 @@ static void test_join(void)
     }
     ok( data_correct, "data returned in the wrong order\n");
 
-    todo_wine
-    {
-        ok( i == 6, "Expected 6 rows, got %d\n", i );
-    }
+    ok( i == 6, "Expected 6 rows, got %d\n", i );
     ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r );
 
     MsiViewClose(hview);
@@ -3000,7 +2994,7 @@ static void test_join(void)
         MsiCloseHandle(hrec);
     }
 
-    todo_wine ok( data_correct, "data returned in the wrong order\n");
+    ok( data_correct, "data returned in the wrong order\n");
     ok( i == 6, "Expected 6 rows, got %d\n", i );
     ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r );
 
@@ -3048,7 +3042,7 @@ static void test_join(void)
         i++;
         MsiCloseHandle(hrec);
     }
-    todo_wine ok( data_correct, "data returned in the wrong order\n");
+    ok( data_correct, "data returned in the wrong order\n");
 
     ok( i == 6, "Expected 6 rows, got %d\n", i );
     ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r );
@@ -3058,10 +3052,10 @@ static void test_join(void)
 
     query = "SELECT * FROM `One`, `Two`, `Three` ";
     r = MsiDatabaseOpenView(hdb, query, &hview);
-    todo_wine ok( r == ERROR_SUCCESS, "failed to open view: %d\n", r );
+    ok( r == ERROR_SUCCESS, "failed to open view: %d\n", r );
 
     r = MsiViewExecute(hview, 0);
-    todo_wine ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r );
+    ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r );
 
     i = 0;
     data_correct = TRUE;
@@ -3099,11 +3093,8 @@ static void test_join(void)
     }
     ok( data_correct, "data returned in the wrong order\n");
 
-    todo_wine
-    {
-        ok( i == 6, "Expected 6 rows, got %d\n", i );
-        ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r );
-    }
+    ok( i == 6, "Expected 6 rows, got %d\n", i );
+    ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r );
 
     MsiViewClose(hview);
     MsiCloseHandle(hview);
-- 
1.4.1


More information about the wine-patches mailing list