[PATCH 2/2] crypt32: Support CERT_NAME_SEARCH_ALL_NAMES_FLAG in CertGetNameStringW().

Paul Gofman wine at gitlab.winehq.org
Tue May 3 15:47:15 CDT 2022


From: Paul Gofman <pgofman at codeweavers.com>

---
 dlls/crypt32/str.c       | 105 +++++++++++++++++++++++++++++----------
 dlls/crypt32/tests/str.c |  55 +++++++++++++++-----
 include/wincrypt.h       |   6 ++-
 3 files changed, 126 insertions(+), 40 deletions(-)

diff --git a/dlls/crypt32/str.c b/dlls/crypt32/str.c
index 0906ad883db..d74df308e4a 100644
--- a/dlls/crypt32/str.c
+++ b/dlls/crypt32/str.c
@@ -905,17 +905,7 @@ DWORD WINAPI CertGetNameStringA(PCCERT_CONTEXT cert, DWORD type,
     return ret;
 }
 
-/* Searches cert's extensions for the alternate name extension with OID
- * altNameOID, and if found, searches it for the alternate name type entryType.
- * If found, returns a pointer to the entry, otherwise returns NULL.
- * Regardless of whether an entry of the desired type is found, if the
- * alternate name extension is present, sets *info to the decoded alternate
- * name extension, which you must free using LocalFree.
- * The return value is a pointer within *info, so don't free *info before
- * you're done with the return value.
- */
-static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL alt_name_issuer,
-                                                     DWORD entryType, PCERT_ALT_NAME_INFO *info)
+static BOOL cert_get_alt_name_info(PCCERT_CONTEXT cert, BOOL alt_name_issuer, PCERT_ALT_NAME_INFO *info)
 {
     static const char *oids[][2] =
     {
@@ -924,24 +914,48 @@ static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL a
     };
     PCERT_EXTENSION ext;
     DWORD bytes = 0;
-    unsigned int i;
 
     ext = CertFindExtension(oids[!!alt_name_issuer][0], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension);
     if (!ext)
         ext = CertFindExtension(oids[!!alt_name_issuer][1], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension);
-    if (!ext) return NULL;
+    if (!ext) return FALSE;
 
-    if (!CryptDecodeObjectEx(cert->dwCertEncodingType, X509_ALTERNATE_NAME, ext->Value.pbData, ext->Value.cbData,
-                             CRYPT_DECODE_ALLOC_FLAG, NULL, info, &bytes))
-        return NULL;
+    return CryptDecodeObjectEx(cert->dwCertEncodingType, X509_ALTERNATE_NAME, ext->Value.pbData, ext->Value.cbData,
+                             CRYPT_DECODE_ALLOC_FLAG, NULL, info, &bytes);
+}
 
-    for (i = 0; i < (*info)->cAltEntry; ++i)
-        if ((*info)->rgAltEntry[i].dwAltNameChoice == entryType)
-            return &(*info)->rgAltEntry[i];
+static PCERT_ALT_NAME_ENTRY cert_find_next_alt_name_entry(PCERT_ALT_NAME_INFO info, DWORD entry_type,
+                                                          unsigned int *index)
+{
+    unsigned int i;
 
+    for (i = *index; i < info->cAltEntry; ++i)
+        if (info->rgAltEntry[i].dwAltNameChoice == entry_type)
+        {
+            *index = i + 1;
+            return &info->rgAltEntry[i];
+        }
     return NULL;
 }
 
+/* Searches cert's extensions for the alternate name extension with OID
+ * altNameOID, and if found, searches it for the alternate name type entryType.
+ * If found, returns a pointer to the entry, otherwise returns NULL.
+ * Regardless of whether an entry of the desired type is found, if the
+ * alternate name extension is present, sets *info to the decoded alternate
+ * name extension, which you must free using LocalFree.
+ * The return value is a pointer within *info, so don't free *info before
+ * you're done with the return value.
+ */
+static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL alt_name_issuer,
+                                                     DWORD entry_type, PCERT_ALT_NAME_INFO *info)
+{
+    unsigned int index = 0;
+
+    if (!cert_get_alt_name_info(cert, alt_name_issuer, info)) return NULL;
+    return cert_find_next_alt_name_entry(*info, entry_type, &index);
+}
+
 static DWORD cert_get_name_from_rdn_attr(DWORD encodingType,
  const CERT_NAME_BLOB *name, LPCSTR oid, LPWSTR pszNameString, DWORD cchNameString)
 {
@@ -978,9 +992,10 @@ static DWORD copy_output_str(WCHAR *dst, const WCHAR *src, DWORD dst_size)
 DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, void *type_para,
                                 LPWSTR name_string, DWORD name_len)
 {
+    static const DWORD supported_flags = CERT_NAME_ISSUER_FLAG | CERT_NAME_SEARCH_ALL_NAMES_FLAG;
+    BOOL alt_name_issuer, search_all_names;
     CERT_ALT_NAME_INFO *info = NULL;
     PCERT_ALT_NAME_ENTRY entry;
-    BOOL alt_name_issuer;
     PCERT_NAME_BLOB name;
     DWORD ret = 0;
 
@@ -989,6 +1004,16 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo
     if (!cert)
         goto done;
 
+    if (flags & ~supported_flags)
+        FIXME("Unsupported flags %#lx.\n", flags);
+
+    search_all_names = flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG;
+    if (search_all_names && type != CERT_NAME_DNS_TYPE)
+    {
+        WARN("CERT_NAME_SEARCH_ALL_NAMES_FLAG used with type %lu.\n", type);
+        goto done;
+    }
+
     alt_name_issuer = flags & CERT_NAME_ISSUER_FLAG;
     name = alt_name_issuer ? &cert->pCertInfo->Issuer : &cert->pCertInfo->Subject;
 
@@ -1077,15 +1102,43 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo
     }
     case CERT_NAME_DNS_TYPE:
     {
-        entry = cert_find_alt_name_entry(cert, alt_name_issuer, CERT_ALT_NAME_DNS_NAME, &info);
+        unsigned int index = 0, len;
 
-        if (entry)
+        if (cert_get_alt_name_info(cert, alt_name_issuer, &info)
+            && (entry = cert_find_next_alt_name_entry(info, CERT_ALT_NAME_DNS_NAME, &index)))
         {
-            ret = copy_output_str(name_string, entry->u.pwszDNSName, name_len);
-            break;
+            if (search_all_names)
+            {
+                do
+                {
+                    if (name_string && name_len == 1) break;
+                    ret += len = copy_output_str(name_string, entry->u.pwszDNSName, name_len ? name_len - 1 : 0);
+                    if (name_string && name_len)
+                    {
+                        name_string += len;
+                        name_len -= len;
+                    }
+                }
+                while ((entry = cert_find_next_alt_name_entry(info, CERT_ALT_NAME_DNS_NAME, &index)));
+            }
+            else ret = copy_output_str(name_string, entry->u.pwszDNSName, name_len);
+        }
+        else
+        {
+            if (!search_all_names || name_len != 1)
+            {
+                len = search_all_names && name_len ? name_len - 1 : name_len;
+                ret = cert_get_name_from_rdn_attr(cert->dwCertEncodingType, name, szOID_COMMON_NAME,
+                                                  name_string, len);
+                if (name_string) name_string += ret;
+            }
+        }
+
+        if (search_all_names)
+        {
+            if (name_string && name_len) *name_string = 0;
+            ++ret;
         }
-        ret = cert_get_name_from_rdn_attr(cert->dwCertEncodingType, name, szOID_COMMON_NAME,
-                                          name_string, name_len);
         break;
     }
     case CERT_NAME_URL_TYPE:
diff --git a/dlls/crypt32/tests/str.c b/dlls/crypt32/tests/str.c
index 9fa9efff6b9..5fb05bdb836 100644
--- a/dlls/crypt32/tests/str.c
+++ b/dlls/crypt32/tests/str.c
@@ -847,39 +847,63 @@ static void test_CertStrToNameW(void)
 static void test_CertGetNameString_value_(unsigned int line, PCCERT_CONTEXT context, DWORD type, DWORD flags,
         void *type_para, const char *expected)
 {
+    DWORD len, retlen, expected_len;
     WCHAR expectedW[512];
-    DWORD len, retlen;
     WCHAR strW[512];
-    unsigned int i;
     char str[512];
 
-    for (i = 0; expected[i]; ++i)
-        expectedW[i] = expected[i];
-    expectedW[i] = 0;
+    expected_len = 0;
+    while(expected[expected_len])
+    {
+        while((expectedW[expected_len] = expected[expected_len]))
+            ++expected_len;
+        if (!(flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG))
+            break;
+        expectedW[expected_len++] = 0;
+    }
+    expectedW[expected_len++] = 0;
 
     len = CertGetNameStringA(context, type, flags, type_para, NULL, 0);
-    ok(len == strlen(expected) + 1, "line %u: unexpected length %ld.\n", line, len);
+    ok(len == expected_len, "line %u: unexpected length %ld, expected %ld.\n", line, len, expected_len);
+    memset(str, 0xcc, len);
     retlen = CertGetNameStringA(context, type, flags, type_para, str, len);
     ok(retlen == len, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len);
-    ok(!strcmp(str, expected), "line %u: unexpected value %s.\n", line, str);
+    ok(!memcmp(str, expected, expected_len), "line %u: unexpected value %s.\n", line, debugstr_an(str, expected_len));
     str[0] = str[1] = 0xcc;
     retlen = CertGetNameStringA(context, type, flags, type_para, str, len - 1);
     ok(retlen == 1, "line %u: Unexpected len %lu, expected 1.\n", line, retlen);
     if (len == 1) return;
     ok(!str[0], "line %u: unexpected str[0] %#x.\n", line, str[0]);
     ok(str[1] == expected[1], "line %u: unexpected str[1] %#x.\n", line, str[1]);
-
+    ok(!memcmp(str + 1, expected + 1, len - 2),
+            "line %u: str %s, string data mismatch.\n", line, debugstr_a(str + 1));
     retlen = CertGetNameStringA(context, type, flags, type_para, str, 0);
     ok(retlen == len, "line %u: Unexpected len %lu, expected 1.\n", line, retlen);
 
+    memset(strW, 0xcc, len * sizeof(*strW));
     retlen = CertGetNameStringW(context, type, flags, type_para, strW, len);
-    ok(retlen == len, "line %u: unexpected len %lu, expected 1.\n", line, retlen);
-    ok(!wcscmp(strW, expectedW), "line %u: unexpected value %s.\n", line, debugstr_w(strW));
+    ok(retlen == expected_len, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, expected_len);
+    ok(!memcmp(strW, expectedW, len * sizeof(*strW)), "line %u: unexpected value %s.\n", line, debugstr_wn(strW, len));
     strW[0] = strW[1] = 0xcccc;
     retlen = CertGetNameStringW(context, type, flags, type_para, strW, len - 1);
     ok(retlen == len - 1, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len - 1);
-    ok(!wcsncmp(strW, expectedW, retlen - 1), "line %u: string data mismatch.\n", line);
-    ok(!strW[retlen - 1], "line %u: string is not zero terminated.\n", line);
+    if (flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG)
+    {
+        ok(!memcmp(strW, expectedW, (retlen - 2) * sizeof(*strW)),
+                "line %u: str %s, string data mismatch.\n", line, debugstr_wn(strW, retlen - 2));
+        ok(!strW[retlen - 2], "line %u: string is not zero terminated.\n", line);
+        ok(!strW[retlen - 1], "line %u: string sequence is not zero terminated.\n", line);
+
+        retlen = CertGetNameStringW(context, type, flags, type_para, strW, 1);
+        ok(retlen == 1, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len - 1);
+        ok(!strW[retlen - 1], "line %u: string sequence is not zero terminated.\n", line);
+    }
+    else
+    {
+        ok(!memcmp(strW, expectedW, (retlen - 1) * sizeof(*strW)),
+                "line %u: str %s, string data mismatch.\n", line, debugstr_wn(strW, retlen - 1));
+        ok(!strW[retlen - 1], "line %u: string is not zero terminated.\n", line);
+    }
     retlen = CertGetNameStringA(context, type, flags, type_para, NULL, len - 1);
     ok(retlen == len, "line %u: unexpected len %lu, expected %lu\n", line, retlen, len);
     retlen = CertGetNameStringW(context, type, flags, type_para, NULL, len - 1);
@@ -924,6 +948,9 @@ static void test_CertGetNameString(void)
     test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, localhost);
     test_CertGetNameString_value(context, CERT_NAME_FRIENDLY_DISPLAY_TYPE, 0, NULL, localhost);
     test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, 0, NULL, localhost);
+    test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "localhost\0");
+    test_CertGetNameString_value(context, CERT_NAME_EMAIL_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "");
+    test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "");
 
     CertFreeCertificateContext(context);
 
@@ -945,6 +972,10 @@ static void test_CertGetNameString(void)
     test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_ISSUER_FLAG, NULL, "ex3.org");
     test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, "server_cn.org");
     test_CertGetNameString_value(context, CERT_NAME_ATTR_TYPE, 0, (void *)szOID_SUR_NAME, "");
+    test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG,
+            NULL, "ex1.org\0*.ex2.org\0");
+    test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG | CERT_NAME_ISSUER_FLAG,
+            NULL, "ex3.org\0*.ex4.org\0");
     CertFreeCertificateContext(context);
 }
 
diff --git a/include/wincrypt.h b/include/wincrypt.h
index 2c1e3f0d4c3..77b2fb5d7cf 100644
--- a/include/wincrypt.h
+++ b/include/wincrypt.h
@@ -3351,8 +3351,10 @@ typedef struct _CTL_FIND_SUBJECT_PARA
 #define CERT_NAME_URL_TYPE              7
 #define CERT_NAME_UPN_TYPE              8
 
-#define CERT_NAME_ISSUER_FLAG           0x00000001
-#define CERT_NAME_DISABLE_IE4_UTF8_FLAG 0x00010000
+#define CERT_NAME_ISSUER_FLAG              0x00000001
+#define CERT_NAME_SEARCH_ALL_NAMES_FLAG    0x00000002
+#define CERT_NAME_DISABLE_IE4_UTF8_FLAG    0x00010000
+#define CERT_NAME_STR_ENABLE_PUNYCODE_FLAG 0x00200000
 
 /* CryptFormatObject flags */
 #define CRYPT_FORMAT_STR_MULTI_LINE 0x0001
-- 
GitLab

https://gitlab.winehq.org/wine/wine/-/merge_requests/30



More information about the wine-devel mailing list