ole32: Marshal the ORPCTHAT structure prefixed to the server data.
Robert Shearman
rob at codeweavers.com
Wed Dec 27 07:04:07 CST 2006
Unmarshal the data on the client side (during
ClientChannelBuffer_SendReceive) and call ClientNotify.
---
dlls/ole32/compobj.c | 1
dlls/ole32/rpc.c | 368
+++++++++++++++++++++++++++++++++++++++++++-------
2 files changed, 317 insertions(+), 52 deletions(-)
Thanks to Dmitry for spotting a typo in the previous patch.
-------------- next part --------------
diff --git a/dlls/ole32/compobj.c b/dlls/ole32/compobj.c
index ca51a9a..f5e2789 100644
--- a/dlls/ole32/compobj.c
+++ b/dlls/ole32/compobj.c
@@ -36,7 +36,6 @@
*
* - Make all ole interface marshaling use NDR to be wire compatible with
* native DCOM
- * - Use & interpret ORPCTHIS & ORPCTHAT.
*
*/
diff --git a/dlls/ole32/rpc.c b/dlls/ole32/rpc.c
index c31aa4a..83cf9a5 100644
--- a/dlls/ole32/rpc.c
+++ b/dlls/ole32/rpc.c
@@ -154,6 +154,9 @@ struct channel_hook_buffer_data
};
+static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
+ ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent);
+
/* Channel Hook Functions */
static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
@@ -172,7 +175,7 @@ static ULONG ChannelHooks_ClientGetSize(
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
(*hook_count)++;
- if (hook_count)
+ if (*hook_count)
*data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
else
*data = NULL;
@@ -276,6 +279,128 @@ static void ChannelHooks_ServerNotify(SC
LeaveCriticalSection(&csChannelHook);
}
+static ULONG ChannelHooks_ServerGetSize(SChannelHookCallInfo *info,
+ struct channel_hook_buffer_data **data, unsigned int *hook_count,
+ ULONG *extension_count)
+{
+ struct channel_hook_entry *entry;
+ ULONG total_size = 0;
+ unsigned int hook_index = 0;
+
+ *hook_count = 0;
+ *extension_count = 0;
+
+ EnterCriticalSection(&csChannelHook);
+
+ LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+ (*hook_count)++;
+
+ if (*hook_count)
+ *data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
+ else
+ *data = NULL;
+
+ LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+ {
+ ULONG extension_size = 0;
+
+ IChannelHook_ServerGetSize(entry->hook, &entry->id, &info->iid, S_OK,
+ &extension_size);
+
+ TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
+
+ extension_size = (extension_size+7)&~7;
+ (*data)[hook_index].id = entry->id;
+ (*data)[hook_index].extension_size = extension_size;
+
+ /* an extension is only put onto the wire if it has data to write */
+ if (extension_size)
+ {
+ total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
+ (*extension_count)++;
+ }
+
+ hook_index++;
+ }
+
+ LeaveCriticalSection(&csChannelHook);
+
+ return total_size;
+}
+
+static unsigned char * ChannelHooks_ServerFillBuffer(SChannelHookCallInfo *info,
+ unsigned char *buffer, struct channel_hook_buffer_data *data,
+ unsigned int hook_count)
+{
+ struct channel_hook_entry *entry;
+
+ EnterCriticalSection(&csChannelHook);
+
+ LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+ {
+ unsigned int i;
+ ULONG extension_size = 0;
+ WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
+
+ for (i = 0; i < hook_count; i++)
+ if (IsEqualGUID(&entry->id, &data[i].id))
+ extension_size = data[i].extension_size;
+
+ /* an extension is only put onto the wire if it has data to write */
+ if (!extension_size)
+ continue;
+
+ IChannelHook_ServerFillBuffer(entry->hook, &entry->id, &info->iid,
+ &extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]),
+ S_OK);
+
+ TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
+
+ /* FIXME: set unused portion of wire_orpc_extent->data to 0? */
+
+ wire_orpc_extent->conformance = (extension_size+7)&~7;
+ wire_orpc_extent->size = extension_size;
+ memcpy(&wire_orpc_extent->id, &entry->id, sizeof(wire_orpc_extent->id));
+ buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
+ }
+
+ LeaveCriticalSection(&csChannelHook);
+
+ HeapFree(GetProcessHeap(), 0, data);
+
+ return buffer;
+}
+
+static void ChannelHooks_ClientNotify(SChannelHookCallInfo *info,
+ DWORD lDataRep, WIRE_ORPC_EXTENT *first_wire_orpc_extent,
+ ULONG extension_count, HRESULT hrFault)
+{
+ struct channel_hook_entry *entry;
+ ULONG i;
+
+ EnterCriticalSection(&csChannelHook);
+
+ LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
+ {
+ WIRE_ORPC_EXTENT *wire_orpc_extent;
+ for (i = 0, wire_orpc_extent = first_wire_orpc_extent;
+ i < extension_count;
+ i++, wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance])
+ {
+ if (IsEqualGUID(&entry->id, &wire_orpc_extent->id))
+ break;
+ }
+ if (i == extension_count) wire_orpc_extent = NULL;
+
+ IChannelHook_ClientNotify(entry->hook, &entry->id, &info->iid,
+ wire_orpc_extent ? wire_orpc_extent->size : 0,
+ wire_orpc_extent ? wire_orpc_extent->data : NULL,
+ lDataRep, hrFault);
+ }
+
+ LeaveCriticalSection(&csChannelHook);
+}
+
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook)
{
struct channel_hook_entry *entry;
@@ -361,7 +486,12 @@ static HRESULT WINAPI ServerRpcChannelBu
RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status;
+ ORPCTHAT *orpcthat;
struct message_state *message_state;
+ ULONG extensions_size;
+ struct channel_hook_buffer_data *channel_hook_data;
+ unsigned int channel_hook_count;
+ ULONG extension_count;
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
@@ -370,11 +500,62 @@ static HRESULT WINAPI ServerRpcChannelBu
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
+ extensions_size = ChannelHooks_ServerGetSize(&message_state->channel_hook_info,
+ &channel_hook_data, &channel_hook_count, &extension_count);
+
+ msg->BufferLength += FIELD_OFFSET(ORPCTHAT, extensions) + 4;
+ if (extensions_size)
+ {
+ msg->BufferLength += FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent) + 2*sizeof(DWORD) + extensions_size;
+ if (extension_count & 1)
+ msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
+ }
+
status = I_RpcGetBuffer(msg);
+ orpcthat = (ORPCTHAT *)msg->Buffer;
+ msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHAT, extensions);
+
+ orpcthat->flags = ORPCF_NULL /* FIXME? */;
+
+ /* NDR representation of orpcthat->extensions */
+ *(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
+ msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+ if (extensions_size)
+ {
+ ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
+ orpc_extent_array->size = extension_count;
+ orpc_extent_array->reserved = 0;
+ msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
+ /* NDR representation of orpc_extent_array->extent */
+ *(DWORD *)msg->Buffer = 1;
+ msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+ /* NDR representation of [size_is] attribute of orpc_extent_array->extent */
+ *(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
+ msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+ msg->Buffer = ChannelHooks_ServerFillBuffer(&message_state->channel_hook_info,
+ msg->Buffer, channel_hook_data, channel_hook_count);
+
+ /* we must add a dummy extension if there is an odd extension
+ * count to meet the contract specified by the size_is attribute */
+ if (extension_count & 1)
+ {
+ WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
+ wire_orpc_extent->conformance = 0;
+ memcpy(&wire_orpc_extent->id, &GUID_NULL, sizeof(wire_orpc_extent->id));
+ wire_orpc_extent->size = 0;
+ msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
+ }
+ }
+
+ /* store the prefixed data length so that we can restore the real buffer
+ * later */
+ message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthat;
+ msg->BufferLength -= message_state->prefix_data_len;
/* save away the message state again */
msg->Handle = message_state;
- message_state->prefix_data_len = 0;
TRACE("-- %ld\n", status);
@@ -556,6 +737,9 @@ static HRESULT WINAPI ClientRpcChannelBu
APARTMENT *apt = NULL;
IPID ipid;
struct message_state *message_state;
+ ORPCTHAT orpcthat;
+ ORPC_EXTENT_ARRAY orpc_ext_array;
+ WIRE_ORPC_EXTENT *first_wire_orpc_extent = NULL;
TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod);
@@ -652,18 +836,43 @@ static HRESULT WINAPI ClientRpcChannelBu
}
ClientRpcChannelBuffer_ReleaseEventHandle(This, params->handle);
- /* save away the message state again */
- msg->Handle = message_state;
- message_state->prefix_data_len = 0;
-
if (hr == S_OK) hr = params->hr;
status = params->status;
HeapFree(GetProcessHeap(), 0, params);
params = NULL;
- if (hr) return hr;
+ orpcthat.flags = ORPCF_NULL;
+ orpcthat.extensions = NULL;
+
+ if (status == RPC_S_OK && msg->BufferLength > FIELD_OFFSET(ORPCTHAT, extensions) + 4)
+ {
+ HRESULT hr2;
+ char *original_buffer = msg->Buffer;
+
+ /* handle ORPCTHAT and client extensions */
+
+ hr2 = unmarshal_ORPCTHAT(msg, &orpcthat, &orpc_ext_array, &first_wire_orpc_extent);
+ if (FAILED(hr2))
+ hr = hr2;
+
+ message_state->prefix_data_len = original_buffer - (char *)msg->Buffer;
+ msg->BufferLength -= message_state->prefix_data_len;
+ }
+ else
+ message_state->prefix_data_len = 0;
+
+ ChannelHooks_ClientNotify(&message_state->channel_hook_info,
+ msg->DataRepresentation,
+ first_wire_orpc_extent,
+ orpcthat.extensions && first_wire_orpc_extent ? orpcthat.extensions->size : 0,
+ status == RPC_S_CALL_FAILED && msg->BufferLength >= sizeof(HRESULT) ? *(HRESULT *)msg->Buffer : S_OK);
+ /* save away the message state again */
+ msg->Handle = message_state;
+
+ if (hr) return hr;
+
if (pstatus) *pstatus = status;
TRACE("RPC call status: 0x%lx\n", status);
@@ -856,6 +1065,60 @@ HRESULT RPC_CreateServerChannel(IRpcChan
return S_OK;
}
+/* unmarshals ORPC_EXTENT_ARRAY according to NDR rules, but doesn't allocate
+ * any memory */
+static HRESULT unmarshal_ORPC_EXTENT_ARRAY(RPC_MESSAGE *msg, const char *end,
+ ORPC_EXTENT_ARRAY *extensions,
+ WIRE_ORPC_EXTENT **first_wire_orpc_extent)
+{
+ DWORD pointer_id;
+ DWORD i;
+
+ memcpy(extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
+ msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
+
+ if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
+ return RPC_E_INVALID_HEADER;
+
+ pointer_id = *(DWORD *)msg->Buffer;
+ msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+ extensions->extent = NULL;
+
+ if (pointer_id)
+ {
+ WIRE_ORPC_EXTENT *wire_orpc_extent;
+
+ /* conformance */
+ if (*(DWORD *)msg->Buffer != ((extensions->size+1)&~1))
+ return RPC_S_INVALID_BOUND;
+
+ msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+ /* arbritary limit for security (don't know what native does) */
+ if (extensions->size > 256)
+ {
+ ERR("too many extensions: %ld\n", extensions->size);
+ return RPC_S_INVALID_BOUND;
+ }
+
+ *first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
+ for (i = 0; i < ((extensions->size+1)&~1); i++)
+ {
+ if ((const char *)&wire_orpc_extent->data[0] > end)
+ return RPC_S_INVALID_BOUND;
+ if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
+ return RPC_S_INVALID_BOUND;
+ if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
+ return RPC_S_INVALID_BOUND;
+ TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
+ wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
+ }
+ msg->Buffer = wire_orpc_extent;
+ }
+
+ return S_OK;
+}
+
/* unmarshals ORPCTHIS according to NDR rules, but doesn't allocate any memory */
static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
@@ -885,50 +1148,10 @@ static HRESULT unmarshal_ORPCTHIS(RPC_ME
if (orpcthis->extensions)
{
- DWORD pointer_id;
- DWORD i;
-
- memcpy(orpcthis->extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
- msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
-
- if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
- return RPC_E_INVALID_HEADER;
-
- pointer_id = *(DWORD *)msg->Buffer;
- msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
- orpcthis->extensions->extent = NULL;
-
- if (pointer_id)
- {
- WIRE_ORPC_EXTENT *wire_orpc_extent;
-
- /* conformance */
- if (*(DWORD *)msg->Buffer != ((orpcthis->extensions->size+1)&~1))
- return RPC_S_INVALID_BOUND;
-
- msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
-
- /* arbritary limit for security (don't know what native does) */
- if (orpcthis->extensions->size > 256)
- {
- ERR("too many extensions: %ld\n", orpcthis->extensions->size);
- return RPC_S_INVALID_BOUND;
- }
-
- *first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
- for (i = 0; i < ((orpcthis->extensions->size+1)&~1); i++)
- {
- if ((const char *)&wire_orpc_extent->data[0] > end)
- return RPC_S_INVALID_BOUND;
- if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
- return RPC_S_INVALID_BOUND;
- if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
- return RPC_S_INVALID_BOUND;
- TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
- wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
- }
- msg->Buffer = wire_orpc_extent;
- }
+ HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
+ first_wire_orpc_extent);
+ if (FAILED(hr))
+ return hr;
}
if ((orpcthis->version.MajorVersion != COM_MAJOR_VERSION) ||
@@ -948,6 +1171,49 @@ static HRESULT unmarshal_ORPCTHIS(RPC_ME
return S_OK;
}
+static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
+ ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
+{
+ const char *end = (char *)msg->Buffer + msg->BufferLength;
+
+ *first_wire_orpc_extent = NULL;
+
+ if (msg->BufferLength < FIELD_OFFSET(ORPCTHAT, extensions) + 4)
+ {
+ ERR("invalid buffer length\n");
+ return RPC_E_INVALID_HEADER;
+ }
+
+ memcpy(orpcthat, msg->Buffer, FIELD_OFFSET(ORPCTHAT, extensions));
+ msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHAT, extensions);
+
+ if ((const char *)msg->Buffer + sizeof(DWORD) > end)
+ return RPC_E_INVALID_HEADER;
+
+ if (*(DWORD *)msg->Buffer)
+ orpcthat->extensions = orpc_ext_array;
+ else
+ orpcthat->extensions = NULL;
+
+ msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
+
+ if (orpcthat->extensions)
+ {
+ HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
+ first_wire_orpc_extent);
+ if (FAILED(hr))
+ return hr;
+ }
+
+ if (orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4))
+ {
+ ERR("invalid flags 0x%lx\n", orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4));
+ return RPC_E_INVALID_HEADER;
+ }
+
+ return S_OK;
+}
+
void RPC_ExecuteCall(struct dispatch_params *params)
{
struct message_state *message_state = NULL;
More information about the wine-patches
mailing list