[PATCH 3/5] secur32: Add DTLS support.

Hans Leidekker hans at codeweavers.com
Wed Mar 10 08:50:26 CST 2021


Signed-off-by: Hans Leidekker <hans at codeweavers.com>
---
 dlls/secur32/schannel.c        | 89 +++++++++++++++++++++-------------
 dlls/secur32/schannel_gnutls.c | 39 ++++++++++++++-
 2 files changed, 92 insertions(+), 36 deletions(-)

diff --git a/dlls/secur32/schannel.c b/dlls/secur32/schannel.c
index 3a34007a8cb..9a1dfd42152 100644
--- a/dlls/secur32/schannel.c
+++ b/dlls/secur32/schannel.c
@@ -60,6 +60,7 @@ struct schan_context
     struct schan_transport transport;
     ULONG req_ctx_attr;
     const CERT_CONTEXT *cert;
+    SIZE_T header_size;
 };
 
 static struct schan_handle *schan_handle_table;
@@ -184,7 +185,9 @@ static void read_config(void)
         {{'S','S','L',' ','3','.','0',0}, SP_PROT_SSL3_CLIENT, TRUE, FALSE},
         {{'T','L','S',' ','1','.','0',0}, SP_PROT_TLS1_0_CLIENT, TRUE, FALSE},
         {{'T','L','S',' ','1','.','1',0}, SP_PROT_TLS1_1_CLIENT, TRUE, FALSE /* NOTE: not enabled by default on Windows */ },
-        {{'T','L','S',' ','1','.','2',0}, SP_PROT_TLS1_2_CLIENT, TRUE, FALSE /* NOTE: not enabled by default on Windows */ }
+        {{'T','L','S',' ','1','.','2',0}, SP_PROT_TLS1_2_CLIENT, TRUE, FALSE /* NOTE: not enabled by default on Windows */ },
+        {{'D','T','L','S',' ','1','.','0',0}, SP_PROT_DTLS1_0_CLIENT, TRUE, TRUE },
+        {{'D','T','L','S',' ','1','.','2',0}, SP_PROT_DTLS1_2_CLIENT, TRUE, TRUE },
     };
 
     /* No need for thread safety */
@@ -399,10 +402,17 @@ static SECURITY_STATUS schan_AcquireClientCredentials(const SCHANNEL_CRED *schan
 
     if (schanCred)
     {
+        const unsigned dtls_protocols = SP_PROT_DTLS_CLIENT | SP_PROT_DTLS1_2_CLIENT;
+        const unsigned tls_protocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_0_CLIENT | SP_PROT_TLS1_1_CLIENT |
+                                       SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_3_CLIENT;
+
         status = get_cert(schanCred, &cert);
         if (status != SEC_E_OK && status != SEC_E_NO_CREDENTIALS)
             return status;
 
+        if ((schanCred->grbitEnabledProtocols & tls_protocols) &&
+            (schanCred->grbitEnabledProtocols & dtls_protocols)) return SEC_E_ALGORITHM_MISMATCH;
+
         status = SEC_E_OK;
     }
 
@@ -773,12 +783,18 @@ static void dump_buffer_desc(SecBufferDesc *desc)
 }
 
 #define HEADER_SIZE_TLS  5
+#define HEADER_SIZE_DTLS 13
 
 static inline SIZE_T read_record_size(const BYTE *buf, SIZE_T header_size)
 {
     return (buf[header_size - 2] << 8) | buf[header_size - 1];
 }
 
+static inline BOOL is_dtls_context(const struct schan_context *ctx)
+{
+    return (ctx->header_size == HEADER_SIZE_DTLS);
+}
+
 /***********************************************************************
  *              InitializeSecurityContextW
  */
@@ -836,6 +852,11 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
             return SEC_E_INTERNAL_ERROR;
         }
 
+        if (cred->enabled_protocols & (SP_PROT_DTLS1_0_CLIENT | SP_PROT_DTLS1_2_CLIENT))
+            ctx->header_size = HEADER_SIZE_DTLS;
+        else
+            ctx->header_size = HEADER_SIZE_TLS;
+
         ctx->transport.ctx = ctx;
         schan_imp_set_session_transport(ctx->session, &ctx->transport);
 
@@ -866,38 +887,38 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
         SIZE_T record_size = 0;
         unsigned char *ptr;
 
-        if (!pInput)
-            return SEC_E_INCOMPLETE_MESSAGE;
-
-        idx = schan_find_sec_buffer_idx(pInput, 0, SECBUFFER_TOKEN);
-        if (idx == -1)
-            return SEC_E_INCOMPLETE_MESSAGE;
+        ctx = schan_get_object(phContext->dwLower, SCHAN_HANDLE_CTX);
+        if (pInput)
+        {
+            idx = schan_find_sec_buffer_idx(pInput, 0, SECBUFFER_TOKEN);
+            if (idx == -1)
+                return SEC_E_INCOMPLETE_MESSAGE;
 
-        buffer = &pInput->pBuffers[idx];
-        ptr = buffer->pvBuffer;
-        expected_size = 0;
+            buffer = &pInput->pBuffers[idx];
+            ptr = buffer->pvBuffer;
+            expected_size = 0;
 
-        while (buffer->cbBuffer > expected_size + HEADER_SIZE_TLS)
-        {
-            record_size = HEADER_SIZE_TLS + read_record_size(ptr, HEADER_SIZE_TLS);
+            while (buffer->cbBuffer > expected_size + ctx->header_size)
+            {
+                record_size = ctx->header_size + read_record_size(ptr, ctx->header_size);
 
-            if (buffer->cbBuffer < expected_size + record_size)
-                break;
+                if (buffer->cbBuffer < expected_size + record_size)
+                    break;
 
-            expected_size += record_size;
-            ptr += record_size;
-        }
+                expected_size += record_size;
+                ptr += record_size;
+            }
 
-        if (!expected_size)
-        {
-            TRACE("Expected at least %lu bytes, but buffer only contains %u bytes.\n",
-                    max(6, record_size), buffer->cbBuffer);
-            return SEC_E_INCOMPLETE_MESSAGE;
+            if (!expected_size)
+            {
+                TRACE("Expected at least %lu bytes, but buffer only contains %u bytes.\n",
+                      max(6, record_size), buffer->cbBuffer);
+                return SEC_E_INCOMPLETE_MESSAGE;
+            }
         }
+        else if (!is_dtls_context(ctx)) return SEC_E_INCOMPLETE_MESSAGE;
 
         TRACE("Using expected_size %lu.\n", expected_size);
-
-        ctx = schan_get_object(phContext->dwLower, SCHAN_HANDLE_CTX);
     }
 
     ctx->req_ctx_attr = fContextReq;
@@ -1038,11 +1059,11 @@ static SECURITY_STATUS SEC_ENTRY schan_QueryContextAttributesW(
                 unsigned int block_size = schan_imp_get_session_cipher_block_size(ctx->session);
                 unsigned int message_size = schan_imp_get_max_message_size(ctx->session);
 
-                TRACE("Using %lu mac bytes, message size %u, block size %u\n",
-                        mac_size, message_size, block_size);
+                TRACE("Using header size %lu mac bytes %lu, message size %u, block size %u\n",
+                      ctx->header_size, mac_size, message_size, block_size);
 
                 /* These are defined by the TLS RFC */
-                stream_sizes->cbHeader = HEADER_SIZE_TLS;
+                stream_sizes->cbHeader = ctx->header_size;
                 stream_sizes->cbTrailer = mac_size + 256; /* Max 255 bytes padding + 1 for padding size */
                 stream_sizes->cbMaximumMessage = message_size;
                 stream_sizes->cbBuffers = 4;
@@ -1367,7 +1388,7 @@ static SECURITY_STATUS SEC_ENTRY schan_DecryptMessage(PCtxtHandle context_handle
     buffer = &message->pBuffers[idx];
     buf_ptr = buffer->pvBuffer;
 
-    expected_size = HEADER_SIZE_TLS + read_record_size(buf_ptr, HEADER_SIZE_TLS);
+    expected_size = ctx->header_size + read_record_size(buf_ptr, ctx->header_size);
     if(buffer->cbBuffer < expected_size)
     {
         TRACE("Expected %u bytes, but buffer only contains %u bytes\n", expected_size, buffer->cbBuffer);
@@ -1384,7 +1405,7 @@ static SECURITY_STATUS SEC_ENTRY schan_DecryptMessage(PCtxtHandle context_handle
         return SEC_E_INCOMPLETE_MESSAGE;
     }
 
-    data_size = expected_size - HEADER_SIZE_TLS;
+    data_size = expected_size - ctx->header_size;
     data = heap_alloc(data_size);
 
     init_schan_buffers(&ctx->transport.in, message, schan_decrypt_message_get_next_buffer);
@@ -1419,21 +1440,21 @@ static SECURITY_STATUS SEC_ENTRY schan_DecryptMessage(PCtxtHandle context_handle
 
     TRACE("Received %ld bytes\n", received);
 
-    memcpy(buf_ptr + HEADER_SIZE_TLS, data, received);
+    memcpy(buf_ptr + ctx->header_size, data, received);
     heap_free(data);
 
     schan_decrypt_fill_buffer(message, SECBUFFER_DATA,
-        buf_ptr + HEADER_SIZE_TLS, received);
+        buf_ptr + ctx->header_size, received);
 
     schan_decrypt_fill_buffer(message, SECBUFFER_STREAM_TRAILER,
-        buf_ptr + HEADER_SIZE_TLS + received, buffer->cbBuffer - HEADER_SIZE_TLS - received);
+        buf_ptr + ctx->header_size + received, buffer->cbBuffer - ctx->header_size - received);
 
     if(buffer->cbBuffer > expected_size)
         schan_decrypt_fill_buffer(message, SECBUFFER_EXTRA,
             buf_ptr + expected_size, buffer->cbBuffer - expected_size);
 
     buffer->BufferType = SECBUFFER_STREAM_HEADER;
-    buffer->cbBuffer = HEADER_SIZE_TLS;
+    buffer->cbBuffer = ctx->header_size;
 
     return status;
 }
diff --git a/dlls/secur32/schannel_gnutls.c b/dlls/secur32/schannel_gnutls.c
index e342df3874d..fbf9277a39c 100644
--- a/dlls/secur32/schannel_gnutls.c
+++ b/dlls/secur32/schannel_gnutls.c
@@ -50,6 +50,10 @@ WINE_DECLARE_DEBUG_CHANNEL(winediag);
 /* Not present in gnutls version < 2.9.10. */
 static int (*pgnutls_cipher_get_block_size)(gnutls_cipher_algorithm_t);
 
+/* Not present in gnutls version < 3.0. */
+static void (*pgnutls_transport_set_pull_timeout_function)(gnutls_session_t,
+                                                           int (*)(gnutls_transport_ptr_t, unsigned int));
+
 /* Not present in gnutls version < 3.2.0. */
 static int (*pgnutls_alpn_get_selected_protocol)(gnutls_session_t, gnutls_datum_t *);
 static int (*pgnutls_alpn_set_protocols)(gnutls_session_t, const gnutls_datum_t *,
@@ -147,6 +151,12 @@ static int compat_cipher_get_block_size(gnutls_cipher_algorithm_t cipher)
     }
 }
 
+static void compat_gnutls_transport_set_pull_timeout_function(gnutls_session_t session,
+                                                              int (*func)(gnutls_transport_ptr_t, unsigned int))
+{
+    FIXME("\n");
+}
+
 static int compat_gnutls_privkey_export_x509(gnutls_privkey_t privkey, gnutls_x509_privkey_t *key)
 {
     FIXME("\n");
@@ -212,6 +222,8 @@ static const struct {
     DWORD enable_flag;
     const char *gnutls_flag;
 } protocol_priority_flags[] = {
+    {SP_PROT_DTLS1_2_CLIENT, "VERS-DTLS1.2"},
+    {SP_PROT_DTLS1_0_CLIENT, "VERS-DTLS1.0"},
     {SP_PROT_TLS1_3_CLIENT, "VERS-TLS1.3"},
     {SP_PROT_TLS1_2_CLIENT, "VERS-TLS1.2"},
     {SP_PROT_TLS1_1_CLIENT, "VERS-TLS1.1"},
@@ -257,14 +269,29 @@ DWORD schan_imp_enabled_protocols(void)
     return supported_protocols;
 }
 
+static int schan_pull_timeout(gnutls_transport_ptr_t transport, unsigned int timeout)
+{
+    struct schan_transport *t = (struct schan_transport *)transport;
+    SIZE_T count = 0;
+
+    if (schan_get_buffer(t, &t->in, &count)) return 1;
+    return 0;
+}
+
 BOOL schan_imp_create_session(schan_imp_session *session, schan_credentials *cred)
 {
     gnutls_session_t *s = (gnutls_session_t*)session;
     char priority[128] = "NORMAL:%LATEST_RECORD_VERSION", *p;
     BOOL using_vers_all = FALSE, disabled;
-    unsigned i;
+    unsigned int i, flags = (cred->credential_use == SECPKG_CRED_INBOUND) ? GNUTLS_SERVER : GNUTLS_CLIENT;
+    int err;
+
+    if (cred->enabled_protocols & (SP_PROT_DTLS1_0_CLIENT | SP_PROT_DTLS1_2_CLIENT))
+    {
+        flags |= GNUTLS_DATAGRAM | GNUTLS_NONBLOCK;
+    }
 
-    int err = pgnutls_init(s, cred->credential_use == SECPKG_CRED_INBOUND ? GNUTLS_SERVER : GNUTLS_CLIENT);
+    err = pgnutls_init(s, flags);
     if (err != GNUTLS_E_SUCCESS)
     {
         pgnutls_perror(err);
@@ -315,6 +342,7 @@ BOOL schan_imp_create_session(schan_imp_session *session, schan_credentials *cre
     }
 
     pgnutls_transport_set_pull_function(*s, schan_pull_adapter);
+    if (flags & GNUTLS_DATAGRAM) pgnutls_transport_set_pull_timeout_function(*s, schan_pull_timeout);
     pgnutls_transport_set_push_function(*s, schan_push_adapter);
 
     return TRUE;
@@ -400,6 +428,8 @@ static DWORD schannel_get_protocol(gnutls_protocol_t proto)
     case GNUTLS_TLS1_0: return SP_PROT_TLS1_0_CLIENT;
     case GNUTLS_TLS1_1: return SP_PROT_TLS1_1_CLIENT;
     case GNUTLS_TLS1_2: return SP_PROT_TLS1_2_CLIENT;
+    case GNUTLS_DTLS1_0: return SP_PROT_DTLS1_0_CLIENT;
+    case GNUTLS_DTLS1_2: return SP_PROT_DTLS1_2_CLIENT;
     default:
         FIXME("unknown protocol %d\n", proto);
         return 0;
@@ -1085,6 +1115,11 @@ BOOL schan_imp_init(void)
         WARN("gnutls_cipher_get_block_size not found\n");
         pgnutls_cipher_get_block_size = compat_cipher_get_block_size;
     }
+    if (!(pgnutls_transport_set_pull_timeout_function = dlsym(libgnutls_handle, "gnutls_transport_set_pull_timeout_function")))
+    {
+        WARN("gnutls_transport_set_pull_timeout_function not found\n");
+        pgnutls_transport_set_pull_timeout_function = compat_gnutls_transport_set_pull_timeout_function;
+    }
     if (!(pgnutls_alpn_set_protocols = dlsym(libgnutls_handle, "gnutls_alpn_set_protocols")))
     {
         WARN("gnutls_alpn_set_protocols not found\n");
-- 
2.30.1




More information about the wine-devel mailing list