Bernhard Loos : msi: Protected primary keys against modification.

Alexandre Julliard julliard at winehq.org
Fri Aug 26 10:40:54 CDT 2011


Module: wine
Branch: master
Commit: abd1174941b7e074d3cd14c3a121febd6eed96b7
URL:    http://source.winehq.org/git/wine.git/?a=commit;h=abd1174941b7e074d3cd14c3a121febd6eed96b7

Author: Bernhard Loos <bernhardloos at googlemail.com>
Date:   Fri Aug 26 04:53:39 2011 +0200

msi: Protected primary keys against modification.

---

 dlls/msi/msipriv.h |    1 +
 dlls/msi/record.c  |   50 ++++++++++++++++++++++++++-------------------
 dlls/msi/where.c   |   56 +++++++++++++++++++++++++++++++++++++++++++++------
 3 files changed, 79 insertions(+), 28 deletions(-)

diff --git a/dlls/msi/msipriv.h b/dlls/msi/msipriv.h
index 2524752..998970b 100644
--- a/dlls/msi/msipriv.h
+++ b/dlls/msi/msipriv.h
@@ -809,6 +809,7 @@ extern UINT MSI_RecordSetStreamFromFileW( MSIRECORD *, UINT, LPCWSTR ) DECLSPEC_
 extern UINT MSI_RecordCopyField( MSIRECORD *, UINT, MSIRECORD *, UINT ) DECLSPEC_HIDDEN;
 extern MSIRECORD *MSI_CloneRecord( MSIRECORD * ) DECLSPEC_HIDDEN;
 extern BOOL MSI_RecordsAreEqual( MSIRECORD *, MSIRECORD * ) DECLSPEC_HIDDEN;
+extern BOOL MSI_RecordsAreFieldsEqual(MSIRECORD *a, MSIRECORD *b, UINT field) DECLSPEC_HIDDEN;
 
 /* stream internals */
 extern void enum_stream_names( IStorage *stg ) DECLSPEC_HIDDEN;
diff --git a/dlls/msi/record.c b/dlls/msi/record.c
index 0e4fb8a..7acbfc7 100644
--- a/dlls/msi/record.c
+++ b/dlls/msi/record.c
@@ -994,6 +994,34 @@ MSIRECORD *MSI_CloneRecord(MSIRECORD *rec)
     return clone;
 }
 
+BOOL MSI_RecordsAreFieldsEqual(MSIRECORD *a, MSIRECORD *b, UINT field)
+{
+    if (a->fields[field].type != b->fields[field].type)
+        return FALSE;
+
+    switch (a->fields[field].type)
+    {
+        case MSIFIELD_NULL:
+            break;
+
+        case MSIFIELD_INT:
+            if (a->fields[field].u.iVal != b->fields[field].u.iVal)
+                return FALSE;
+            break;
+
+        case MSIFIELD_WSTR:
+            if (strcmpW(a->fields[field].u.szwVal, b->fields[field].u.szwVal))
+                return FALSE;
+            break;
+
+        case MSIFIELD_STREAM:
+        default:
+            return FALSE;
+    }
+    return TRUE;
+}
+
+
 BOOL MSI_RecordsAreEqual(MSIRECORD *a, MSIRECORD *b)
 {
     UINT i;
@@ -1003,28 +1031,8 @@ BOOL MSI_RecordsAreEqual(MSIRECORD *a, MSIRECORD *b)
 
     for (i = 0; i <= a->count; i++)
     {
-        if (a->fields[i].type != b->fields[i].type)
+        if (!MSI_RecordsAreFieldsEqual( a, b, i ))
             return FALSE;
-
-        switch (a->fields[i].type)
-        {
-            case MSIFIELD_NULL:
-                break;
-
-            case MSIFIELD_INT:
-                if (a->fields[i].u.iVal != b->fields[i].u.iVal)
-                    return FALSE;
-                break;
-
-            case MSIFIELD_WSTR:
-                if (strcmpW(a->fields[i].u.szwVal, b->fields[i].u.szwVal))
-                    return FALSE;
-                break;
-
-            case MSIFIELD_STREAM:
-            default:
-                return FALSE;
-        }
     }
 
     return TRUE;
diff --git a/dlls/msi/where.c b/dlls/msi/where.c
index d8ac5e8..0742244 100644
--- a/dlls/msi/where.c
+++ b/dlls/msi/where.c
@@ -262,9 +262,10 @@ static UINT WHERE_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec )
 static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask )
 {
     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
-    UINT r, offset = 0, reduced_mask = 0;
+    UINT i, r, offset = 0;
     JOINTABLE *table = wv->tables;
     UINT *rows;
+    UINT mask_copy = mask;
 
     TRACE("%p %d %p %08x\n", wv, row, rec, mask );
 
@@ -275,28 +276,54 @@ static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UI
     if (r != ERROR_SUCCESS)
         return r;
 
-    if(mask >= 1 << wv->col_count)
+    if (mask >= 1 << wv->col_count)
         return ERROR_INVALID_PARAMETER;
 
     do
     {
+        for (i = 0; i < table->col_count; i++) {
+            UINT type;
+
+            if (!(mask_copy & (1 << i)))
+                continue;
+            r = table->view->ops->get_column_info(table->view, i + 1, NULL,
+                                            &type, NULL, NULL );
+            if (r != ERROR_SUCCESS)
+                return r;
+            if (type & MSITYPE_KEY)
+                return ERROR_FUNCTION_FAILED;
+        }
+        mask_copy >>= table->col_count;
+    }
+    while (mask_copy && (table = table->next));
+
+    table = wv->tables;
+
+    do
+    {
         const UINT col_count = table->col_count;
         UINT i;
         MSIRECORD *reduced;
+        UINT reduced_mask = (mask >> offset) & ((1 << col_count) - 1);
+
+        if (!reduced_mask)
+        {
+            offset += col_count;
+            continue;
+        }
 
         reduced = MSI_CreateRecord(col_count);
         if (!reduced)
             return ERROR_FUNCTION_FAILED;
 
-        for (i = 0; i < col_count; i++)
+        for (i = 1; i <= col_count; i++)
         {
-            r = MSI_RecordCopyField(rec, i + offset + 1, reduced, i + 1);
+            r = MSI_RecordCopyField(rec, i + offset, reduced, i);
             if (r != ERROR_SUCCESS)
                 break;
         }
 
         offset += col_count;
-        reduced_mask = mask >> (wv->col_count - offset) & ((1 << col_count) - 1);
 
         if (r == ERROR_SUCCESS)
             r = table->view->ops->set_row(table->view, rows[table->table_index], reduced, reduced_mask);
@@ -644,13 +671,28 @@ static UINT join_find_row( MSIWHEREVIEW *wv, MSIRECORD *rec, UINT *row )
 static UINT join_modify_update( struct tagMSIVIEW *view, MSIRECORD *rec )
 {
     MSIWHEREVIEW *wv = (MSIWHEREVIEW *)view;
-    UINT r, row;
+    UINT r, row, i, mask = 0;
+    MSIRECORD *current;
+
 
     r = join_find_row( wv, rec, &row );
     if (r != ERROR_SUCCESS)
         return r;
 
-    return WHERE_set_row( view, row, rec, (1 << wv->col_count) - 1 );
+    r = msi_view_get_row( wv->db, view, row, &current );
+    if (r != ERROR_SUCCESS)
+        return r;
+
+    assert(MSI_RecordGetFieldCount(rec) == MSI_RecordGetFieldCount(current));
+
+    for (i = MSI_RecordGetFieldCount(rec); i > 0; i--)
+    {
+        if (!MSI_RecordsAreFieldsEqual(rec, current, i))
+            mask |= 1 << (i - 1);
+    }
+     msiobj_release(&current->hdr);
+
+    return WHERE_set_row( view, row, rec, mask );
 }
 
 static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,




More information about the wine-cvs mailing list