[PATCH 5/5] ole32: Fix ref counting in GetDataHere Proxy.
Huw Davies
huw at codeweavers.com
Tue Oct 27 09:42:34 CDT 2015
Signed-off-by: Huw Davies <huw at codeweavers.com>
---
dlls/ole32/tests/usrmarshal.c | 236 +++++++++++++++++++++++++++++++++++++++++-
dlls/ole32/usrmarshal.c | 38 +++++--
2 files changed, 267 insertions(+), 7 deletions(-)
diff --git a/dlls/ole32/tests/usrmarshal.c b/dlls/ole32/tests/usrmarshal.c
index c68689c..529ad75 100644
--- a/dlls/ole32/tests/usrmarshal.c
+++ b/dlls/ole32/tests/usrmarshal.c
@@ -85,6 +85,91 @@ static void init_user_marshal_cb(USER_MARSHAL_CB *umcb,
umcb->CBType = buffer ? USER_MARSHAL_CB_UNMARSHALL : USER_MARSHAL_CB_BUFFER_SIZE;
}
+#define RELEASEMARSHALDATA WM_USER
+
+struct host_object_data
+{
+ IStream *stream;
+ IID iid;
+ IUnknown *object;
+ MSHLFLAGS marshal_flags;
+ HANDLE marshal_event;
+ IMessageFilter *filter;
+};
+
+static DWORD CALLBACK host_object_proc(LPVOID p)
+{
+ struct host_object_data *data = p;
+ HRESULT hr;
+ MSG msg;
+
+ CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
+
+ if (data->filter)
+ {
+ IMessageFilter * prev_filter = NULL;
+ hr = CoRegisterMessageFilter(data->filter, &prev_filter);
+ if (prev_filter) IMessageFilter_Release(prev_filter);
+ ok(hr == S_OK, "got %08x\n", hr);
+ }
+
+ hr = CoMarshalInterface(data->stream, &data->iid, data->object, MSHCTX_INPROC, NULL, data->marshal_flags);
+ ok(hr == S_OK, "got %08x\n", hr);
+
+ /* force the message queue to be created before signaling parent thread */
+ PeekMessageA(&msg, NULL, WM_USER, WM_USER, PM_NOREMOVE);
+
+ SetEvent(data->marshal_event);
+
+ while (GetMessageA(&msg, NULL, 0, 0))
+ {
+ if (msg.hwnd == NULL && msg.message == RELEASEMARSHALDATA)
+ {
+ CoReleaseMarshalData(data->stream);
+ SetEvent((HANDLE)msg.lParam);
+ }
+ else
+ DispatchMessageA(&msg);
+ }
+
+ HeapFree(GetProcessHeap(), 0, data);
+
+ CoUninitialize();
+
+ return hr;
+}
+
+static DWORD start_host_object2(IStream *stream, REFIID riid, IUnknown *object, MSHLFLAGS marshal_flags, IMessageFilter *filter, HANDLE *thread)
+{
+ DWORD tid = 0;
+ HANDLE marshal_event = CreateEventA(NULL, FALSE, FALSE, NULL);
+ struct host_object_data *data = HeapAlloc(GetProcessHeap(), 0, sizeof(*data));
+
+ data->stream = stream;
+ data->iid = *riid;
+ data->object = object;
+ data->marshal_flags = marshal_flags;
+ data->marshal_event = marshal_event;
+ data->filter = filter;
+
+ *thread = CreateThread(NULL, 0, host_object_proc, data, 0, &tid);
+
+ /* wait for marshaling to complete before returning */
+ ok( !WaitForSingleObject(marshal_event, 10000), "wait timed out\n" );
+ CloseHandle(marshal_event);
+
+ return tid;
+}
+
+static void end_host_object(DWORD tid, HANDLE thread)
+{
+ BOOL ret = PostThreadMessageA(tid, WM_QUIT, 0, 0);
+ ok(ret, "PostThreadMessage failed with error %d\n", GetLastError());
+ /* be careful of races - don't return until hosting thread has terminated */
+ ok( !WaitForSingleObject(thread, 10000), "wait timed out\n" );
+ CloseHandle(thread);
+}
+
static const char cf_marshaled[] =
{
0x9, 0x0, 0x0, 0x0,
@@ -1105,9 +1190,156 @@ static void test_marshal_HBRUSH(void)
DeleteObject(hBrush);
}
+struct obj
+{
+ IDataObject IDataObject_iface;
+};
+
+static HRESULT WINAPI obj_QueryInterface(IDataObject *iface, REFIID iid, void **obj)
+{
+ *obj = NULL;
+
+ if (IsEqualGUID(iid, &IID_IUnknown) ||
+ IsEqualGUID(iid, &IID_IDataObject))
+ *obj = iface;
+
+ if (*obj)
+ {
+ IDataObject_AddRef(iface);
+ return S_OK;
+ }
+
+ return E_NOINTERFACE;
+}
+
+static ULONG WINAPI obj_AddRef(IDataObject *iface)
+{
+ return 2;
+}
+
+static ULONG WINAPI obj_Release(IDataObject *iface)
+{
+ return 1;
+}
+
+static HRESULT WINAPI obj_DO_GetDataHere(IDataObject *iface, FORMATETC *fmt,
+ STGMEDIUM *med)
+{
+ ok( med->pUnkForRelease == NULL, "got %p\n", med->pUnkForRelease );
+
+ if (fmt->cfFormat == 2)
+ {
+ IStream_Release(U(med)->pstm);
+ U(med)->pstm = &Test_Stream2.IStream_iface;
+ }
+
+ return S_OK;
+}
+
+static const IDataObjectVtbl obj_data_object_vtbl =
+{
+ obj_QueryInterface,
+ obj_AddRef,
+ obj_Release,
+ NULL, /* GetData */
+ obj_DO_GetDataHere,
+ NULL, /* QueryGetData */
+ NULL, /* GetCanonicalFormatEtc */
+ NULL, /* SetData */
+ NULL, /* EnumFormatEtc */
+ NULL, /* DAdvise */
+ NULL, /* DUnadvise */
+ NULL /* EnumDAdvise */
+};
+
+static struct obj obj =
+{
+ {&obj_data_object_vtbl}
+};
+
+static void test_GetDataHere_Proxy(void)
+{
+ HRESULT hr;
+ IStream *stm;
+ HANDLE thread;
+ DWORD tid;
+ static const LARGE_INTEGER zero;
+ IDataObject *data;
+ FORMATETC fmt;
+ STGMEDIUM med;
+
+ hr = CreateStreamOnHGlobal( NULL, TRUE, &stm );
+ ok( hr == S_OK, "got %08x\n", hr );
+ tid = start_host_object2( stm, &IID_IDataObject, (IUnknown *)&obj.IDataObject_iface, MSHLFLAGS_NORMAL, NULL, &thread );
+
+ IStream_Seek( stm, zero, STREAM_SEEK_SET, NULL );
+ hr = CoUnmarshalInterface( stm, &IID_IDataObject, (void **)&data );
+ ok( hr == S_OK, "got %08x\n", hr );
+ IStream_Release( stm );
+
+ Test_Stream.refs = 1;
+ Test_Stream2.refs = 1;
+ Test_Unknown.refs = 1;
+
+ fmt.cfFormat = 1;
+ fmt.ptd = NULL;
+ fmt.dwAspect = DVASPECT_CONTENT;
+ fmt.lindex = -1;
+ U(med).pstm = NULL;
+ med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
+
+ fmt.tymed = med.tymed = TYMED_NULL;
+ hr = IDataObject_GetDataHere( data, &fmt, &med );
+ ok( hr == DV_E_TYMED, "got %08x\n", hr );
+
+ for (fmt.tymed = TYMED_HGLOBAL; fmt.tymed <= TYMED_ENHMF; fmt.tymed <<= 1)
+ {
+ med.tymed = fmt.tymed;
+ hr = IDataObject_GetDataHere( data, &fmt, &med );
+ ok( hr == (fmt.tymed <= TYMED_ISTORAGE ? S_OK : DV_E_TYMED), "got %08x for tymed %d\n", hr, fmt.tymed );
+ ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
+ }
+
+ fmt.tymed = TYMED_ISTREAM;
+ med.tymed = TYMED_ISTORAGE;
+ hr = IDataObject_GetDataHere( data, &fmt, &med );
+ ok( hr == DV_E_TYMED, "got %08x\n", hr );
+
+ fmt.tymed = med.tymed = TYMED_ISTREAM;
+ U(med).pstm = &Test_Stream.IStream_iface;
+ med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
+
+ hr = IDataObject_GetDataHere( data, &fmt, &med );
+ ok( hr == S_OK, "got %08x\n", hr );
+
+ ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" );
+ ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" );
+
+ ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs );
+ ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
+
+ fmt.cfFormat = 2;
+ fmt.tymed = med.tymed = TYMED_ISTREAM;
+ U(med).pstm = &Test_Stream.IStream_iface;
+ med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
+
+ hr = IDataObject_GetDataHere( data, &fmt, &med );
+ ok( hr == S_OK, "got %08x\n", hr );
+
+ ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" );
+ ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" );
+
+ ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs );
+ ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
+ ok( Test_Stream2.refs == 0, "got %d\n", Test_Stream2.refs );
+
+ IDataObject_Release( data );
+ end_host_object( tid, thread );
+}
+
START_TEST(usrmarshal)
{
- CoInitialize(NULL);
+ CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
test_marshal_CLIPFORMAT();
test_marshal_HWND();
@@ -1122,5 +1354,7 @@ START_TEST(usrmarshal)
test_marshal_HICON();
test_marshal_HBRUSH();
+ test_GetDataHere_Proxy();
+
CoUninitialize();
}
diff --git a/dlls/ole32/usrmarshal.c b/dlls/ole32/usrmarshal.c
index 1a3f6af..89d0675 100644
--- a/dlls/ole32/usrmarshal.c
+++ b/dlls/ole32/usrmarshal.c
@@ -2783,13 +2783,39 @@ HRESULT __RPC_STUB IDataObject_GetData_Stub(
return IDataObject_GetData(This, pformatetcIn, pRemoteMedium);
}
-HRESULT CALLBACK IDataObject_GetDataHere_Proxy(
- IDataObject* This,
- FORMATETC *pformatetc,
- STGMEDIUM *pmedium)
+HRESULT CALLBACK IDataObject_GetDataHere_Proxy(IDataObject *iface, FORMATETC *fmt, STGMEDIUM *med)
{
- TRACE("(%p)->(%p, %p)\n", This, pformatetc, pmedium);
- return IDataObject_RemoteGetDataHere_Proxy(This, pformatetc, pmedium);
+ IUnknown *release;
+ IStorage *stg = NULL;
+ HRESULT hr;
+
+ TRACE("(%p)->(%p, %p)\n", iface, fmt, med);
+
+ if ((med->tymed & (TYMED_HGLOBAL | TYMED_FILE | TYMED_ISTREAM | TYMED_ISTORAGE)) == 0)
+ return DV_E_TYMED;
+ if (med->tymed != fmt->tymed)
+ return DV_E_TYMED;
+
+ release = med->pUnkForRelease;
+ med->pUnkForRelease = NULL;
+
+ if (med->tymed == TYMED_ISTREAM || med->tymed == TYMED_ISTORAGE)
+ {
+ stg = med->u.pstg; /* This may actually be a stream, but that's ok */
+ if (stg) IStorage_AddRef( stg );
+ }
+
+ hr = IDataObject_RemoteGetDataHere_Proxy(iface, fmt, med);
+
+ med->pUnkForRelease = release;
+ if (stg)
+ {
+ if (med->u.pstg)
+ IStorage_Release( med->u.pstg );
+ med->u.pstg = stg;
+ }
+
+ return hr;
}
HRESULT __RPC_STUB IDataObject_GetDataHere_Stub(
--
1.8.0
More information about the wine-patches
mailing list