Huw Davies : ole32: Fix ref counting in GetDataHere Proxy.

Alexandre Julliard julliard at wine.codeweavers.com
Tue Oct 27 11:06:37 CDT 2015


Module: wine
Branch: master
Commit: 28b916b26eceb35a071e9aad76da1fbbaa857ba1
URL:    http://source.winehq.org/git/wine.git/?a=commit;h=28b916b26eceb35a071e9aad76da1fbbaa857ba1

Author: Huw Davies <huw at codeweavers.com>
Date:   Tue Oct 27 14:42:34 2015 +0000

ole32: Fix ref counting in GetDataHere Proxy.

Signed-off-by: Huw Davies <huw at codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard at winehq.org>

---

 dlls/ole32/tests/usrmarshal.c | 236 +++++++++++++++++++++++++++++++++++++++++-
 dlls/ole32/usrmarshal.c       |  38 +++++--
 2 files changed, 267 insertions(+), 7 deletions(-)

diff --git a/dlls/ole32/tests/usrmarshal.c b/dlls/ole32/tests/usrmarshal.c
index c68689c..529ad75 100644
--- a/dlls/ole32/tests/usrmarshal.c
+++ b/dlls/ole32/tests/usrmarshal.c
@@ -85,6 +85,91 @@ static void init_user_marshal_cb(USER_MARSHAL_CB *umcb,
     umcb->CBType = buffer ? USER_MARSHAL_CB_UNMARSHALL : USER_MARSHAL_CB_BUFFER_SIZE;
 }
 
+#define RELEASEMARSHALDATA WM_USER
+
+struct host_object_data
+{
+    IStream *stream;
+    IID iid;
+    IUnknown *object;
+    MSHLFLAGS marshal_flags;
+    HANDLE marshal_event;
+    IMessageFilter *filter;
+};
+
+static DWORD CALLBACK host_object_proc(LPVOID p)
+{
+    struct host_object_data *data = p;
+    HRESULT hr;
+    MSG msg;
+
+    CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
+
+    if (data->filter)
+    {
+        IMessageFilter * prev_filter = NULL;
+        hr = CoRegisterMessageFilter(data->filter, &prev_filter);
+        if (prev_filter) IMessageFilter_Release(prev_filter);
+        ok(hr == S_OK, "got %08x\n", hr);
+    }
+
+    hr = CoMarshalInterface(data->stream, &data->iid, data->object, MSHCTX_INPROC, NULL, data->marshal_flags);
+    ok(hr == S_OK, "got %08x\n", hr);
+
+    /* force the message queue to be created before signaling parent thread */
+    PeekMessageA(&msg, NULL, WM_USER, WM_USER, PM_NOREMOVE);
+
+    SetEvent(data->marshal_event);
+
+    while (GetMessageA(&msg, NULL, 0, 0))
+    {
+        if (msg.hwnd == NULL && msg.message == RELEASEMARSHALDATA)
+        {
+            CoReleaseMarshalData(data->stream);
+            SetEvent((HANDLE)msg.lParam);
+        }
+        else
+            DispatchMessageA(&msg);
+    }
+
+    HeapFree(GetProcessHeap(), 0, data);
+
+    CoUninitialize();
+
+    return hr;
+}
+
+static DWORD start_host_object2(IStream *stream, REFIID riid, IUnknown *object, MSHLFLAGS marshal_flags, IMessageFilter *filter, HANDLE *thread)
+{
+    DWORD tid = 0;
+    HANDLE marshal_event = CreateEventA(NULL, FALSE, FALSE, NULL);
+    struct host_object_data *data = HeapAlloc(GetProcessHeap(), 0, sizeof(*data));
+
+    data->stream = stream;
+    data->iid = *riid;
+    data->object = object;
+    data->marshal_flags = marshal_flags;
+    data->marshal_event = marshal_event;
+    data->filter = filter;
+
+    *thread = CreateThread(NULL, 0, host_object_proc, data, 0, &tid);
+
+    /* wait for marshaling to complete before returning */
+    ok( !WaitForSingleObject(marshal_event, 10000), "wait timed out\n" );
+    CloseHandle(marshal_event);
+
+    return tid;
+}
+
+static void end_host_object(DWORD tid, HANDLE thread)
+{
+    BOOL ret = PostThreadMessageA(tid, WM_QUIT, 0, 0);
+    ok(ret, "PostThreadMessage failed with error %d\n", GetLastError());
+    /* be careful of races - don't return until hosting thread has terminated */
+    ok( !WaitForSingleObject(thread, 10000), "wait timed out\n" );
+    CloseHandle(thread);
+}
+
 static const char cf_marshaled[] =
 {
     0x9, 0x0, 0x0, 0x0,
@@ -1105,9 +1190,156 @@ static void test_marshal_HBRUSH(void)
     DeleteObject(hBrush);
 }
 
+struct obj
+{
+    IDataObject IDataObject_iface;
+};
+
+static HRESULT WINAPI obj_QueryInterface(IDataObject *iface, REFIID iid, void **obj)
+{
+    *obj = NULL;
+
+    if (IsEqualGUID(iid, &IID_IUnknown) ||
+        IsEqualGUID(iid, &IID_IDataObject))
+        *obj = iface;
+
+    if (*obj)
+    {
+        IDataObject_AddRef(iface);
+        return S_OK;
+    }
+
+    return E_NOINTERFACE;
+}
+
+static ULONG WINAPI obj_AddRef(IDataObject *iface)
+{
+    return 2;
+}
+
+static ULONG WINAPI obj_Release(IDataObject *iface)
+{
+    return 1;
+}
+
+static HRESULT WINAPI obj_DO_GetDataHere(IDataObject *iface, FORMATETC *fmt,
+                                         STGMEDIUM *med)
+{
+    ok( med->pUnkForRelease == NULL, "got %p\n", med->pUnkForRelease );
+
+    if (fmt->cfFormat == 2)
+    {
+        IStream_Release(U(med)->pstm);
+        U(med)->pstm = &Test_Stream2.IStream_iface;
+    }
+
+    return S_OK;
+}
+
+static const IDataObjectVtbl obj_data_object_vtbl =
+{
+    obj_QueryInterface,
+    obj_AddRef,
+    obj_Release,
+    NULL, /* GetData */
+    obj_DO_GetDataHere,
+    NULL, /* QueryGetData */
+    NULL, /* GetCanonicalFormatEtc */
+    NULL, /* SetData */
+    NULL, /* EnumFormatEtc */
+    NULL, /* DAdvise */
+    NULL, /* DUnadvise */
+    NULL  /* EnumDAdvise */
+};
+
+static struct obj obj =
+{
+    {&obj_data_object_vtbl}
+};
+
+static void test_GetDataHere_Proxy(void)
+{
+    HRESULT hr;
+    IStream *stm;
+    HANDLE thread;
+    DWORD tid;
+    static const LARGE_INTEGER zero;
+    IDataObject *data;
+    FORMATETC fmt;
+    STGMEDIUM med;
+
+    hr = CreateStreamOnHGlobal( NULL, TRUE, &stm );
+    ok( hr == S_OK, "got %08x\n", hr );
+    tid = start_host_object2( stm, &IID_IDataObject, (IUnknown *)&obj.IDataObject_iface, MSHLFLAGS_NORMAL, NULL, &thread );
+
+    IStream_Seek( stm, zero, STREAM_SEEK_SET, NULL );
+    hr = CoUnmarshalInterface( stm, &IID_IDataObject, (void **)&data );
+    ok( hr == S_OK, "got %08x\n", hr );
+    IStream_Release( stm );
+
+    Test_Stream.refs = 1;
+    Test_Stream2.refs = 1;
+    Test_Unknown.refs = 1;
+
+    fmt.cfFormat = 1;
+    fmt.ptd = NULL;
+    fmt.dwAspect = DVASPECT_CONTENT;
+    fmt.lindex = -1;
+    U(med).pstm = NULL;
+    med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
+
+    fmt.tymed = med.tymed = TYMED_NULL;
+    hr = IDataObject_GetDataHere( data, &fmt, &med );
+    ok( hr == DV_E_TYMED, "got %08x\n", hr );
+
+    for (fmt.tymed = TYMED_HGLOBAL; fmt.tymed <= TYMED_ENHMF; fmt.tymed <<= 1)
+    {
+        med.tymed = fmt.tymed;
+        hr = IDataObject_GetDataHere( data, &fmt, &med );
+        ok( hr == (fmt.tymed <= TYMED_ISTORAGE ? S_OK : DV_E_TYMED), "got %08x for tymed %d\n", hr, fmt.tymed );
+        ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
+    }
+
+    fmt.tymed = TYMED_ISTREAM;
+    med.tymed = TYMED_ISTORAGE;
+    hr = IDataObject_GetDataHere( data, &fmt, &med );
+    ok( hr == DV_E_TYMED, "got %08x\n", hr );
+
+    fmt.tymed = med.tymed = TYMED_ISTREAM;
+    U(med).pstm = &Test_Stream.IStream_iface;
+    med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
+
+    hr = IDataObject_GetDataHere( data, &fmt, &med );
+    ok( hr == S_OK, "got %08x\n", hr );
+
+    ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" );
+    ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" );
+
+    ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs );
+    ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
+
+    fmt.cfFormat = 2;
+    fmt.tymed = med.tymed = TYMED_ISTREAM;
+    U(med).pstm = &Test_Stream.IStream_iface;
+    med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
+
+    hr = IDataObject_GetDataHere( data, &fmt, &med );
+    ok( hr == S_OK, "got %08x\n", hr );
+
+    ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" );
+    ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" );
+
+    ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs );
+    ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
+    ok( Test_Stream2.refs == 0, "got %d\n", Test_Stream2.refs );
+
+    IDataObject_Release( data );
+    end_host_object( tid, thread );
+}
+
 START_TEST(usrmarshal)
 {
-    CoInitialize(NULL);
+    CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
 
     test_marshal_CLIPFORMAT();
     test_marshal_HWND();
@@ -1122,5 +1354,7 @@ START_TEST(usrmarshal)
     test_marshal_HICON();
     test_marshal_HBRUSH();
 
+    test_GetDataHere_Proxy();
+
     CoUninitialize();
 }
diff --git a/dlls/ole32/usrmarshal.c b/dlls/ole32/usrmarshal.c
index 1a3f6af..89d0675 100644
--- a/dlls/ole32/usrmarshal.c
+++ b/dlls/ole32/usrmarshal.c
@@ -2783,13 +2783,39 @@ HRESULT __RPC_STUB IDataObject_GetData_Stub(
     return IDataObject_GetData(This, pformatetcIn, pRemoteMedium);
 }
 
-HRESULT CALLBACK IDataObject_GetDataHere_Proxy(
-    IDataObject* This,
-    FORMATETC *pformatetc,
-    STGMEDIUM *pmedium)
+HRESULT CALLBACK IDataObject_GetDataHere_Proxy(IDataObject *iface, FORMATETC *fmt, STGMEDIUM *med)
 {
-    TRACE("(%p)->(%p, %p)\n", This, pformatetc, pmedium);
-    return IDataObject_RemoteGetDataHere_Proxy(This, pformatetc, pmedium);
+    IUnknown *release;
+    IStorage *stg = NULL;
+    HRESULT hr;
+
+    TRACE("(%p)->(%p, %p)\n", iface, fmt, med);
+
+    if ((med->tymed & (TYMED_HGLOBAL | TYMED_FILE | TYMED_ISTREAM | TYMED_ISTORAGE)) == 0)
+        return DV_E_TYMED;
+    if (med->tymed != fmt->tymed)
+        return DV_E_TYMED;
+
+    release = med->pUnkForRelease;
+    med->pUnkForRelease = NULL;
+
+    if (med->tymed == TYMED_ISTREAM || med->tymed == TYMED_ISTORAGE)
+    {
+        stg = med->u.pstg; /* This may actually be a stream, but that's ok */
+        if (stg) IStorage_AddRef( stg );
+    }
+
+    hr = IDataObject_RemoteGetDataHere_Proxy(iface, fmt, med);
+
+    med->pUnkForRelease = release;
+    if (stg)
+    {
+        if (med->u.pstg)
+            IStorage_Release( med->u.pstg );
+        med->u.pstg = stg;
+    }
+
+    return hr;
 }
 
 HRESULT __RPC_STUB IDataObject_GetDataHere_Stub(




More information about the wine-cvs mailing list