Jacek Caban : itss: Support COM aggregation in its protocol handler.

Alexandre Julliard julliard at winehq.org
Tue May 22 15:37:14 CDT 2018


Module: wine
Branch: master
Commit: cd5570d9efd0cea19e5863fb334da3d40dd9658d
URL:    https://source.winehq.org/git/wine.git/?a=commit;h=cd5570d9efd0cea19e5863fb334da3d40dd9658d

Author: Jacek Caban <jacek at codeweavers.com>
Date:   Tue May 22 11:31:59 2018 +0200

itss: Support COM aggregation in its protocol handler.

Signed-off-by: Jacek Caban <jacek at codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard at winehq.org>

---

 dlls/itss/itss.c           | 29 ++++++++++++------
 dlls/itss/protocol.c       | 76 ++++++++++++++++++++++++++++++++--------------
 dlls/itss/tests/protocol.c | 64 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 138 insertions(+), 31 deletions(-)

diff --git a/dlls/itss/itss.c b/dlls/itss/itss.c
index f47a0e9..1d236a5 100644
--- a/dlls/itss/itss.c
+++ b/dlls/itss/itss.c
@@ -105,20 +105,31 @@ static ULONG WINAPI ITSSCF_Release(LPCLASSFACTORY iface)
 }
 
 
-static HRESULT WINAPI ITSSCF_CreateInstance(LPCLASSFACTORY iface, LPUNKNOWN pOuter,
-					  REFIID riid, LPVOID *ppobj)
+static HRESULT WINAPI ITSSCF_CreateInstance(IClassFactory *iface, IUnknown *outer,
+                                            REFIID riid, void **ppv)
 {
     IClassFactoryImpl *This = impl_from_IClassFactory(iface);
+    IUnknown *unk;
     HRESULT hres;
-    LPUNKNOWN punk;
 
-    TRACE("(%p)->(%p,%s,%p)\n", This, pOuter, debugstr_guid(riid), ppobj);
+    TRACE("(%p)->(%p %s %p)\n", This, outer, debugstr_guid(riid), ppv);
 
-    *ppobj = NULL;
-    hres = This->pfnCreateInstance(pOuter, (LPVOID *) &punk);
-    if (SUCCEEDED(hres)) {
-        hres = IUnknown_QueryInterface(punk, riid, ppobj);
-        IUnknown_Release(punk);
+    if(outer && !IsEqualGUID(riid, &IID_IUnknown)) {
+        *ppv = NULL;
+        return CLASS_E_NOAGGREGATION;
+    }
+
+    hres = This->pfnCreateInstance(outer, (void**)&unk);
+    if(FAILED(hres)) {
+        *ppv = NULL;
+        return hres;
+    }
+
+    if(!IsEqualGUID(riid, &IID_IUnknown)) {
+        hres = IUnknown_QueryInterface(unk, riid, ppv);
+        IUnknown_Release(unk);
+    }else {
+        *ppv = unk;
     }
     return hres;
 }
diff --git a/dlls/itss/protocol.c b/dlls/itss/protocol.c
index 1463518..1cdb365 100644
--- a/dlls/itss/protocol.c
+++ b/dlls/itss/protocol.c
@@ -36,16 +36,23 @@
 WINE_DEFAULT_DEBUG_CHANNEL(itss);
 
 typedef struct {
+    IUnknown              IUnknown_inner;
     IInternetProtocol     IInternetProtocol_iface;
     IInternetProtocolInfo IInternetProtocolInfo_iface;
 
     LONG ref;
+    IUnknown *outer;
 
     ULONG offset;
     struct chmFile *chm_file;
     struct chmUnitInfo chm_object;
 } ITSProtocol;
 
+static inline ITSProtocol *impl_from_IUnknown(IUnknown *iface)
+{
+    return CONTAINING_RECORD(iface, ITSProtocol, IUnknown_inner);
+}
+
 static inline ITSProtocol *impl_from_IInternetProtocol(IInternetProtocol *iface)
 {
     return CONTAINING_RECORD(iface, ITSProtocol, IInternetProtocol_iface);
@@ -65,14 +72,13 @@ static void release_chm(ITSProtocol *This)
     This->offset = 0;
 }
 
-static HRESULT WINAPI ITSProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv)
+static HRESULT WINAPI ITSProtocol_QueryInterface(IUnknown *iface, REFIID riid, void **ppv)
 {
-    ITSProtocol *This = impl_from_IInternetProtocol(iface);
+    ITSProtocol *This = impl_from_IUnknown(iface);
 
-    *ppv = NULL;
     if(IsEqualGUID(&IID_IUnknown, riid)) {
         TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv);
-        *ppv = &This->IInternetProtocol_iface;
+        *ppv = &This->IUnknown_inner;
     }else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) {
         TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv);
         *ppv = &This->IInternetProtocol_iface;
@@ -82,28 +88,27 @@ static HRESULT WINAPI ITSProtocol_QueryInterface(IInternetProtocol *iface, REFII
     }else if(IsEqualGUID(&IID_IInternetProtocolInfo, riid)) {
         TRACE("(%p)->(IID_IInternetProtocolInfo %p)\n", This, ppv);
         *ppv = &This->IInternetProtocolInfo_iface;
+    }else {
+        *ppv = NULL;
+        WARN("not supported interface %s\n", debugstr_guid(riid));
+        return E_NOINTERFACE;
     }
 
-    if(*ppv) {
-        IInternetProtocol_AddRef(iface);
-        return S_OK;
-    }
-
-    WARN("not supported interface %s\n", debugstr_guid(riid));
-    return E_NOINTERFACE;
+    IUnknown_AddRef((IUnknown*)*ppv);
+    return S_OK;
 }
 
-static ULONG WINAPI ITSProtocol_AddRef(IInternetProtocol *iface)
+static ULONG WINAPI ITSProtocol_AddRef(IUnknown *iface)
 {
-    ITSProtocol *This = impl_from_IInternetProtocol(iface);
+    ITSProtocol *This = impl_from_IUnknown(iface);
     LONG ref = InterlockedIncrement(&This->ref);
     TRACE("(%p) ref=%d\n", This, ref);
     return ref;
 }
 
-static ULONG WINAPI ITSProtocol_Release(IInternetProtocol *iface)
+static ULONG WINAPI ITSProtocol_Release(IUnknown *iface)
 {
-    ITSProtocol *This = impl_from_IInternetProtocol(iface);
+    ITSProtocol *This = impl_from_IUnknown(iface);
     LONG ref = InterlockedDecrement(&This->ref);
 
     TRACE("(%p) ref=%d\n", This, ref);
@@ -118,6 +123,30 @@ static ULONG WINAPI ITSProtocol_Release(IInternetProtocol *iface)
     return ref;
 }
 
+static const IUnknownVtbl ITSProtocolUnkVtbl = {
+    ITSProtocol_QueryInterface,
+    ITSProtocol_AddRef,
+    ITSProtocol_Release
+};
+
+static HRESULT WINAPI ITSInternetProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv)
+{
+    ITSProtocol *This = impl_from_IInternetProtocol(iface);
+    return IUnknown_QueryInterface(This->outer, riid, ppv);
+}
+
+static ULONG WINAPI ITSInternetProtocol_AddRef(IInternetProtocol *iface)
+{
+    ITSProtocol *This = impl_from_IInternetProtocol(iface);
+    return IUnknown_AddRef(This->outer);
+}
+
+static ULONG WINAPI ITSInternetProtocol_Release(IInternetProtocol *iface)
+{
+    ITSProtocol *This = impl_from_IInternetProtocol(iface);
+    return IUnknown_Release(This->outer);
+}
+
 static LPCWSTR skip_schema(LPCWSTR url)
 {
     static const WCHAR its_schema[] = {'i','t','s',':'};
@@ -387,9 +416,9 @@ static HRESULT WINAPI ITSProtocol_UnlockRequest(IInternetProtocol *iface)
 }
 
 static const IInternetProtocolVtbl ITSProtocolVtbl = {
-    ITSProtocol_QueryInterface,
-    ITSProtocol_AddRef,
-    ITSProtocol_Release,
+    ITSInternetProtocol_QueryInterface,
+    ITSInternetProtocol_AddRef,
+    ITSInternetProtocol_Release,
     ITSProtocol_Start,
     ITSProtocol_Continue,
     ITSProtocol_Abort,
@@ -520,21 +549,24 @@ static const IInternetProtocolInfoVtbl ITSProtocolInfoVtbl = {
     ITSProtocolInfo_QueryInfo
 };
 
-HRESULT ITSProtocol_create(IUnknown *pUnkOuter, LPVOID *ppobj)
+HRESULT ITSProtocol_create(IUnknown *outer, void **ppv)
 {
     ITSProtocol *ret;
 
-    TRACE("(%p %p)\n", pUnkOuter, ppobj);
+    TRACE("(%p %p)\n", outer, ppv);
 
     ITSS_LockModule();
 
     ret = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(ITSProtocol));
+    if(!ret)
+        return E_OUTOFMEMORY;
 
+    ret->IUnknown_inner.lpVtbl = &ITSProtocolUnkVtbl;
     ret->IInternetProtocol_iface.lpVtbl = &ITSProtocolVtbl;
     ret->IInternetProtocolInfo_iface.lpVtbl = &ITSProtocolInfoVtbl;
     ret->ref = 1;
+    ret->outer = outer ? outer : &ret->IUnknown_inner;
 
-    *ppobj = &ret->IInternetProtocol_iface;
-
+    *ppv = &ret->IUnknown_inner;
     return S_OK;
 }
diff --git a/dlls/itss/tests/protocol.c b/dlls/itss/tests/protocol.c
index 663411f..3028da4 100644
--- a/dlls/itss/tests/protocol.c
+++ b/dlls/itss/tests/protocol.c
@@ -60,6 +60,7 @@ DEFINE_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
 DEFINE_EXPECT(ReportProgress_DIRECTBIND);
 DEFINE_EXPECT(ReportData);
 DEFINE_EXPECT(ReportResult);
+DEFINE_EXPECT(outer_QI_test);
 
 static HRESULT expect_hrResult;
 static IInternetProtocol *read_protocol = NULL;
@@ -660,6 +661,68 @@ static void delete_chm(void)
     ok(ret, "DeleteFileA failed: %d\n", GetLastError());
 }
 
+static const IID outer_test_iid = {0xabcabc00,0,0,{0,0,0,0,0,0,0,0x66}};
+
+static HRESULT WINAPI outer_QueryInterface(IUnknown *iface, REFIID riid, void **ppv)
+{
+    if(IsEqualGUID(riid, &outer_test_iid)) {
+        CHECK_EXPECT(outer_QI_test);
+        *ppv = (IUnknown*)0xdeadbeef;
+        return S_OK;
+    }
+    ok(0, "unexpected call %s\n", wine_dbgstr_guid(riid));
+    return E_NOINTERFACE;
+}
+
+static ULONG WINAPI outer_AddRef(IUnknown *iface)
+{
+    return 2;
+}
+
+static ULONG WINAPI outer_Release(IUnknown *iface)
+{
+    return 1;
+}
+
+static const IUnknownVtbl outer_vtbl = {
+    outer_QueryInterface,
+    outer_AddRef,
+    outer_Release
+};
+
+static void test_com_aggregation(const CLSID *clsid)
+{
+    IUnknown outer = { &outer_vtbl };
+    IClassFactory *class_factory;
+    IUnknown *unk, *unk2, *unk3;
+    HRESULT hres;
+
+    hres = CoGetClassObject(clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IClassFactory, (void**)&class_factory);
+    ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
+
+    hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IUnknown, (void**)&unk);
+    ok(hres == S_OK, "CreateInstance returned: %08x\n", hres);
+
+    hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocol, (void**)&unk2);
+    ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres);
+
+    SET_EXPECT(outer_QI_test);
+    hres = IUnknown_QueryInterface(unk2, &outer_test_iid, (void**)&unk3);
+    CHECK_CALLED(outer_QI_test);
+    ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres);
+    ok(unk3 == (IUnknown*)0xdeadbeef, "unexpected unk2\n");
+
+    IUnknown_Release(unk2);
+    IUnknown_Release(unk);
+
+    unk = (void*)0xdeadbeef;
+    hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IInternetProtocol, (void**)&unk);
+    ok(hres == CLASS_E_NOAGGREGATION, "CreateInstance returned: %08x\n", hres);
+    ok(!unk, "unk = %p\n", unk);
+
+    IClassFactory_Release(class_factory);
+}
+
 START_TEST(protocol)
 {
     OleInitialize(NULL);
@@ -669,6 +732,7 @@ START_TEST(protocol)
 
     test_its_protocol();
     test_mk_protocol();
+    test_com_aggregation(&CLSID_ITSProtocol);
 
     delete_chm();
     OleUninitialize();




More information about the wine-cvs mailing list