Nikolay Sivov : combase: Move CoGetTreatAsClass().

Alexandre Julliard julliard at winehq.org
Mon Aug 10 16:16:29 CDT 2020


Module: wine
Branch: master
Commit: ea8d11a8e54989cb3f98562d48a6a7054f387226
URL:    https://source.winehq.org/git/wine.git/?a=commit;h=ea8d11a8e54989cb3f98562d48a6a7054f387226

Author: Nikolay Sivov <nsivov at codeweavers.com>
Date:   Mon Aug 10 11:12:49 2020 +0300

combase: Move CoGetTreatAsClass().

Signed-off-by: Nikolay Sivov <nsivov at codeweavers.com>
Signed-off-by: Huw Davies <huw at codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard at winehq.org>

---

 dlls/combase/combase.c    | 203 ++++++++++++++++++++++++++++++++++++++++++++++
 dlls/combase/combase.spec |   2 +-
 dlls/ole32/compobj.c      |  50 ------------
 dlls/ole32/ole32.spec     |   2 +-
 4 files changed, 205 insertions(+), 52 deletions(-)

diff --git a/dlls/combase/combase.c b/dlls/combase/combase.c
index bcf9912cbc..a757407c71 100644
--- a/dlls/combase/combase.c
+++ b/dlls/combase/combase.c
@@ -20,14 +20,179 @@
 #define COBJMACROS
 #define NONAMELESSUNION
 
+#include "ntstatus.h"
+#define WIN32_NO_STATUS
 #define USE_COM_CONTEXT_DEF
 #include "objbase.h"
 #include "oleauto.h"
+#include "winternl.h"
 
 #include "wine/debug.h"
 
 WINE_DEFAULT_DEBUG_CHANNEL(ole);
 
+#define CHARS_IN_GUID 39
+
+static NTSTATUS create_key(HKEY *retkey, ACCESS_MASK access, OBJECT_ATTRIBUTES *attr)
+{
+    NTSTATUS status = NtCreateKey((HANDLE *)retkey, access, attr, 0, NULL, 0, NULL);
+
+    if (status == STATUS_OBJECT_NAME_NOT_FOUND)
+    {
+        HANDLE subkey, root = attr->RootDirectory;
+        WCHAR *buffer = attr->ObjectName->Buffer;
+        DWORD attrs, pos = 0, i = 0, len = attr->ObjectName->Length / sizeof(WCHAR);
+        UNICODE_STRING str;
+
+        while (i < len && buffer[i] != '\\') i++;
+        if (i == len) return status;
+
+        attrs = attr->Attributes;
+        attr->ObjectName = &str;
+
+        while (i < len)
+        {
+            str.Buffer = buffer + pos;
+            str.Length = (i - pos) * sizeof(WCHAR);
+            status = NtCreateKey(&subkey, access, attr, 0, NULL, 0, NULL);
+            if (attr->RootDirectory != root) NtClose(attr->RootDirectory);
+            if (status) return status;
+            attr->RootDirectory = subkey;
+            while (i < len && buffer[i] == '\\') i++;
+            pos = i;
+            while (i < len && buffer[i] != '\\') i++;
+        }
+        str.Buffer = buffer + pos;
+        str.Length = (i - pos) * sizeof(WCHAR);
+        attr->Attributes = attrs;
+        status = NtCreateKey((HANDLE *)retkey, access, attr, 0, NULL, 0, NULL);
+        if (attr->RootDirectory != root) NtClose(attr->RootDirectory);
+    }
+    return status;
+}
+
+static HKEY classes_root_hkey;
+
+static HKEY create_classes_root_hkey(DWORD access)
+{
+    HKEY hkey, ret = 0;
+    OBJECT_ATTRIBUTES attr;
+    UNICODE_STRING name;
+
+    attr.Length = sizeof(attr);
+    attr.RootDirectory = 0;
+    attr.ObjectName = &name;
+    attr.Attributes = 0;
+    attr.SecurityDescriptor = NULL;
+    attr.SecurityQualityOfService = NULL;
+    RtlInitUnicodeString(&name, L"\\Registry\\Machine\\Software\\Classes");
+
+    if (create_key( &hkey, access, &attr )) return 0;
+    TRACE( "%s -> %p\n", debugstr_w(attr.ObjectName->Buffer), hkey );
+
+    if (!(access & KEY_WOW64_64KEY))
+    {
+        if (!(ret = InterlockedCompareExchangePointer( (void **)&classes_root_hkey, hkey, 0 )))
+            ret = hkey;
+        else
+            NtClose( hkey );  /* somebody beat us to it */
+    }
+    else
+        ret = hkey;
+    return ret;
+}
+
+static HKEY get_classes_root_hkey(HKEY hkey, REGSAM access);
+
+static LSTATUS create_classes_key(HKEY hkey, const WCHAR *name, REGSAM access, HKEY *retkey)
+{
+    OBJECT_ATTRIBUTES attr;
+    UNICODE_STRING nameW;
+
+    if (!(hkey = get_classes_root_hkey(hkey, access)))
+        return ERROR_INVALID_HANDLE;
+
+    attr.Length = sizeof(attr);
+    attr.RootDirectory = hkey;
+    attr.ObjectName = &nameW;
+    attr.Attributes = 0;
+    attr.SecurityDescriptor = NULL;
+    attr.SecurityQualityOfService = NULL;
+    RtlInitUnicodeString( &nameW, name );
+
+    return RtlNtStatusToDosError(create_key(retkey, access, &attr));
+}
+
+static HKEY get_classes_root_hkey(HKEY hkey, REGSAM access)
+{
+    HKEY ret = hkey;
+    const BOOL is_win64 = sizeof(void*) > sizeof(int);
+    const BOOL force_wow32 = is_win64 && (access & KEY_WOW64_32KEY);
+
+    if (hkey == HKEY_CLASSES_ROOT &&
+        ((access & KEY_WOW64_64KEY) || !(ret = classes_root_hkey)))
+        ret = create_classes_root_hkey(MAXIMUM_ALLOWED | (access & KEY_WOW64_64KEY));
+    if (force_wow32 && ret && ret == classes_root_hkey)
+    {
+        access &= ~KEY_WOW64_32KEY;
+        if (create_classes_key(classes_root_hkey, L"Wow6432Node", access, &hkey))
+            return 0;
+        ret = hkey;
+    }
+
+    return ret;
+}
+
+static LSTATUS open_classes_key(HKEY hkey, const WCHAR *name, REGSAM access, HKEY *retkey)
+{
+    OBJECT_ATTRIBUTES attr;
+    UNICODE_STRING nameW;
+
+    if (!(hkey = get_classes_root_hkey(hkey, access)))
+        return ERROR_INVALID_HANDLE;
+
+    attr.Length = sizeof(attr);
+    attr.RootDirectory = hkey;
+    attr.ObjectName = &nameW;
+    attr.Attributes = 0;
+    attr.SecurityDescriptor = NULL;
+    attr.SecurityQualityOfService = NULL;
+    RtlInitUnicodeString( &nameW, name );
+
+    return RtlNtStatusToDosError(NtOpenKey((HANDLE *)retkey, access, &attr));
+}
+
+static HRESULT open_key_for_clsid(REFCLSID clsid, const WCHAR *keyname, REGSAM access, HKEY *subkey)
+{
+    static const WCHAR clsidW[] = L"CLSID\\";
+    WCHAR path[CHARS_IN_GUID + ARRAY_SIZE(clsidW) - 1];
+    LONG res;
+    HKEY key;
+
+    lstrcpyW(path, clsidW);
+    StringFromGUID2(clsid, path + lstrlenW(clsidW), CHARS_IN_GUID);
+    res = open_classes_key(HKEY_CLASSES_ROOT, path, keyname ? KEY_READ : access, &key);
+    if (res == ERROR_FILE_NOT_FOUND)
+        return REGDB_E_CLASSNOTREG;
+    else if (res != ERROR_SUCCESS)
+        return REGDB_E_READREGDB;
+
+    if (!keyname)
+    {
+        *subkey = key;
+        return S_OK;
+    }
+
+    res = open_classes_key(key, keyname, access, subkey);
+    RegCloseKey(key);
+    if (res == ERROR_FILE_NOT_FOUND)
+        return REGDB_E_KEYMISSING;
+    else if (res != ERROR_SUCCESS)
+        return REGDB_E_READREGDB;
+
+    return S_OK;
+}
+
 /***********************************************************************
  *           FreePropVariantArray    (combase.@)
  */
@@ -628,6 +793,44 @@ HRESULT WINAPI CoGetActivationState(GUID guid, DWORD arg2, DWORD *arg3)
     return E_NOTIMPL;
 }
 
+/******************************************************************************
+ *          CoGetTreatAsClass       (combase.@)
+ */
+HRESULT WINAPI CoGetTreatAsClass(REFCLSID clsidOld, CLSID *clsidNew)
+{
+    WCHAR buffW[CHARS_IN_GUID];
+    LONG len = sizeof(buffW);
+    HRESULT hr = S_OK;
+    HKEY hkey = NULL;
+
+    TRACE("%s, %p.\n", debugstr_guid(clsidOld), clsidNew);
+
+    if (!clsidOld || !clsidNew)
+        return E_INVALIDARG;
+
+    *clsidNew = *clsidOld;
+
+    hr = open_key_for_clsid(clsidOld, L"TreatAs", KEY_READ, &hkey);
+    if (FAILED(hr))
+    {
+        hr = S_FALSE;
+        goto done;
+    }
+
+    if (RegQueryValueW(hkey, NULL, buffW, &len))
+    {
+        hr = S_FALSE;
+        goto done;
+    }
+
+    hr = CLSIDFromString(buffW, clsidNew);
+    if (FAILED(hr))
+        ERR("Failed to get CLSID from string %s, hr %#x.\n", debugstr_w(buffW), hr);
+done:
+    if (hkey) RegCloseKey(hkey);
+    return hr;
+}
+
 static void init_multi_qi(DWORD count, MULTI_QI *mqi, HRESULT hr)
 {
     ULONG i;
diff --git a/dlls/combase/combase.spec b/dlls/combase/combase.spec
index b8e374924a..497af6da8c 100644
--- a/dlls/combase/combase.spec
+++ b/dlls/combase/combase.spec
@@ -119,7 +119,7 @@
 @ stdcall CoGetStandardMarshal(ptr ptr long ptr long ptr) ole32.CoGetStandardMarshal
 @ stub CoGetStdMarshalEx
 @ stub CoGetSystemSecurityPermissions
-@ stdcall CoGetTreatAsClass(ptr ptr) ole32.CoGetTreatAsClass
+@ stdcall CoGetTreatAsClass(ptr ptr)
 @ stdcall CoImpersonateClient()
 @ stdcall CoIncrementMTAUsage(ptr) ole32.CoIncrementMTAUsage
 @ stdcall CoInitializeEx(ptr long) ole32.CoInitializeEx
diff --git a/dlls/ole32/compobj.c b/dlls/ole32/compobj.c
index cc64d7c036..1f6a952f15 100644
--- a/dlls/ole32/compobj.c
+++ b/dlls/ole32/compobj.c
@@ -3707,56 +3707,6 @@ done:
     return res;
 }
 
-/******************************************************************************
- *              CoGetTreatAsClass        [OLE32.@]
- *
- * Gets the TreatAs value of a class.
- *
- * PARAMS
- *  clsidOld [I] Class to get the TreatAs value of.
- *  clsidNew [I] The class the clsidOld should be treated as.
- *
- * RETURNS
- *  Success: S_OK.
- *  Failure: HRESULT code.
- *
- * SEE ALSO
- *  CoSetTreatAsClass
- */
-HRESULT WINAPI CoGetTreatAsClass(REFCLSID clsidOld, LPCLSID clsidNew)
-{
-    static const WCHAR wszTreatAs[] = {'T','r','e','a','t','A','s',0};
-    HKEY hkey = NULL;
-    WCHAR szClsidNew[CHARS_IN_GUID];
-    HRESULT res = S_OK;
-    LONG len = sizeof(szClsidNew);
-
-    TRACE("(%s,%p)\n", debugstr_guid(clsidOld), clsidNew);
-
-    if (!clsidOld || !clsidNew)
-        return E_INVALIDARG;
-
-    *clsidNew = *clsidOld; /* copy over old value */
-
-    res = COM_OpenKeyForCLSID(clsidOld, wszTreatAs, KEY_READ, &hkey);
-    if (FAILED(res))
-    {
-        res = S_FALSE;
-        goto done;
-    }
-    if (RegQueryValueW(hkey, NULL, szClsidNew, &len))
-    {
-        res = S_FALSE;
-	goto done;
-    }
-    res = CLSIDFromString(szClsidNew,clsidNew);
-    if (FAILED(res))
-        ERR("Failed CLSIDFromStringA(%s), hres 0x%08x\n", debugstr_w(szClsidNew), res);
-done:
-    if (hkey) RegCloseKey(hkey);
-    return res;
-}
-
 /******************************************************************************
  *		CoGetCurrentProcess	[OLE32.@]
  */
diff --git a/dlls/ole32/ole32.spec b/dlls/ole32/ole32.spec
index 550b3149d4..e51b2b5887 100644
--- a/dlls/ole32/ole32.spec
+++ b/dlls/ole32/ole32.spec
@@ -46,7 +46,7 @@
 @ stdcall CoGetStandardMarshal(ptr ptr long ptr long ptr)
 @ stdcall CoGetState(ptr)
 @ stub CoGetTIDFromIPID
-@ stdcall CoGetTreatAsClass(ptr ptr)
+@ stdcall CoGetTreatAsClass(ptr ptr) combase.CoGetTreatAsClass
 @ stdcall CoImpersonateClient() combase.CoImpersonateClient
 @ stdcall CoIncrementMTAUsage(ptr)
 @ stdcall CoInitialize(ptr)




More information about the wine-cvs mailing list