[PATCH v2 3/4] netio.sys: Implement wsk_accept() function.

Paul Gofman pgofman at codeweavers.com
Wed Jun 17 12:21:59 CDT 2020


Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
 dlls/netio.sys/netio.c | 226 +++++++++++++++++++++++++++++++++++++++--
 1 file changed, 220 insertions(+), 6 deletions(-)

diff --git a/dlls/netio.sys/netio.c b/dlls/netio.sys/netio.c
index 36a6a41b994..04cedb7396a 100644
--- a/dlls/netio.sys/netio.c
+++ b/dlls/netio.sys/netio.c
@@ -44,14 +44,41 @@ struct _WSK_CLIENT
     WSK_CLIENT_NPI *client_npi;
 };
 
+struct listen_socket_callback_context
+{
+    SOCKADDR *remote_address;
+    const void *client_dispatch;
+    void *client_context;
+    char addr_buffer[2 * (sizeof(SOCKADDR) + 16)];
+    SOCKET acceptor;
+};
+
 struct wsk_socket_internal
 {
     WSK_SOCKET wsk_socket;
     SOCKET s;
     const void *client_dispatch;
     void *client_context;
+    ULONG flags;
+    ADDRESS_FAMILY address_family;
+    USHORT socket_type;
+    ULONG protocol;
+    OVERLAPPED ovr;
+    TP_WAIT *tp_wait;
+    IRP *pending_irp;
+
+    CRITICAL_SECTION cs_socket;
+
+    union
+    {
+        struct listen_socket_callback_context listen_socket_callback_context;
+    }
+    callback_context;
 };
 
+static LPFN_ACCEPTEX pAcceptEx;
+static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch;
+
 static inline struct wsk_socket_internal *wsk_socket_internal_from_wsk_socket(WSK_SOCKET *wsk_socket)
 {
     return CONTAINING_RECORD(wsk_socket, struct wsk_socket_internal, wsk_socket);
@@ -79,7 +106,7 @@ static NTSTATUS sock_error_to_ntstatus(DWORD err)
         case WSAEAFNOSUPPORT:
         case WSAEPROTOTYPE:        return STATUS_NOT_SUPPORTED;
         case WSAENOPROTOOPT:       return STATUS_INVALID_PARAMETER;
-        case WSAEOPNOTSUPP:        return STATUS_NOT_SUPPORTED;
+        case WSAEOPNOTSUPP:        return STATUS_NOT_IMPLEMENTED;
         case WSAEADDRINUSE:        return STATUS_ADDRESS_ALREADY_ASSOCIATED;
         case WSAEADDRNOTAVAIL:     return STATUS_INVALID_PARAMETER;
         case WSAECONNREFUSED:      return STATUS_CONNECTION_REFUSED;
@@ -97,6 +124,26 @@ static NTSTATUS sock_error_to_ntstatus(DWORD err)
     }
 }
 
+static inline void lock_socket(struct wsk_socket_internal *socket)
+{
+    EnterCriticalSection(&socket->cs_socket);
+}
+
+static inline void unlock_socket(struct wsk_socket_internal *socket)
+{
+    LeaveCriticalSection(&socket->cs_socket);
+}
+
+static void socket_init(struct wsk_socket_internal *socket, PTP_WAIT_CALLBACK socket_async_callback)
+{
+    InitializeCriticalSection(&socket->cs_socket);
+    if (socket_async_callback)
+    {
+        socket->ovr.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
+        socket->tp_wait = CreateThreadpoolWait(socket_async_callback, socket, NULL);
+    }
+}
+
 static void dispatch_irp(IRP *irp, NTSTATUS status)
 {
     irp->IoStatus.u.Status = status;
@@ -124,7 +171,33 @@ static NTSTATUS WINAPI wsk_close_socket(WSK_SOCKET *socket, IRP *irp)
 
     TRACE("socket %p, irp %p.\n", socket, irp);
 
+    lock_socket(s);
+
+    if (s->tp_wait)
+    {
+        CancelIoEx((HANDLE)s->s, &s->ovr);
+        unlock_socket(s);
+        WaitForThreadpoolWaitCallbacks(s->tp_wait, FALSE);
+        lock_socket(s);
+        CloseThreadpoolWait(s->tp_wait);
+    }
+
+    if (s->flags & WSK_FLAG_LISTEN_SOCKET && s->callback_context.listen_socket_callback_context.acceptor)
+        closesocket(s->callback_context.listen_socket_callback_context.acceptor);
+
     status = closesocket(s->s) ? sock_error_to_ntstatus(WSAGetLastError()) : STATUS_SUCCESS;
+
+    if (s->ovr.hEvent)
+        CloseHandle(s->ovr.hEvent);
+
+    if (s->pending_irp)
+    {
+        s->pending_irp->IoStatus.Information = 0;
+        dispatch_irp(s->pending_irp, STATUS_CANCELLED);
+    }
+
+    unlock_socket(s);
+    DeleteCriticalSection(&s->cs_socket);
     heap_free(socket);
 
     irp->IoStatus.Information = 0;
@@ -146,6 +219,8 @@ static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULO
 
     if (bind(s->s, local_address, sizeof(*local_address)))
         status = sock_error_to_ntstatus(WSAGetLastError());
+    else if (s->flags & WSK_FLAG_LISTEN_SOCKET && listen(s->s, SOMAXCONN))
+        status = sock_error_to_ntstatus(WSAGetLastError());
     else
         status = STATUS_SUCCESS;
 
@@ -155,16 +230,147 @@ static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULO
     return STATUS_PENDING;
 }
 
+static void create_accept_socket(struct wsk_socket_internal *socket)
+{
+    struct listen_socket_callback_context *context
+            = &socket->callback_context.listen_socket_callback_context;
+    struct wsk_socket_internal *accept_socket;
+    NTSTATUS status;
+
+    if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket))))
+    {
+        ERR("No memory.\n");
+        status = STATUS_NO_MEMORY;
+        socket->pending_irp->IoStatus.Information = 0;
+    }
+    else
+    {
+        TRACE("accept_socket %p.\n", accept_socket);
+        accept_socket->wsk_socket.Dispatch = &wsk_provider_connection_dispatch;
+        accept_socket->s = context->acceptor;
+        accept_socket->client_dispatch = context->client_dispatch;
+        accept_socket->client_context = context->client_context;
+        accept_socket->socket_type = socket->socket_type;
+        accept_socket->address_family = socket->address_family;
+        accept_socket->protocol = socket->protocol;
+        accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET;
+        socket_init(accept_socket, NULL);
+        /* TODO: fill local and remote addresses. */
+
+        socket->pending_irp->IoStatus.Information = (ULONG_PTR)&accept_socket->wsk_socket;
+        status = STATUS_SUCCESS;
+    }
+    TRACE("status %#x.\n", status);
+    dispatch_irp(socket->pending_irp, status);
+    socket->pending_irp = NULL;
+}
+
+static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait,
+        TP_WAIT_RESULT wait_result)
+{
+    struct listen_socket_callback_context *context;
+    struct wsk_socket_internal *socket = socket_;
+    DWORD size;
+
+    TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result);
+
+    lock_socket(socket);
+    context = &socket->callback_context.listen_socket_callback_context;
+
+    if (GetOverlappedResult((HANDLE)socket->s, &socket->ovr, &size, FALSE))
+    {
+        create_accept_socket(socket);
+    }
+    else
+    {
+        closesocket(context->acceptor);
+        context->acceptor = 0;
+        socket->pending_irp->IoStatus.Information = 0;
+        dispatch_irp(socket->pending_irp, socket->ovr.Internal);
+        socket->pending_irp = NULL;
+    }
+    unlock_socket(socket);
+}
+
+static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **context)
+{
+    GUID acceptex_guid = WSAID_ACCEPTEX;
+    SOCKET s = (SOCKET)param;
+    DWORD size;
+
+    if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &acceptex_guid, sizeof(acceptex_guid),
+            &pAcceptEx, sizeof(pAcceptEx), &size, NULL, NULL))
+    {
+        ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
+        return FALSE;
+    }
+    return TRUE;
+}
+
 static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *accept_socket_context,
         const WSK_CLIENT_CONNECTION_DISPATCH *accept_socket_dispatch, SOCKADDR *local_address,
         SOCKADDR *remote_address, IRP *irp)
 {
-    FIXME("listen_socket %p, flags %#x, accept_socket_context %p, accept_socket_dispatch %p, "
-            "local_address %p, remote_address %p, irp %p stub.\n",
+    struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(listen_socket);
+    static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT;
+    struct listen_socket_callback_context *context;
+    SOCKET acceptor;
+    NTSTATUS status;
+    DWORD size;
+    int error;
+
+    TRACE("listen_socket %p, flags %#x, accept_socket_context %p, accept_socket_dispatch %p, "
+            "local_address %p, remote_address %p, irp %p.\n",
             listen_socket, flags, accept_socket_context, accept_socket_dispatch, local_address,
             remote_address, irp);
 
-    return STATUS_NOT_IMPLEMENTED;
+    if (!irp)
+        return STATUS_INVALID_PARAMETER;
+
+    if (!InitOnceExecuteOnce(&init_once, init_accept_functions, (void *)s->s, NULL))
+    {
+        status = STATUS_UNSUCCESSFUL;
+        dispatch_irp(irp, status);
+        return status;
+    }
+
+    lock_socket(s);
+    context = &s->callback_context.listen_socket_callback_context;
+    if ((acceptor = WSASocketW(s->address_family, s->socket_type, s->protocol, NULL, 0, WSA_FLAG_OVERLAPPED))
+            == INVALID_SOCKET)
+    {
+        status = sock_error_to_ntstatus(WSAGetLastError());;
+        dispatch_irp(irp, status);
+        unlock_socket(s);
+        return status;
+    }
+
+    s->pending_irp = irp;
+    context->remote_address = remote_address;
+    context->client_dispatch = accept_socket_dispatch;
+    context->client_context = accept_socket_context;
+    context->acceptor = acceptor;
+
+    if (pAcceptEx(s->s, acceptor, context->addr_buffer, 0,
+            sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, &size, &s->ovr))
+    {
+        create_accept_socket(s);
+    }
+    else if ((error = WSAGetLastError()) == ERROR_IO_PENDING)
+    {
+        SetThreadpoolWait(s->tp_wait, s->ovr.hEvent, NULL);
+    }
+    else
+    {
+        closesocket(acceptor);
+        context->acceptor = 0;
+        irp->IoStatus.Information = 0;
+        dispatch_irp(irp, sock_error_to_ntstatus(error));
+        s->pending_irp = NULL;
+    }
+    unlock_socket(s);
+
+    return STATUS_PENDING;
 }
 
 static NTSTATUS WINAPI wsk_inspect_complete(WSK_SOCKET *listen_socket, WSK_INSPECT_ID *inspect_id,
@@ -284,6 +490,7 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
         PETHREAD owning_thread, SECURITY_DESCRIPTOR *security_descriptor, IRP *irp)
 {
     struct wsk_socket_internal *socket;
+    PTP_WAIT_CALLBACK async_callback;
     NTSTATUS status;
     SOCKET s;
 
@@ -300,13 +507,13 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
 
     irp->IoStatus.Information = 0;
 
-    if ((s = WSASocketW(address_family, socket_type, protocol, NULL, 0, 0)) == INVALID_SOCKET)
+    if ((s = WSASocketW(address_family, socket_type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED)) == INVALID_SOCKET)
     {
         status = sock_error_to_ntstatus(WSAGetLastError());
         goto done;
     }
 
-    if (!(socket = heap_alloc(sizeof(*socket))))
+    if (!(socket = heap_alloc_zero(sizeof(*socket))))
     {
         status = STATUS_NO_MEMORY;
         closesocket(s);
@@ -316,11 +523,16 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
     socket->s = s;
     socket->client_dispatch = dispatch;
     socket->client_context = socket_context;
+    socket->socket_type = socket_type;
+    socket->flags = flags;
+    socket->address_family = address_family;
+    socket->protocol = protocol;
 
     switch (flags)
     {
         case WSK_FLAG_LISTEN_SOCKET:
             socket->wsk_socket.Dispatch = &wsk_provider_listen_dispatch;
+            async_callback = accept_callback;
             break;
 
         case WSK_FLAG_CONNECTION_SOCKET:
@@ -335,6 +547,8 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
             goto done;
     }
 
+    socket_init(socket, async_callback);
+
     irp->IoStatus.Information = (ULONG_PTR)&socket->wsk_socket;
     status = STATUS_SUCCESS;
 
-- 
2.26.2




More information about the wine-devel mailing list