[PATCH 4/4] winhttp: Allow synchronous nonblocking send in WinHttpWebSocketSend() when possible.

Paul Gofman pgofman at codeweavers.com
Mon Jan 24 03:43:14 CST 2022


Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
     I have a patch later in the queue which allows also trying sync send even
     if send size exceeds maximum frame buffer size or SSL buffer size. But those
     changes have almost no overlap with this patch (except for WSAEWOULDBLOCK
     status generation and handling is going away), so I thought it is easier
     to keep these (mostly independent) changes split.

 dlls/winhttp/net.c                | 37 +++++++-----
 dlls/winhttp/request.c            | 96 ++++++++++++++++++++++---------
 dlls/winhttp/tests/notification.c | 16 +++++-
 dlls/winhttp/winhttp_private.h    |  4 +-
 4 files changed, 111 insertions(+), 42 deletions(-)

diff --git a/dlls/winhttp/net.c b/dlls/winhttp/net.c
index 68aee036734..cdb67f74481 100644
--- a/dlls/winhttp/net.c
+++ b/dlls/winhttp/net.c
@@ -32,15 +32,23 @@
 
 WINE_DEFAULT_DEBUG_CHANNEL(winhttp);
 
-static int sock_send(int fd, const void *msg, size_t len, int flags)
+static int sock_send(int fd, const void *msg, size_t len, WSAOVERLAPPED *ovr)
 {
-    int ret;
-    do
+    WSABUF wsabuf;
+    DWORD size;
+    DWORD err;
+
+    wsabuf.len = len;
+    wsabuf.buf = (void *)msg;
+
+    if (!WSASend( (SOCKET)fd, &wsabuf, 1, &size, 0, ovr, NULL ))
     {
-        if ((ret = send(fd, msg, len, flags)) == -1) WARN("send error %u\n", WSAGetLastError());
+        assert( size == len );
+        return size;
     }
-    while(ret == -1 && WSAGetLastError() == WSAEINTR);
-    return ret;
+    err = WSAGetLastError();
+    if (!(ovr && err == WSA_IO_PENDING)) WARN( "send error %u\n", err );
+    return -1;
 }
 
 static int sock_recv(int fd, void *msg, size_t len, int flags)
@@ -190,7 +198,7 @@ DWORD netconn_create( struct hostdata *host, const struct sockaddr_storage *sock
     if (!(conn = calloc( 1, sizeof(*conn) ))) return ERROR_OUTOFMEMORY;
     conn->host = host;
     conn->sockaddr = *sockaddr;
-    if ((conn->socket = socket( sockaddr->ss_family, SOCK_STREAM, 0 )) == -1)
+    if ((conn->socket = WSASocketW( sockaddr->ss_family, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED )) == -1)
     {
         ret = WSAGetLastError();
         WARN("unable to create socket (%u)\n", ret);
@@ -290,7 +298,7 @@ DWORD netconn_secure_connect( struct netconn *conn, WCHAR *hostname, DWORD secur
 
             TRACE("sending %u bytes\n", out_buf.cbBuffer);
 
-            size = sock_send(conn->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
+            size = sock_send(conn->socket, out_buf.pvBuffer, out_buf.cbBuffer, NULL);
             if(size != out_buf.cbBuffer) {
                 ERR("send failed\n");
                 res = ERROR_WINHTTP_SECURE_CHANNEL_ERROR;
@@ -398,7 +406,7 @@ DWORD netconn_secure_connect( struct netconn *conn, WCHAR *hostname, DWORD secur
     return ERROR_SUCCESS;
 }
 
-static DWORD send_ssl_chunk( struct netconn *conn, const void *msg, size_t size )
+static DWORD send_ssl_chunk( struct netconn *conn, const void *msg, size_t size, WSAOVERLAPPED *ovr )
 {
     SecBuffer bufs[4] = {
         {conn->ssl_sizes.cbHeader, SECBUFFER_STREAM_HEADER, conn->ssl_write_buf},
@@ -416,7 +424,8 @@ static DWORD send_ssl_chunk( struct netconn *conn, const void *msg, size_t size
         return res;
     }
 
-    if (sock_send( conn->socket, conn->ssl_write_buf, bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer, 0 ) < 1)
+    if (sock_send( conn->socket, conn->ssl_write_buf,
+                   bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer, ovr ) < 1)
     {
         WARN("send failed\n");
         return WSAGetLastError();
@@ -425,7 +434,7 @@ static DWORD send_ssl_chunk( struct netconn *conn, const void *msg, size_t size
     return ERROR_SUCCESS;
 }
 
-DWORD netconn_send( struct netconn *conn, const void *msg, size_t len, int *sent )
+DWORD netconn_send( struct netconn *conn, const void *msg, size_t len, int *sent, WSAOVERLAPPED *ovr )
 {
     if (conn->secure)
     {
@@ -433,11 +442,13 @@ DWORD netconn_send( struct netconn *conn, const void *msg, size_t len, int *sent
         size_t chunk_size;
         DWORD res;
 
+        if (ovr && len > conn->ssl_sizes.cbMaximumMessage) return WSAEWOULDBLOCK;
+
         *sent = 0;
         while (len)
         {
             chunk_size = min( len, conn->ssl_sizes.cbMaximumMessage );
-            if ((res = send_ssl_chunk( conn, ptr, chunk_size )))
+            if ((res = send_ssl_chunk( conn, ptr, chunk_size, ovr )))
                 return res;
 
             *sent += chunk_size;
@@ -448,7 +459,7 @@ DWORD netconn_send( struct netconn *conn, const void *msg, size_t len, int *sent
         return ERROR_SUCCESS;
     }
 
-    if ((*sent = sock_send( conn->socket, msg, len, 0 )) < 0) return WSAGetLastError();
+    if ((*sent = sock_send( conn->socket, msg, len, ovr )) < 0) return WSAGetLastError();
     return ERROR_SUCCESS;
 }
 
diff --git a/dlls/winhttp/request.c b/dlls/winhttp/request.c
index c12816b1ad4..31e65a68359 100644
--- a/dlls/winhttp/request.c
+++ b/dlls/winhttp/request.c
@@ -1286,7 +1286,7 @@ static DWORD secure_proxy_connect( struct request *request )
     if (!strA) return ERROR_OUTOFMEMORY;
 
     len = strlen( strA );
-    ret = netconn_send( request->netconn, strA, len, &bytes_sent );
+    ret = netconn_send( request->netconn, strA, len, &bytes_sent, NULL );
     free( strA );
     if (!ret) ret = read_reply( request );
 
@@ -2138,13 +2138,13 @@ static DWORD send_request( struct request *request, const WCHAR *headers, DWORD
 
     send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_SENDING_REQUEST, NULL, 0 );
 
-    ret = netconn_send( request->netconn, wire_req, len, &bytes_sent );
+    ret = netconn_send( request->netconn, wire_req, len, &bytes_sent, NULL );
     free( wire_req );
     if (ret) goto end;
 
     if (optional_len)
     {
-        if ((ret = netconn_send( request->netconn, optional, optional_len, &bytes_sent ))) goto end;
+        if ((ret = netconn_send( request->netconn, optional, optional_len, &bytes_sent, NULL ))) goto end;
         request->optional = optional;
         request->optional_len = optional_len;
         len += optional_len;
@@ -2972,7 +2972,7 @@ static DWORD write_data( struct request *request, const void *buffer, DWORD to_w
     DWORD ret;
     int num_bytes;
 
-    ret = netconn_send( request->netconn, buffer, to_write, &num_bytes );
+    ret = netconn_send( request->netconn, buffer, to_write, &num_bytes, NULL );
 
     if (async)
     {
@@ -3127,11 +3127,11 @@ HINTERNET WINAPI WinHttpWebSocketCompleteUpgrade( HINTERNET hrequest, DWORD_PTR
     return hsocket;
 }
 
-static DWORD send_bytes( struct socket *socket, char *bytes, int len )
+static DWORD send_bytes( struct socket *socket, char *bytes, int len, WSAOVERLAPPED *ovr )
 {
     int count;
     DWORD err;
-    if ((err = netconn_send( socket->request->netconn, bytes, len, &count ))) return err;
+    if ((err = netconn_send( socket->request->netconn, bytes, len, &count, ovr ))) return err;
     return (count == len) ? ERROR_SUCCESS : ERROR_INTERNAL_ERROR;
 }
 
@@ -3141,7 +3141,7 @@ static DWORD send_bytes( struct socket *socket, char *bytes, int len )
 #define CONTROL_BIT (1 << 3)
 
 static DWORD send_frame( struct socket *socket, enum socket_opcode opcode, USHORT status, const char *buf,
-                         DWORD buflen, BOOL final )
+                         DWORD buflen, BOOL final, WSAOVERLAPPED *ovr )
 {
     DWORD i = 0, j, offset = 2, len = buflen;
     DWORD buffer_size, ret = 0, send_size;
@@ -3177,6 +3177,7 @@ static DWORD send_frame( struct socket *socket, enum socket_opcode opcode, USHOR
     buffer_size = len + offset;
     if (len) buffer_size += 4;
     assert( buffer_size - len < MAX_FRAME_BUFFER_SIZE );
+    if (ovr && buffer_size > MAX_FRAME_BUFFER_SIZE) return WSAEWOULDBLOCK;
     if (buffer_size > socket->send_frame_buffer_size && socket->send_frame_buffer_size < MAX_FRAME_BUFFER_SIZE)
     {
         DWORD new_size;
@@ -3217,7 +3218,7 @@ static DWORD send_frame( struct socket *socket, enum socket_opcode opcode, USHOR
         while (j < buflen && offset < MAX_FRAME_BUFFER_SIZE)
             socket->send_frame_buffer[offset++] = buf[j++] ^ mask[i++ % 4];
 
-        if ((ret = send_bytes( socket, socket->send_frame_buffer, offset ))) return ret;
+        if ((ret = send_bytes( socket, socket->send_frame_buffer, offset, ovr ))) return ret;
 
         if (!(send_size -= offset)) break;
         offset = 0;
@@ -3227,6 +3228,16 @@ static DWORD send_frame( struct socket *socket, enum socket_opcode opcode, USHOR
     return ERROR_SUCCESS;
 }
 
+static DWORD complete_send_frame( struct socket *socket, WSAOVERLAPPED *ovr )
+{
+    DWORD retflags, len;
+
+    if (!WSAGetOverlappedResult( socket->request->netconn->socket, ovr, &len, TRUE, &retflags ))
+        return WSAGetLastError();
+
+    return ERROR_SUCCESS;
+}
+
 static void send_io_complete( struct object_header *hdr )
 {
     LONG count = InterlockedDecrement( &hdr->pending_sends );
@@ -3265,11 +3276,12 @@ static void socket_send_complete( struct socket *socket, DWORD ret, WINHTTP_WEB_
     }
 }
 
-static DWORD socket_send( struct socket *socket, WINHTTP_WEB_SOCKET_BUFFER_TYPE type, const void *buf, DWORD len )
+static DWORD socket_send( struct socket *socket, WINHTTP_WEB_SOCKET_BUFFER_TYPE type, const void *buf, DWORD len,
+                          WSAOVERLAPPED *ovr )
 {
     enum socket_opcode opcode = map_buffer_type( type );
 
-    return send_frame( socket, opcode, 0, buf, len, TRUE );
+    return send_frame( socket, opcode, 0, buf, len, TRUE, ovr );
 }
 
 static void CALLBACK task_socket_send( TP_CALLBACK_INSTANCE *instance, void *ctx, TP_WORK *work )
@@ -3278,7 +3290,10 @@ static void CALLBACK task_socket_send( TP_CALLBACK_INSTANCE *instance, void *ctx
     DWORD ret;
 
     TRACE("running %p\n", work);
-    ret = socket_send( s->socket, s->type, s->buf, s->len );
+
+    if (s->complete_async) ret = complete_send_frame( s->socket, &s->ovr );
+    else                   ret = socket_send( s->socket, s->type, s->buf, s->len, NULL );
+
     send_io_complete( &s->socket->hdr );
     socket_send_complete( s->socket, ret, s->type, s->len );
 
@@ -3289,7 +3304,7 @@ static void CALLBACK task_socket_send( TP_CALLBACK_INSTANCE *instance, void *ctx
 DWORD WINAPI WinHttpWebSocketSend( HINTERNET hsocket, WINHTTP_WEB_SOCKET_BUFFER_TYPE type, void *buf, DWORD len )
 {
     struct socket *socket;
-    DWORD ret;
+    DWORD ret = 0;
 
     TRACE("%p, %u, %p, %u\n", hsocket, type, buf, len);
 
@@ -3314,24 +3329,53 @@ DWORD WINAPI WinHttpWebSocketSend( HINTERNET hsocket, WINHTTP_WEB_SOCKET_BUFFER_
 
     if (socket->request->connect->hdr.flags & WINHTTP_FLAG_ASYNC)
     {
+        BOOL async_send, complete_async = FALSE;
         struct socket_send *s;
 
-        if (!(s = malloc( sizeof(*s) ))) return FALSE;
-        s->socket = socket;
-        s->type   = type;
-        s->buf    = buf;
-        s->len    = len;
+        if (!(s = malloc( sizeof(*s) )))
+        {
+            release_object( &socket->hdr );
+            return ERROR_OUTOFMEMORY;
+        }
 
-        addref_object( &socket->hdr );
-        InterlockedIncrement( &socket->hdr.pending_sends );
-        if ((ret = queue_task( &socket->send_q, task_socket_send, s )))
+        async_send = InterlockedIncrement( &socket->hdr.pending_sends ) > 1 || socket->hdr.recursion_count >= 3;
+        if (!async_send)
+        {
+            memset( &s->ovr, 0, sizeof(s->ovr) );
+            if ((ret = socket_send( socket, type, buf, len, &s->ovr )) == WSA_IO_PENDING)
+            {
+                async_send = TRUE;
+                complete_async = TRUE;
+            }
+            else if (ret == WSAEWOULDBLOCK) async_send = TRUE;
+        }
+
+        if (async_send)
+        {
+            s->complete_async = complete_async;
+            s->socket = socket;
+            s->type   = type;
+            s->buf    = buf;
+            s->len    = len;
+
+            addref_object( &socket->hdr );
+            if ((ret = queue_task( &socket->send_q, task_socket_send, s )))
+            {
+                InterlockedDecrement( &socket->hdr.pending_sends );
+                release_object( &socket->hdr );
+                free( s );
+            }
+            else ++socket->hdr.pending_sends;
+        }
+        else
         {
             InterlockedDecrement( &socket->hdr.pending_sends );
-            release_object( &socket->hdr );
             free( s );
+            socket_send_complete( socket, ret, type, len );
+            ret = ERROR_SUCCESS;
         }
     }
-    else ret = socket_send( socket, type, buf, len );
+    else ret = socket_send( socket, type, buf, len, NULL );
 
     release_object( &socket->hdr );
     return ret;
@@ -3418,7 +3462,7 @@ static void CALLBACK task_socket_send_pong( TP_CALLBACK_INSTANCE *instance, void
     struct socket_send *s = ctx;
 
     TRACE("running %p\n", work);
-    send_frame( s->socket, SOCKET_OPCODE_PONG, 0, NULL, 0, TRUE );
+    send_frame( s->socket, SOCKET_OPCODE_PONG, 0, NULL, 0, TRUE, NULL );
     send_io_complete( &s->socket->hdr );
 
     release_object( &s->socket->hdr );
@@ -3445,7 +3489,7 @@ static DWORD socket_send_pong( struct socket *socket )
         }
         return ret;
     }
-    return send_frame( socket, SOCKET_OPCODE_PONG, 0, NULL, 0, TRUE );
+    return send_frame( socket, SOCKET_OPCODE_PONG, 0, NULL, 0, TRUE, NULL );
 }
 
 static DWORD socket_drain( struct socket *socket )
@@ -3611,7 +3655,7 @@ static DWORD socket_shutdown( struct socket *socket, USHORT status, const void *
     DWORD ret;
 
     stop_queue( &socket->send_q );
-    if (!(ret = send_frame( socket, SOCKET_OPCODE_CLOSE, status, reason, len, TRUE )))
+    if (!(ret = send_frame( socket, SOCKET_OPCODE_CLOSE, status, reason, len, TRUE, NULL )))
     {
         socket->state = SOCKET_STATE_SHUTDOWN;
     }
@@ -3697,7 +3741,7 @@ static DWORD socket_close( struct socket *socket, USHORT status, const void *rea
     if (socket->state < SOCKET_STATE_SHUTDOWN)
     {
         stop_queue( &socket->send_q );
-        if ((ret = send_frame( socket, SOCKET_OPCODE_CLOSE, status, reason, len, TRUE ))) goto done;
+        if ((ret = send_frame( socket, SOCKET_OPCODE_CLOSE, status, reason, len, TRUE, NULL ))) goto done;
         socket->state = SOCKET_STATE_SHUTDOWN;
     }
 
diff --git a/dlls/winhttp/tests/notification.c b/dlls/winhttp/tests/notification.c
index 1573dd37faa..e919159ba98 100644
--- a/dlls/winhttp/tests/notification.c
+++ b/dlls/winhttp/tests/notification.c
@@ -62,6 +62,7 @@ struct notification
 #define NF_ALLOW       0x0001  /* notification may or may not happen */
 #define NF_WINE_ALLOW  0x0002  /* wine sends notification when it should not */
 #define NF_SIGNAL      0x0004  /* signal wait handle when notified */
+#define NF_MAIN_THREAD 0x0008  /* the operation completes synchronously and callback is called from the main thread */
 
 struct info
 {
@@ -71,6 +72,7 @@ struct info
     unsigned int index;
     HANDLE wait;
     unsigned int line;
+    DWORD main_thread_id;
     DWORD last_thread_id;
     DWORD last_status;
 };
@@ -111,6 +113,12 @@ static void CALLBACK check_notification( HINTERNET handle, DWORD_PTR context, DW
     ok(status_ok, "%u: expected status 0x%08x got 0x%08x\n", info->line, info->test[info->index].status, status);
     ok(function_ok, "%u: expected function %u got %u\n", info->line, info->test[info->index].function, info->function);
 
+    if (info->test[info->index].flags & NF_MAIN_THREAD)
+    {
+        ok(GetCurrentThreadId() == info->main_thread_id, "%u: expected callback to be called from the same thread\n",
+                info->line);
+    }
+
     if (status_ok && function_ok && info->test[info->index++].flags & NF_SIGNAL)
     {
         SetEvent( info->wait );
@@ -184,6 +192,7 @@ static void setup_test( struct info *info, enum api function, unsigned int line
                        info->test[info->index].function, function);
     info->last_thread_id = 0xdeadbeef;
     info->last_status = 0xdeadbeef;
+    info->main_thread_id = GetCurrentThreadId();
 }
 
 static void end_test( struct info *info, unsigned int line )
@@ -658,8 +667,8 @@ static const struct notification websocket_test[] =
     { winhttp_receive_response,           WINHTTP_CALLBACK_STATUS_RESPONSE_RECEIVED },
     { winhttp_receive_response,           WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE, NF_SIGNAL },
     { winhttp_websocket_complete_upgrade, WINHTTP_CALLBACK_STATUS_HANDLE_CREATED, NF_SIGNAL },
-    { winhttp_websocket_send,             WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE, NF_SIGNAL },
-    { winhttp_websocket_send,             WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE, NF_SIGNAL },
+    { winhttp_websocket_send,             WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE, NF_MAIN_THREAD | NF_SIGNAL },
+    { winhttp_websocket_send,             WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE, NF_MAIN_THREAD | NF_SIGNAL },
     { winhttp_websocket_shutdown,         WINHTTP_CALLBACK_STATUS_SHUTDOWN_COMPLETE, NF_SIGNAL },
     { winhttp_websocket_receive,          WINHTTP_CALLBACK_STATUS_READ_COMPLETE, NF_SIGNAL },
     { winhttp_websocket_receive,          WINHTTP_CALLBACK_STATUS_READ_COMPLETE, NF_SIGNAL },
@@ -792,6 +801,9 @@ static void test_websocket(BOOL secure)
 
     for (i = 0; i < 2; ++i)
     {
+        /* The send is executed synchronously (even if sending a reasonably big buffer exceeding SSL buffer size).
+         * It is possible to trigger queueing the send into another thread but that involves sending a considerable
+         * amount of big enough buffers. */
         setup_test( &info, winhttp_websocket_send, __LINE__ );
         err = pWinHttpWebSocketSend( socket, WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE,
                                      (void*)"hello", sizeof("hello") );
diff --git a/dlls/winhttp/winhttp_private.h b/dlls/winhttp/winhttp_private.h
index db235fbf622..7c904071a14 100644
--- a/dlls/winhttp/winhttp_private.h
+++ b/dlls/winhttp/winhttp_private.h
@@ -297,6 +297,8 @@ struct socket_send
     WINHTTP_WEB_SOCKET_BUFFER_TYPE type;
     const void *buf;
     DWORD len;
+    WSAOVERLAPPED ovr;
+    BOOL complete_async;
 };
 
 struct socket_receive
@@ -331,7 +333,7 @@ ULONG netconn_query_data_available( struct netconn * ) DECLSPEC_HIDDEN;
 DWORD netconn_recv( struct netconn *, void *, size_t, int, int * ) DECLSPEC_HIDDEN;
 DWORD netconn_resolve( WCHAR *, INTERNET_PORT, struct sockaddr_storage *, int ) DECLSPEC_HIDDEN;
 DWORD netconn_secure_connect( struct netconn *, WCHAR *, DWORD, CredHandle *, BOOL ) DECLSPEC_HIDDEN;
-DWORD netconn_send( struct netconn *, const void *, size_t, int * ) DECLSPEC_HIDDEN;
+DWORD netconn_send( struct netconn *, const void *, size_t, int *, WSAOVERLAPPED * ) DECLSPEC_HIDDEN;
 DWORD netconn_set_timeout( struct netconn *, BOOL, int ) DECLSPEC_HIDDEN;
 BOOL netconn_is_alive( struct netconn * ) DECLSPEC_HIDDEN;
 const void *netconn_get_certificate( struct netconn * ) DECLSPEC_HIDDEN;
-- 
2.34.1




More information about the wine-devel mailing list