[PATCH 1/3] netio.sys: Support multiple simultaneous async operations for socket.

Paul Gofman pgofman at codeweavers.com
Fri Jun 19 04:58:02 CDT 2020


Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
 dlls/netio.sys/netio.c | 162 +++++++++++++++++++++++++++--------------
 1 file changed, 106 insertions(+), 56 deletions(-)

diff --git a/dlls/netio.sys/netio.c b/dlls/netio.sys/netio.c
index ed99a087a58..6cadff2e87c 100644
--- a/dlls/netio.sys/netio.c
+++ b/dlls/netio.sys/netio.c
@@ -53,6 +53,15 @@ struct listen_socket_callback_context
     SOCKET acceptor;
 };
 
+#define MAX_PENDING_IO 10
+
+struct wsk_pending_io
+{
+    OVERLAPPED ovr;
+    TP_WAIT *tp_wait;
+    IRP *irp;
+};
+
 struct wsk_socket_internal
 {
     WSK_SOCKET wsk_socket;
@@ -63,12 +72,11 @@ struct wsk_socket_internal
     ADDRESS_FAMILY address_family;
     USHORT socket_type;
     ULONG protocol;
-    OVERLAPPED ovr;
-    TP_WAIT *tp_wait;
-    IRP *pending_irp;
 
     CRITICAL_SECTION cs_socket;
 
+    struct wsk_pending_io pending_io[MAX_PENDING_IO];
+
     union
     {
         struct listen_socket_callback_context listen_socket_callback_context;
@@ -134,14 +142,9 @@ static inline void unlock_socket(struct wsk_socket_internal *socket)
     LeaveCriticalSection(&socket->cs_socket);
 }
 
-static void socket_init(struct wsk_socket_internal *socket, PTP_WAIT_CALLBACK socket_async_callback)
+static void socket_init(struct wsk_socket_internal *socket)
 {
     InitializeCriticalSection(&socket->cs_socket);
-    if (socket_async_callback)
-    {
-        socket->ovr.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
-        socket->tp_wait = CreateThreadpoolWait(socket_async_callback, socket, NULL);
-    }
 }
 
 static void dispatch_irp(IRP *irp, NTSTATUS status)
@@ -152,6 +155,56 @@ static void dispatch_irp(IRP *irp, NTSTATUS status)
     IoCompleteRequest(irp, IO_NO_INCREMENT);
 }
 
+static struct wsk_pending_io *allocate_pending_io(struct wsk_socket_internal *socket,
+        PTP_WAIT_CALLBACK socket_async_callback, IRP *irp)
+{
+    struct wsk_pending_io *io = socket->pending_io;
+    unsigned int i;
+
+    for (i = 0; i < ARRAY_SIZE(socket->pending_io); ++i)
+        if (!io[i].irp)
+            break;
+
+    if (i == ARRAY_SIZE(socket->pending_io))
+    {
+        FIXME("Pending io requests count exceeds limit.\n");
+        return NULL;
+    }
+
+    io[i].irp = irp;
+
+    if (io[i].tp_wait)
+        return &io[i];
+
+    io[i].ovr.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
+    io[i].tp_wait = CreateThreadpoolWait(socket_async_callback, socket, NULL);
+
+    return &io[i];
+}
+
+static struct wsk_pending_io *find_pending_io(struct wsk_socket_internal *socket, TP_WAIT *tp_wait)
+{
+    unsigned int i;
+
+    for (i = 0; i < ARRAY_SIZE(socket->pending_io); ++i)
+    {
+        if (socket->pending_io[i].tp_wait == tp_wait)
+            return &socket->pending_io[i];
+    }
+
+    FIXME("Pending io not found for tp_wait %p.\n", tp_wait);
+    return NULL;
+}
+
+static void dispatch_pending_io(struct wsk_pending_io *io, NTSTATUS status, ULONG_PTR information)
+{
+    TRACE("io %p, status %#x, information %#lx.\n", io, status, information);
+
+    io->irp->IoStatus.Information = information;
+    dispatch_irp(io->irp, status);
+    io->irp = NULL;
+}
+
 static NTSTATUS WINAPI wsk_control_socket(WSK_SOCKET *socket, WSK_CONTROL_SOCKET_TYPE request_type,
         ULONG control_code, ULONG level, SIZE_T input_size, void *input_buffer, SIZE_T output_size,
         void *output_buffer, SIZE_T *output_size_returned, IRP *irp)
@@ -168,18 +221,29 @@ static NTSTATUS WINAPI wsk_close_socket(WSK_SOCKET *socket, IRP *irp)
 {
     struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket);
     NTSTATUS status;
+    unsigned int i;
 
     TRACE("socket %p, irp %p.\n", socket, irp);
 
     lock_socket(s);
 
-    if (s->tp_wait)
+    for (i = 0; i < ARRAY_SIZE(s->pending_io); ++i)
     {
-        CancelIoEx((HANDLE)s->s, &s->ovr);
-        unlock_socket(s);
-        WaitForThreadpoolWaitCallbacks(s->tp_wait, FALSE);
-        lock_socket(s);
-        CloseThreadpoolWait(s->tp_wait);
+        struct wsk_pending_io *io = &s->pending_io[i];
+
+        if (io->tp_wait)
+        {
+            CancelIoEx((HANDLE)s->s, &io->ovr);
+            SetThreadpoolWait(io->tp_wait, NULL, NULL);
+            unlock_socket(s);
+            WaitForThreadpoolWaitCallbacks(io->tp_wait, FALSE);
+            lock_socket(s);
+            CloseThreadpoolWait(io->tp_wait);
+            CloseHandle(io->ovr.hEvent);
+        }
+
+        if (io->irp)
+            dispatch_pending_io(io, STATUS_CANCELLED, 0);
     }
 
     if (s->flags & WSK_FLAG_LISTEN_SOCKET && s->callback_context.listen_socket_callback_context.acceptor)
@@ -187,15 +251,6 @@ static NTSTATUS WINAPI wsk_close_socket(WSK_SOCKET *socket, IRP *irp)
 
     status = closesocket(s->s) ? sock_error_to_ntstatus(WSAGetLastError()) : STATUS_SUCCESS;
 
-    if (s->ovr.hEvent)
-        CloseHandle(s->ovr.hEvent);
-
-    if (s->pending_irp)
-    {
-        s->pending_irp->IoStatus.Information = 0;
-        dispatch_irp(s->pending_irp, STATUS_CANCELLED);
-    }
-
     unlock_socket(s);
     DeleteCriticalSection(&s->cs_socket);
     heap_free(socket);
@@ -230,18 +285,16 @@ static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULO
     return STATUS_PENDING;
 }
 
-static void create_accept_socket(struct wsk_socket_internal *socket)
+static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_pending_io *io)
 {
     struct listen_socket_callback_context *context
             = &socket->callback_context.listen_socket_callback_context;
     struct wsk_socket_internal *accept_socket;
-    NTSTATUS status;
 
     if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket))))
     {
         ERR("No memory.\n");
-        status = STATUS_NO_MEMORY;
-        socket->pending_irp->IoStatus.Information = 0;
+        dispatch_pending_io(io, STATUS_NO_MEMORY, 0);
     }
     else
     {
@@ -254,15 +307,11 @@ static void create_accept_socket(struct wsk_socket_internal *socket)
         accept_socket->address_family = socket->address_family;
         accept_socket->protocol = socket->protocol;
         accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET;
-        socket_init(accept_socket, NULL);
+        socket_init(accept_socket);
         /* TODO: fill local and remote addresses. */
 
-        socket->pending_irp->IoStatus.Information = (ULONG_PTR)&accept_socket->wsk_socket;
-        status = STATUS_SUCCESS;
+        dispatch_pending_io(io, STATUS_SUCCESS, (ULONG_PTR)&accept_socket->wsk_socket);
     }
-    TRACE("status %#x.\n", status);
-    dispatch_irp(socket->pending_irp, status);
-    socket->pending_irp = NULL;
 }
 
 static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait,
@@ -270,24 +319,24 @@ static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_
 {
     struct listen_socket_callback_context *context;
     struct wsk_socket_internal *socket = socket_;
+    struct wsk_pending_io *io;
     DWORD size;
 
     TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result);
 
     lock_socket(socket);
     context = &socket->callback_context.listen_socket_callback_context;
+    io = find_pending_io(socket, wait);
 
-    if (GetOverlappedResult((HANDLE)socket->s, &socket->ovr, &size, FALSE))
+    if (GetOverlappedResult((HANDLE)socket->s, &io->ovr, &size, FALSE))
     {
-        create_accept_socket(socket);
+        create_accept_socket(socket, io);
     }
     else
     {
         closesocket(context->acceptor);
         context->acceptor = 0;
-        socket->pending_irp->IoStatus.Information = 0;
-        dispatch_irp(socket->pending_irp, socket->ovr.Internal);
-        socket->pending_irp = NULL;
+        dispatch_pending_io(io, io->ovr.Internal, 0);
     }
     unlock_socket(socket);
 }
@@ -314,8 +363,8 @@ static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *
     struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(listen_socket);
     static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT;
     struct listen_socket_callback_context *context;
+    struct wsk_pending_io *io;
     SOCKET acceptor;
-    NTSTATUS status;
     DWORD size;
     int error;
 
@@ -329,44 +378,47 @@ static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *
 
     if (!InitOnceExecuteOnce(&init_once, init_accept_functions, (void *)s->s, NULL))
     {
-        status = STATUS_UNSUCCESSFUL;
-        dispatch_irp(irp, status);
-        return status;
+        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
+        return STATUS_PENDING;
     }
 
     lock_socket(s);
+    if (!(io = allocate_pending_io(s, accept_callback, irp)))
+    {
+        irp->IoStatus.Information = 0;
+        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
+        unlock_socket(s);
+        return STATUS_PENDING;
+    }
+
     context = &s->callback_context.listen_socket_callback_context;
     if ((acceptor = WSASocketW(s->address_family, s->socket_type, s->protocol, NULL, 0, WSA_FLAG_OVERLAPPED))
             == INVALID_SOCKET)
     {
-        status = sock_error_to_ntstatus(WSAGetLastError());
-        dispatch_irp(irp, status);
+        dispatch_pending_io(io, sock_error_to_ntstatus(WSAGetLastError()), 0);
         unlock_socket(s);
-        return status;
+        return STATUS_PENDING;
     }
 
-    s->pending_irp = irp;
     context->remote_address = remote_address;
     context->client_dispatch = accept_socket_dispatch;
     context->client_context = accept_socket_context;
     context->acceptor = acceptor;
 
     if (pAcceptEx(s->s, acceptor, context->addr_buffer, 0,
-            sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, &size, &s->ovr))
+            sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, &size, &io->ovr))
     {
-        create_accept_socket(s);
+        create_accept_socket(s, io);
     }
     else if ((error = WSAGetLastError()) == ERROR_IO_PENDING)
     {
-        SetThreadpoolWait(s->tp_wait, s->ovr.hEvent, NULL);
+        SetThreadpoolWait(io->tp_wait, io->ovr.hEvent, NULL);
     }
     else
     {
         closesocket(acceptor);
         context->acceptor = 0;
-        irp->IoStatus.Information = 0;
-        dispatch_irp(irp, sock_error_to_ntstatus(error));
-        s->pending_irp = NULL;
+        dispatch_pending_io(io, sock_error_to_ntstatus(error), 0);
     }
     unlock_socket(s);
 
@@ -490,7 +542,6 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
         PETHREAD owning_thread, SECURITY_DESCRIPTOR *security_descriptor, IRP *irp)
 {
     struct wsk_socket_internal *socket;
-    PTP_WAIT_CALLBACK async_callback;
     NTSTATUS status;
     SOCKET s;
 
@@ -532,7 +583,6 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
     {
         case WSK_FLAG_LISTEN_SOCKET:
             socket->wsk_socket.Dispatch = &wsk_provider_listen_dispatch;
-            async_callback = accept_callback;
             break;
 
         case WSK_FLAG_CONNECTION_SOCKET:
@@ -547,7 +597,7 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
             goto done;
     }
 
-    socket_init(socket, async_callback);
+    socket_init(socket);
 
     irp->IoStatus.Information = (ULONG_PTR)&socket->wsk_socket;
     status = STATUS_SUCCESS;
-- 
2.26.2




More information about the wine-devel mailing list