[PATCH v2 06/12] secur32: Prepare schan_send() buffers on PE side.

Nikolay Sivov wine at gitlab.winehq.org
Thu Jun 2 06:45:46 CDT 2022


From: Nikolay Sivov <nsivov at codeweavers.com>

Signed-off-by: Nikolay Sivov <nsivov at codeweavers.com>
---
 dlls/secur32/schannel.c        | 45 +++++++++++++++---
 dlls/secur32/schannel_gnutls.c | 84 ++++------------------------------
 2 files changed, 46 insertions(+), 83 deletions(-)

diff --git a/dlls/secur32/schannel.c b/dlls/secur32/schannel.c
index 289802c5b32..4508bc2f4b6 100644
--- a/dlls/secur32/schannel.c
+++ b/dlls/secur32/schannel.c
@@ -1299,8 +1299,12 @@ static SECURITY_STATUS SEC_ENTRY schan_EncryptMessage(PCtxtHandle context_handle
     SIZE_T data_size;
     SIZE_T length;
     char *data;
-    int idx, output_buffer_idx = -1;
+    int output_buffer_idx = -1;
     ULONG output_offset = 0;
+    SecBufferDesc output_desc = { 0 };
+    SecBuffer output_buffers[3];
+    int header_idx, data_idx, trailer_idx = -1;
+    int buffer_index[3];
 
     TRACE("context_handle %p, quality %ld, message %p, message_seq_no %ld\n",
             context_handle, quality, message, message_seq_no);
@@ -1310,29 +1314,56 @@ static SECURITY_STATUS SEC_ENTRY schan_EncryptMessage(PCtxtHandle context_handle
 
     dump_buffer_desc(message);
 
-    idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_DATA);
-    if (idx == -1)
+    data_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_DATA);
+    if (data_idx == -1)
     {
         WARN("No data buffer passed\n");
         return SEC_E_INTERNAL_ERROR;
     }
-    buffer = &message->pBuffers[idx];
+    buffer = &message->pBuffers[data_idx];
 
     data_size = buffer->cbBuffer;
     data = malloc(data_size);
     memcpy(data, buffer->pvBuffer, data_size);
 
+    /* Use { STREAM_HEADER, DATA, STREAM_TRAILER } or { TOKEN, DATA, TOKEN } buffers. */
+
+    output_desc.pBuffers = output_buffers;
+    if ((header_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_STREAM_HEADER)) == -1)
+    {
+        if ((header_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_TOKEN)) != -1)
+        {
+            output_buffers[output_desc.cBuffers++] = message->pBuffers[header_idx];
+            output_buffers[output_desc.cBuffers++] = message->pBuffers[data_idx];
+            trailer_idx = schan_find_sec_buffer_idx(message, header_idx + 1, SECBUFFER_TOKEN);
+            if (trailer_idx != -1)
+                output_buffers[output_desc.cBuffers++] = message->pBuffers[trailer_idx];
+        }
+    }
+    else
+    {
+        output_buffers[output_desc.cBuffers++] = message->pBuffers[header_idx];
+        output_buffers[output_desc.cBuffers++] = message->pBuffers[data_idx];
+        trailer_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_STREAM_TRAILER);
+        if (trailer_idx != -1)
+            output_buffers[output_desc.cBuffers++] = message->pBuffers[trailer_idx];
+    }
+
+    buffer_index[0] = header_idx;
+    buffer_index[1] = data_idx;
+    buffer_index[2] = trailer_idx;
+
     length = data_size;
     params.session = ctx->session;
-    params.output = message;
+    params.output = &output_desc;
     params.buffer = data;
     params.length = &length;
     params.output_buffer_idx = &output_buffer_idx;
     params.output_offset = &output_offset;
     status = GNUTLS_CALL( send, &params );
 
-    if (!status && output_buffer_idx != -1)
-        message->pBuffers[output_buffer_idx].cbBuffer = output_offset;
+    if (!status)
+        message->pBuffers[buffer_index[output_buffer_idx]].cbBuffer = output_offset;
 
     TRACE("Sent %Id bytes.\n", length);
 
diff --git a/dlls/secur32/schannel_gnutls.c b/dlls/secur32/schannel_gnutls.c
index ebce6bacde0..44ad7c2c1da 100644
--- a/dlls/secur32/schannel_gnutls.c
+++ b/dlls/secur32/schannel_gnutls.c
@@ -236,78 +236,13 @@ static void init_schan_buffers(struct schan_buffers *s, const PSecBufferDesc des
     s->get_next_buffer = get_next_buffer;
 }
 
-static int schan_find_sec_buffer_idx(const SecBufferDesc *desc, unsigned int start_idx, ULONG buffer_type)
+static int common_get_next_buffer(struct schan_buffers *s)
 {
-    unsigned int i;
-    PSecBuffer buffer;
-
-    for (i = start_idx; i < desc->cBuffers; ++i)
-    {
-        buffer = &desc->pBuffers[i];
-        if ((buffer->BufferType | SECBUFFER_ATTRMASK) == (buffer_type | SECBUFFER_ATTRMASK))
-            return i;
-    }
-
-    return -1;
-}
-
-static int handshake_get_next_buffer(struct schan_buffers *s)
-{
-    if (s->current_buffer_idx != -1)
-        return -1;
-    return s->desc->cBuffers ? 0 : -1;
-}
-
-static int send_message_get_next_buffer(struct schan_buffers *s)
-{
-    SecBuffer *b;
-
-    if (s->current_buffer_idx == -1)
-        return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_STREAM_HEADER);
-
-    b = &s->desc->pBuffers[s->current_buffer_idx];
-
-    if (b->BufferType == SECBUFFER_STREAM_HEADER)
-        return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_DATA);
-
-    if (b->BufferType == SECBUFFER_DATA)
-        return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_STREAM_TRAILER);
-
-    return -1;
-}
-
-static int send_message_get_next_buffer_token(struct schan_buffers *s)
-{
-    SecBuffer *b;
-
     if (s->current_buffer_idx == -1)
-        return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_TOKEN);
-
-    b = &s->desc->pBuffers[s->current_buffer_idx];
-
-    if (b->BufferType == SECBUFFER_TOKEN)
-    {
-        int idx = schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_TOKEN);
-        if (idx != s->current_buffer_idx) return -1;
-        return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_DATA);
-    }
-
-    if (b->BufferType == SECBUFFER_DATA)
-    {
-        int idx = schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_TOKEN);
-        if (idx != -1)
-            idx = schan_find_sec_buffer_idx(s->desc, idx + 1, SECBUFFER_TOKEN);
-        return idx;
-    }
-
-    return -1;
-}
-
-static int recv_message_get_next_buffer(struct schan_buffers *s)
-{
-    if (s->current_buffer_idx != -1)
+        return s->desc->cBuffers ? 0 : -1;
+    if (s->current_buffer_idx == s->desc->cBuffers - 1)
         return -1;
-    return s->desc->cBuffers ? 0 : -1;
+    return s->current_buffer_idx + 1;
 }
 
 static char *get_buffer(struct schan_buffers *s, SIZE_T *count)
@@ -583,9 +518,9 @@ static NTSTATUS schan_handshake( void *args )
     NTSTATUS status;
     int err;
 
-    init_schan_buffers(&t->in, params->input, handshake_get_next_buffer);
+    init_schan_buffers(&t->in, params->input, common_get_next_buffer);
     t->in.limit = params->input_size;
-    init_schan_buffers(&t->out, params->output, handshake_get_next_buffer);
+    init_schan_buffers(&t->out, params->output, common_get_next_buffer);
 
     while (1)
     {
@@ -850,10 +785,7 @@ static NTSTATUS schan_send( void *args )
     struct schan_transport *t = (struct schan_transport *)pgnutls_transport_get_ptr(s);
     SSIZE_T ret, total = 0;
 
-    if (schan_find_sec_buffer_idx(params->output, 0, SECBUFFER_STREAM_HEADER) != -1)
-        init_schan_buffers(&t->out, params->output, send_message_get_next_buffer);
-    else
-        init_schan_buffers(&t->out, params->output, send_message_get_next_buffer_token);
+    init_schan_buffers(&t->out, params->output, common_get_next_buffer);
 
     for (;;)
     {
@@ -893,7 +825,7 @@ static NTSTATUS schan_recv( void *args )
     ssize_t ret;
     SECURITY_STATUS status = SEC_E_OK;
 
-    init_schan_buffers(&t->in, params->input, recv_message_get_next_buffer);
+    init_schan_buffers(&t->in, params->input, common_get_next_buffer);
     t->in.limit = params->input_size;
 
     while (received < data_size)
-- 
GitLab


https://gitlab.winehq.org/wine/wine/-/merge_requests/160



More information about the wine-devel mailing list