ole32: Marshal the ORPCTHAT structure prefixed to the server data.

Robert Shearman rob at codeweavers.com
Wed Dec 27 07:04:07 CST 2006


Unmarshal the data on the client side (during 
ClientChannelBuffer_SendReceive) and call ClientNotify.
---
  dlls/ole32/compobj.c |    1
  dlls/ole32/rpc.c     |  368 
+++++++++++++++++++++++++++++++++++++++++++-------
  2 files changed, 317 insertions(+), 52 deletions(-)

Thanks to Dmitry for spotting a typo in the previous patch.
-------------- next part --------------
diff --git a/dlls/ole32/compobj.c b/dlls/ole32/compobj.c
index ca51a9a..f5e2789 100644
--- a/dlls/ole32/compobj.c
+++ b/dlls/ole32/compobj.c
@@ -36,7 +36,6 @@
  *
  *   - Make all ole interface marshaling use NDR to be wire compatible with
  *     native DCOM
- *   - Use & interpret ORPCTHIS & ORPCTHAT.
  *
  */
 
diff --git a/dlls/ole32/rpc.c b/dlls/ole32/rpc.c
index c31aa4a..83cf9a5 100644
--- a/dlls/ole32/rpc.c
+++ b/dlls/ole32/rpc.c
@@ -154,6 +154,9 @@ struct channel_hook_buffer_data
 };
 
 
+static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
+                                  ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent);
+
 /* Channel Hook Functions */
 
 static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
@@ -172,7 +175,7 @@ static ULONG ChannelHooks_ClientGetSize(
     LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
         (*hook_count)++;
 
-    if (hook_count)
+    if (*hook_count)
         *data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
     else
         *data = NULL;
@@ -276,6 +279,128 @@ static void ChannelHooks_ServerNotify(SC
     LeaveCriticalSection(&csChannelHook);
 }
 
+static ULONG ChannelHooks_ServerGetSize(SChannelHookCallInfo *info,
+                                        struct channel_hook_buffer_data **data, unsigned int *hook_count,
+                                        ULONG *extension_count)
+{
+    struct channel_hook_entry *entry;
+    ULONG total_size = 0;
+    unsigned int hook_index = 0;
+
+    *hook_count = 0;
+    *extension_count = 0;
+
+    EnterCriticalSection(&csChannelHook);
+
+    LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+        (*hook_count)++;
+
+    if (*hook_count)
+        *data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
+    else
+        *data = NULL;
+
+    LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+    {
+        ULONG extension_size = 0;
+
+        IChannelHook_ServerGetSize(entry->hook, &entry->id, &info->iid, S_OK,
+                                   &extension_size);
+
+        TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
+
+        extension_size = (extension_size+7)&~7;
+        (*data)[hook_index].id = entry->id;
+        (*data)[hook_index].extension_size = extension_size;
+
+        /* an extension is only put onto the wire if it has data to write */
+        if (extension_size)
+        {
+            total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
+            (*extension_count)++;
+        }
+
+        hook_index++;
+    }
+
+    LeaveCriticalSection(&csChannelHook);
+
+    return total_size;
+}
+
+static unsigned char * ChannelHooks_ServerFillBuffer(SChannelHookCallInfo *info,
+                                                     unsigned char *buffer, struct channel_hook_buffer_data *data,
+                                                     unsigned int hook_count)
+{
+    struct channel_hook_entry *entry;
+
+    EnterCriticalSection(&csChannelHook);
+
+    LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+    {
+        unsigned int i;
+        ULONG extension_size = 0;
+        WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
+
+        for (i = 0; i < hook_count; i++)
+            if (IsEqualGUID(&entry->id, &data[i].id))
+                extension_size = data[i].extension_size;
+
+        /* an extension is only put onto the wire if it has data to write */
+        if (!extension_size)
+            continue;
+
+        IChannelHook_ServerFillBuffer(entry->hook, &entry->id, &info->iid,
+                                      &extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]),
+                                      S_OK);
+
+        TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
+
+        /* FIXME: set unused portion of wire_orpc_extent->data to 0? */
+
+        wire_orpc_extent->conformance = (extension_size+7)&~7;
+        wire_orpc_extent->size = extension_size;
+        memcpy(&wire_orpc_extent->id, &entry->id, sizeof(wire_orpc_extent->id));
+        buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
+    }
+
+    LeaveCriticalSection(&csChannelHook);
+
+    HeapFree(GetProcessHeap(), 0, data);
+
+    return buffer;
+}
+
+static void ChannelHooks_ClientNotify(SChannelHookCallInfo *info,
+                                      DWORD lDataRep, WIRE_ORPC_EXTENT *first_wire_orpc_extent,
+                                      ULONG extension_count, HRESULT hrFault)
+{
+    struct channel_hook_entry *entry;
+    ULONG i;
+
+    EnterCriticalSection(&csChannelHook);
+
+    LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+    {
+        WIRE_ORPC_EXTENT *wire_orpc_extent;
+        for (i = 0, wire_orpc_extent = first_wire_orpc_extent;
+             i < extension_count;
+             i++, wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance])
+        {
+            if (IsEqualGUID(&entry->id, &wire_orpc_extent->id))
+                break;
+        }
+        if (i == extension_count) wire_orpc_extent = NULL;
+
+        IChannelHook_ClientNotify(entry->hook, &entry->id, &info->iid,
+                                  wire_orpc_extent ? wire_orpc_extent->size : 0,
+                                  wire_orpc_extent ? wire_orpc_extent->data : NULL,
+                                  lDataRep, hrFault);
+    }
+
+    LeaveCriticalSection(&csChannelHook);
+}
+
 HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook)
 {
     struct channel_hook_entry *entry;
@@ -361,7 +486,12 @@ static HRESULT WINAPI ServerRpcChannelBu
     RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
     RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
     RPC_STATUS status;
+    ORPCTHAT *orpcthat;
     struct message_state *message_state;
+    ULONG extensions_size;
+    struct channel_hook_buffer_data *channel_hook_data;
+    unsigned int channel_hook_count;
+    ULONG extension_count;
 
     TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
 
@@ -370,11 +500,62 @@ static HRESULT WINAPI ServerRpcChannelBu
     msg->Handle = message_state->binding_handle;
     msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
 
+    extensions_size = ChannelHooks_ServerGetSize(&message_state->channel_hook_info,
+                                                 &channel_hook_data, &channel_hook_count, &extension_count);
+    
+    msg->BufferLength += FIELD_OFFSET(ORPCTHAT, extensions) + 4;
+    if (extensions_size)
+    {
+        msg->BufferLength += FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent) + 2*sizeof(DWORD) + extensions_size;
+        if (extension_count & 1)
+            msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
+    }
+    
     status = I_RpcGetBuffer(msg);
 
+    orpcthat = (ORPCTHAT *)msg->Buffer;
+    msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHAT, extensions);
+
+    orpcthat->flags = ORPCF_NULL /* FIXME? */;
+
+    /* NDR representation of orpcthat->extensions */
+    *(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
+    msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+    if (extensions_size)
+    {
+        ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
+        orpc_extent_array->size = extension_count;
+        orpc_extent_array->reserved = 0;
+        msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
+        /* NDR representation of orpc_extent_array->extent */
+        *(DWORD *)msg->Buffer = 1;
+        msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+        /* NDR representation of [size_is] attribute of orpc_extent_array->extent */
+        *(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
+        msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+        msg->Buffer = ChannelHooks_ServerFillBuffer(&message_state->channel_hook_info,
+                                                    msg->Buffer, channel_hook_data, channel_hook_count);
+
+        /* we must add a dummy extension if there is an odd extension
+         * count to meet the contract specified by the size_is attribute */
+        if (extension_count & 1)
+        {
+            WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
+            wire_orpc_extent->conformance = 0;
+            memcpy(&wire_orpc_extent->id, &GUID_NULL, sizeof(wire_orpc_extent->id));
+            wire_orpc_extent->size = 0;
+            msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
+        }
+    }
+
+    /* store the prefixed data length so that we can restore the real buffer
+     * later */
+    message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthat;
+    msg->BufferLength -= message_state->prefix_data_len;
     /* save away the message state again */
     msg->Handle = message_state;
-    message_state->prefix_data_len = 0;
 
     TRACE("-- %ld\n", status);
 
@@ -556,6 +737,9 @@ static HRESULT WINAPI ClientRpcChannelBu
     APARTMENT *apt = NULL;
     IPID ipid;
     struct message_state *message_state;
+    ORPCTHAT orpcthat;
+    ORPC_EXTENT_ARRAY orpc_ext_array;
+    WIRE_ORPC_EXTENT *first_wire_orpc_extent = NULL;
 
     TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod);
 
@@ -652,18 +836,43 @@ static HRESULT WINAPI ClientRpcChannelBu
     }
     ClientRpcChannelBuffer_ReleaseEventHandle(This, params->handle);
 
-    /* save away the message state again */
-    msg->Handle = message_state;
-    message_state->prefix_data_len = 0;
-
     if (hr == S_OK) hr = params->hr;
 
     status = params->status;
     HeapFree(GetProcessHeap(), 0, params);
     params = NULL;
 
-    if (hr) return hr;
+    orpcthat.flags = ORPCF_NULL;
+    orpcthat.extensions = NULL;
+
+    if (status == RPC_S_OK && msg->BufferLength > FIELD_OFFSET(ORPCTHAT, extensions) + 4)
+    {
+        HRESULT hr2;
+        char *original_buffer = msg->Buffer;
+
+        /* handle ORPCTHAT and client extensions */
+
+        hr2 = unmarshal_ORPCTHAT(msg, &orpcthat, &orpc_ext_array, &first_wire_orpc_extent);
+        if (FAILED(hr2))
+            hr = hr2;
+
+        message_state->prefix_data_len = original_buffer - (char *)msg->Buffer;
+        msg->BufferLength -= message_state->prefix_data_len;
+    }
+    else
+        message_state->prefix_data_len = 0;
+
+    ChannelHooks_ClientNotify(&message_state->channel_hook_info,
+                              msg->DataRepresentation,
+                              first_wire_orpc_extent,
+                              orpcthat.extensions && first_wire_orpc_extent ? orpcthat.extensions->size : 0,
+                              status == RPC_S_CALL_FAILED && msg->BufferLength >= sizeof(HRESULT) ? *(HRESULT *)msg->Buffer : S_OK);
     
+    /* save away the message state again */
+    msg->Handle = message_state;
+
+    if (hr) return hr;
+
     if (pstatus) *pstatus = status;
 
     TRACE("RPC call status: 0x%lx\n", status);
@@ -856,6 +1065,60 @@ HRESULT RPC_CreateServerChannel(IRpcChan
     return S_OK;
 }
 
+/* unmarshals ORPC_EXTENT_ARRAY according to NDR rules, but doesn't allocate
+ * any memory */
+static HRESULT unmarshal_ORPC_EXTENT_ARRAY(RPC_MESSAGE *msg, const char *end,
+                                           ORPC_EXTENT_ARRAY *extensions,
+                                           WIRE_ORPC_EXTENT **first_wire_orpc_extent)
+{
+    DWORD pointer_id;
+    DWORD i;
+
+    memcpy(extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
+    msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
+
+    if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
+        return RPC_E_INVALID_HEADER;
+
+    pointer_id = *(DWORD *)msg->Buffer;
+    msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+    extensions->extent = NULL;
+
+    if (pointer_id)
+    {
+        WIRE_ORPC_EXTENT *wire_orpc_extent;
+    
+        /* conformance */
+        if (*(DWORD *)msg->Buffer != ((extensions->size+1)&~1))
+            return RPC_S_INVALID_BOUND;
+
+        msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+        /* arbritary limit for security (don't know what native does) */
+        if (extensions->size > 256)
+        {
+            ERR("too many extensions: %ld\n", extensions->size);
+            return RPC_S_INVALID_BOUND;
+        }
+
+        *first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
+        for (i = 0; i < ((extensions->size+1)&~1); i++)
+        {
+            if ((const char *)&wire_orpc_extent->data[0] > end)
+                return RPC_S_INVALID_BOUND;
+            if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
+                return RPC_S_INVALID_BOUND;
+            if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
+                return RPC_S_INVALID_BOUND;
+            TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
+            wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
+        }
+        msg->Buffer = wire_orpc_extent;
+    }
+
+    return S_OK;
+}
+
 /* unmarshals ORPCTHIS according to NDR rules, but doesn't allocate any memory */
 static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
     ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
@@ -885,50 +1148,10 @@ static HRESULT unmarshal_ORPCTHIS(RPC_ME
 
     if (orpcthis->extensions)
     {
-        DWORD pointer_id;
-        DWORD i;
-
-        memcpy(orpcthis->extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
-        msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
-
-        if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
-            return RPC_E_INVALID_HEADER;
-
-        pointer_id = *(DWORD *)msg->Buffer;
-        msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
-        orpcthis->extensions->extent = NULL;
-
-        if (pointer_id)
-        {
-            WIRE_ORPC_EXTENT *wire_orpc_extent;
-
-            /* conformance */
-            if (*(DWORD *)msg->Buffer != ((orpcthis->extensions->size+1)&~1))
-                return RPC_S_INVALID_BOUND;
-
-            msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
-
-            /* arbritary limit for security (don't know what native does) */
-            if (orpcthis->extensions->size > 256)
-            {
-                ERR("too many extensions: %ld\n", orpcthis->extensions->size);
-                return RPC_S_INVALID_BOUND;
-            }
-
-            *first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
-            for (i = 0; i < ((orpcthis->extensions->size+1)&~1); i++)
-            {
-                if ((const char *)&wire_orpc_extent->data[0] > end)
-                    return RPC_S_INVALID_BOUND;
-                if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
-                    return RPC_S_INVALID_BOUND;
-                if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
-                    return RPC_S_INVALID_BOUND;
-                TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
-                wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
-            }
-            msg->Buffer = wire_orpc_extent;
-        }
+        HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
+                                                 first_wire_orpc_extent);
+        if (FAILED(hr))
+            return hr;
     }
 
     if ((orpcthis->version.MajorVersion != COM_MAJOR_VERSION) ||
@@ -948,6 +1171,49 @@ static HRESULT unmarshal_ORPCTHIS(RPC_ME
     return S_OK;
 }
 
+static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
+                                  ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
+{
+    const char *end = (char *)msg->Buffer + msg->BufferLength;
+
+    *first_wire_orpc_extent = NULL;
+
+    if (msg->BufferLength < FIELD_OFFSET(ORPCTHAT, extensions) + 4)
+    {
+        ERR("invalid buffer length\n");
+        return RPC_E_INVALID_HEADER;
+    }
+
+    memcpy(orpcthat, msg->Buffer, FIELD_OFFSET(ORPCTHAT, extensions));
+    msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHAT, extensions);
+
+    if ((const char *)msg->Buffer + sizeof(DWORD) > end)
+        return RPC_E_INVALID_HEADER;
+
+    if (*(DWORD *)msg->Buffer)
+        orpcthat->extensions = orpc_ext_array;
+    else
+        orpcthat->extensions = NULL;
+
+    msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+    if (orpcthat->extensions)
+    {
+        HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
+                                                 first_wire_orpc_extent);
+        if (FAILED(hr))
+            return hr;
+    }
+
+    if (orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4))
+    {
+        ERR("invalid flags 0x%lx\n", orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4));
+        return RPC_E_INVALID_HEADER;
+    }
+    
+    return S_OK;
+}
+
 void RPC_ExecuteCall(struct dispatch_params *params)
 {
     struct message_state *message_state = NULL;


More information about the wine-patches mailing list