[PATCH 1/3] ws2_32: Use server-side async I/O in AcceptEx().

Zebediah Figura z.figura12 at gmail.com
Fri Oct 16 23:13:35 CDT 2020


Signed-off-by: Zebediah Figura <z.figura12 at gmail.com>
---
 dlls/ws2_32/socket.c     | 183 +++-------------------
 dlls/ws2_32/tests/sock.c |  27 ++--
 include/wine/afd.h       |   7 +
 server/sock.c            | 329 +++++++++++++++++++++++++++++++++++++--
 4 files changed, 352 insertions(+), 194 deletions(-)

diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c
index 2eb1e1a7307..e93c4ccf589 100644
--- a/dlls/ws2_32/socket.c
+++ b/dlls/ws2_32/socket.c
@@ -2481,99 +2481,6 @@ static NTSTATUS WS2_async_recv( void *user, IO_STATUS_BLOCK *iosb, NTSTATUS stat
     return status;
 }
 
-/***********************************************************************
- *              WS2_async_accept_recv            (INTERNAL)
- *
- * This function is used to finish the read part of an accept request. It is
- * needed to place the completion on the correct socket (listener).
- */
-static NTSTATUS WS2_async_accept_recv( void *user, IO_STATUS_BLOCK *iosb, NTSTATUS status )
-{
-    struct ws2_accept_async *wsa = user;
-
-    status = WS2_async_recv( wsa->read, iosb, status );
-    if (status == STATUS_PENDING)
-        return status;
-
-    if (wsa->cvalue)
-        WS_AddCompletion( HANDLE2SOCKET(wsa->listen_socket), wsa->cvalue, iosb->u.Status, iosb->Information, TRUE );
-
-    release_async_io( &wsa->io );
-    return status;
-}
-
-/***********************************************************************
- *              WS2_async_accept                (INTERNAL)
- *
- * This is the function called to satisfy the AcceptEx callback
- */
-static NTSTATUS WS2_async_accept( void *user, IO_STATUS_BLOCK *iosb, NTSTATUS status )
-{
-    struct ws2_accept_async *wsa = user;
-    int len;
-    char *addr;
-
-    TRACE("status: 0x%x listen: %p, accept: %p\n", status, wsa->listen_socket, wsa->accept_socket);
-
-    if (status == STATUS_ALERTED)
-    {
-        obj_handle_t accept_handle = wine_server_obj_handle( wsa->accept_socket );
-        IO_STATUS_BLOCK io;
-
-        status = NtDeviceIoControlFile( wsa->listen_socket, NULL, NULL, NULL, &io, IOCTL_AFD_ACCEPT_INTO,
-                                        &accept_handle, sizeof(accept_handle), NULL, 0 );
-
-        if (NtStatusToWSAError( status ) == WSAEWOULDBLOCK)
-            return STATUS_PENDING;
-
-        if (status == STATUS_INVALID_HANDLE)
-        {
-            FIXME("AcceptEx accepting socket closed but request was not cancelled\n");
-            status = STATUS_CANCELLED;
-        }
-    }
-    else if (status == STATUS_HANDLES_CLOSED)
-        status = STATUS_CANCELLED;  /* strange windows behavior */
-
-    if (status != STATUS_SUCCESS)
-        goto finish;
-
-    /* WS2 Spec says size param is extra 16 bytes long...what do we put in it? */
-    addr = ((char *)wsa->buf) + wsa->data_len;
-    len = wsa->local_len - sizeof(int);
-    WS_getsockname(HANDLE2SOCKET(wsa->accept_socket),
-                   (struct WS_sockaddr *)(addr + sizeof(int)), &len);
-    *(int *)addr = len;
-
-    addr += wsa->local_len;
-    len = wsa->remote_len - sizeof(int);
-    WS_getpeername(HANDLE2SOCKET(wsa->accept_socket),
-                   (struct WS_sockaddr *)(addr + sizeof(int)), &len);
-    *(int *)addr = len;
-
-    if (!wsa->read)
-        goto finish;
-
-    wsa->io.callback = WS2_async_accept_recv;
-    status = register_async( ASYNC_TYPE_READ, wsa->accept_socket, &wsa->io,
-                             wsa->user_overlapped->hEvent, NULL, NULL, iosb);
-
-    if (status != STATUS_PENDING)
-        goto finish;
-
-    /* The APC has finished but no completion should be sent for the operation yet, additional processing
-     * needs to be performed by WS2_async_accept_recv() first. */
-    return STATUS_MORE_PROCESSING_REQUIRED;
-
-finish:
-    iosb->u.Status = status;
-    iosb->Information = 0;
-
-    if (wsa->read) release_async_io( &wsa->read->io );
-    release_async_io( &wsa->io );
-    return status;
-}
-
 /***********************************************************************
  *              WS2_send                (INTERNAL)
  *
@@ -2820,23 +2727,30 @@ error:
 /***********************************************************************
  *     AcceptEx
  */
-static BOOL WINAPI WS2_AcceptEx(SOCKET listener, SOCKET acceptor, PVOID dest, DWORD dest_len,
-                         DWORD local_addr_len, DWORD rem_addr_len, LPDWORD received,
-                         LPOVERLAPPED overlapped)
+static BOOL WINAPI WS2_AcceptEx( SOCKET listener, SOCKET acceptor, void *dest, DWORD recv_len,
+                                 DWORD local_len, DWORD remote_len, DWORD *ret_len, OVERLAPPED *overlapped)
 {
-    DWORD status;
-    struct ws2_accept_async *wsa;
-    int fd;
+    struct afd_accept_into_params params =
+    {
+        .accept_handle = acceptor,
+        .recv_len = recv_len,
+        .local_len = local_len,
+    };
+    void *cvalue = NULL;
+    NTSTATUS status;
 
-    TRACE("(%04lx, %04lx, %p, %d, %d, %d, %p, %p)\n", listener, acceptor, dest, dest_len, local_addr_len,
-                                                  rem_addr_len, received, overlapped);
+    TRACE( "listener %#lx, acceptor %#lx, dest %p, recv_len %u, local_len %u, remote_len %u, ret_len %p, "
+           "overlapped %p\n", listener, acceptor, dest, recv_len, local_len, remote_len, ret_len, overlapped );
 
     if (!overlapped)
     {
         SetLastError(WSA_INVALID_PARAMETER);
         return FALSE;
     }
+
+    if (!((ULONG_PTR)overlapped->hEvent & 1)) cvalue = overlapped;
     overlapped->Internal = STATUS_PENDING;
+    overlapped->InternalHigh = 0;
 
     if (!dest)
     {
@@ -2844,72 +2758,19 @@ static BOOL WINAPI WS2_AcceptEx(SOCKET listener, SOCKET acceptor, PVOID dest, DW
         return FALSE;
     }
 
-    if (!rem_addr_len)
+    if (!remote_len)
     {
         SetLastError(WSAEFAULT);
         return FALSE;
     }
 
-    fd = get_sock_fd( listener, FILE_READ_DATA, NULL );
-    if (fd == -1) return FALSE;
-    release_sock_fd( listener, fd );
-
-    fd = get_sock_fd( acceptor, FILE_READ_DATA, NULL );
-    if (fd == -1) return FALSE;
-    release_sock_fd( acceptor, fd );
+    status = NtDeviceIoControlFile( SOCKET2HANDLE(listener), overlapped->hEvent, NULL, cvalue,
+                                    (IO_STATUS_BLOCK *)overlapped, IOCTL_AFD_ACCEPT_INTO, &params, sizeof(params),
+                                    dest, recv_len + local_len + remote_len );
 
-    wsa = (struct ws2_accept_async *)alloc_async_io( sizeof(*wsa), WS2_async_accept );
-    if(!wsa)
-    {
-        SetLastError(WSAEFAULT);
-        return FALSE;
-    }
-
-    wsa->listen_socket   = SOCKET2HANDLE(listener);
-    wsa->accept_socket   = SOCKET2HANDLE(acceptor);
-    wsa->user_overlapped = overlapped;
-    wsa->cvalue          = !((ULONG_PTR)overlapped->hEvent & 1) ? (ULONG_PTR)overlapped : 0;
-    wsa->buf             = dest;
-    wsa->data_len        = dest_len;
-    wsa->local_len       = local_addr_len;
-    wsa->remote_len      = rem_addr_len;
-    wsa->read            = NULL;
-
-    if (wsa->data_len)
-    {
-        /* set up a read request if we need it */
-        wsa->read = (struct ws2_async *)alloc_async_io( offsetof(struct ws2_async, iovec[1]), WS2_async_accept_recv );
-        if (!wsa->read)
-        {
-            HeapFree( GetProcessHeap(), 0, wsa );
-            SetLastError(WSAEFAULT);
-            return FALSE;
-        }
-
-        wsa->read->hSocket     = wsa->accept_socket;
-        wsa->read->flags       = 0;
-        wsa->read->lpFlags     = &wsa->read->flags;
-        wsa->read->addr        = NULL;
-        wsa->read->addrlen.ptr = NULL;
-        wsa->read->control     = NULL;
-        wsa->read->n_iovecs    = 1;
-        wsa->read->first_iovec = 0;
-        wsa->read->completion_func = NULL;
-        wsa->read->iovec[0].iov_base = wsa->buf;
-        wsa->read->iovec[0].iov_len  = wsa->data_len;
-    }
-
-    status = register_async( ASYNC_TYPE_READ, SOCKET2HANDLE(listener), &wsa->io,
-                             overlapped->hEvent, NULL, (void *)wsa->cvalue, (IO_STATUS_BLOCK *)overlapped );
-
-    if(status != STATUS_PENDING)
-    {
-        HeapFree( GetProcessHeap(), 0, wsa->read );
-        HeapFree( GetProcessHeap(), 0, wsa );
-    }
-
-    SetLastError( NtStatusToWSAError(status) );
-    return FALSE;
+    if (ret_len) *ret_len = overlapped->InternalHigh;
+    WSASetLastError( NtStatusToWSAError(status) );
+    return !status;
 }
 
 /***********************************************************************
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c
index 5fb85bfaa60..c20a83b4261 100644
--- a/dlls/ws2_32/tests/sock.c
+++ b/dlls/ws2_32/tests/sock.c
@@ -7379,6 +7379,8 @@ todo_wine
     ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on a non-listening socket "
         "returned %d + errno %d\n", bret, WSAGetLastError());
     ok(overlapped.Internal == STATUS_PENDING, "got %08x\n", (ULONG)overlapped.Internal);
+    if (!bret && WSAGetLastError() == ERROR_IO_PENDING)
+        CancelIo((HANDLE)listener);
 
     iret = listen(listener, 5);
     ok(!iret, "failed to listen, error %u\n", GetLastError());
@@ -7452,9 +7454,9 @@ todo_wine
     bytesReturned = 0xdeadbeef;
     SetLastError(0xdeadbeef);
     bret = GetOverlappedResult((HANDLE)listener, &overlapped, &bytesReturned, FALSE);
-    todo_wine ok(!bret, "expected failure\n");
-    todo_wine ok(GetLastError() == ERROR_INSUFFICIENT_BUFFER, "got error %u\n", GetLastError());
-    todo_wine ok((NTSTATUS)overlapped.Internal == STATUS_BUFFER_TOO_SMALL, "got %#lx\n", overlapped.Internal);
+    ok(!bret, "expected failure\n");
+    ok(GetLastError() == ERROR_INSUFFICIENT_BUFFER, "got error %u\n", GetLastError());
+    ok((NTSTATUS)overlapped.Internal == STATUS_BUFFER_TOO_SMALL, "got %#lx\n", overlapped.Internal);
     ok(!bytesReturned, "got size %u\n", bytesReturned);
 
     closesocket(acceptor);
@@ -7796,19 +7798,10 @@ todo_wine
     closesocket(acceptor);
 
     dwret = WaitForSingleObject(overlapped.hEvent, 1000);
-    todo_wine ok(dwret == WAIT_OBJECT_0,
+    ok(dwret == WAIT_OBJECT_0,
        "Waiting for accept event failed with %d + errno %d\n", dwret, GetLastError());
-
-    if (dwret != WAIT_TIMEOUT) {
-        bret = GetOverlappedResult((HANDLE)listener, &overlapped, &bytesReturned, FALSE);
-        ok(!bret && GetLastError() == ERROR_OPERATION_ABORTED, "GetOverlappedResult failed, error %d\n", GetLastError());
-    }
-    else {
-        bret = CancelIo((HANDLE) listener);
-        ok(bret, "Failed to cancel failed test. Bailing...\n");
-        if (!bret) return;
-        WaitForSingleObject(overlapped.hEvent, 0);
-    }
+    bret = GetOverlappedResult((HANDLE)listener, &overlapped, &bytesReturned, FALSE);
+    ok(!bret && GetLastError() == ERROR_OPERATION_ABORTED, "GetOverlappedResult failed, error %d\n", GetLastError());
 
     acceptor = socket(AF_INET, SOCK_STREAM, 0);
     ok(acceptor != INVALID_SOCKET, "failed to create socket, error %u\n", GetLastError());
@@ -9381,12 +9374,12 @@ static void test_completion_port(void)
 
     bret = GetQueuedCompletionStatus(io_port, &num_bytes, &key, &olp, 100);
     ok(bret == FALSE, "failed to get completion status %u\n", bret);
-    todo_wine ok(GetLastError() == ERROR_OPERATION_ABORTED
+    ok(GetLastError() == ERROR_OPERATION_ABORTED
             || GetLastError() == ERROR_CONNECTION_ABORTED, "got error %u\n", GetLastError());
     ok(key == 125, "Key is %lu\n", key);
     ok(num_bytes == 0, "Number of bytes transferred is %u\n", num_bytes);
     ok(olp == &ov, "Overlapped structure is at %p\n", olp);
-    todo_wine ok((NTSTATUS)olp->Internal == STATUS_CANCELLED
+    ok((NTSTATUS)olp->Internal == STATUS_CANCELLED
             || (NTSTATUS)olp->Internal == STATUS_CONNECTION_ABORTED, "got status %#lx\n", olp->Internal);
 
     SetLastError(0xdeadbeef);
diff --git a/include/wine/afd.h b/include/wine/afd.h
index 5a994084e16..07320e7bab5 100644
--- a/include/wine/afd.h
+++ b/include/wine/afd.h
@@ -22,6 +22,7 @@
 #define __WINE_WINE_AFD_H
 
 #include <winioctl.h>
+#include "wine/server_protocol.h"
 
 #define IOCTL_AFD_CREATE                    CTL_CODE(FILE_DEVICE_NETWORK, 200, METHOD_BUFFERED, FILE_WRITE_ACCESS)
 #define IOCTL_AFD_ACCEPT                    CTL_CODE(FILE_DEVICE_NETWORK, 201, METHOD_BUFFERED, FILE_WRITE_ACCESS)
@@ -35,4 +36,10 @@ struct afd_create_params
     unsigned int flags;
 };
 
+struct afd_accept_into_params
+{
+    obj_handle_t accept_handle;
+    unsigned int recv_len, local_len;
+};
+
 #endif
diff --git a/server/sock.c b/server/sock.c
index 4f97fe72080..2e82c7ffdde 100644
--- a/server/sock.c
+++ b/server/sock.c
@@ -84,6 +84,7 @@
 #include "winerror.h"
 #define USE_WS_PREFIX
 #include "winsock2.h"
+#include "ws2tcpip.h"
 #include "wsipx.h"
 #include "wine/afd.h"
 
@@ -120,6 +121,15 @@
 #define FD_WINE_RAW                0x80000000
 #define FD_WINE_INTERNAL           0xFFFF0000
 
+struct accept_req
+{
+    struct list entry;
+    struct async *async;
+    struct sock *acceptsock;
+    int accepted;
+    unsigned int recv_len, local_len;
+};
+
 struct sock
 {
     struct object       obj;         /* object header */
@@ -143,8 +153,11 @@ struct sock
     struct async_queue  read_q;      /* queue for asynchronous reads */
     struct async_queue  write_q;     /* queue for asynchronous writes */
     struct async_queue  ifchange_q;  /* queue for interface change notifications */
+    struct async_queue  accept_q;    /* queue for asynchronous accepts */
     struct object      *ifchange_obj; /* the interface change notification object */
     struct list         ifchange_entry; /* entry in ifchange notification list */
+    struct list         accept_list; /* list of pending accept requests */
+    struct accept_req  *accept_recv_req; /* pending accept-into request which will recv on this socket */
 };
 
 static void sock_dump( struct object *obj, int verbose );
@@ -161,6 +174,7 @@ static int sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async );
 static void sock_queue_async( struct fd *fd, struct async *async, int type, int count );
 static void sock_reselect_async( struct fd *fd, struct async_queue *queue );
 
+static int accept_into_socket( struct sock *sock, struct sock *acceptsock );
 static int sock_get_ntstatus( int err );
 static unsigned int sock_get_error( int err );
 
@@ -203,6 +217,93 @@ static const struct fd_ops sock_fd_ops =
     sock_reselect_async           /* reselect_async */
 };
 
+union unix_sockaddr
+{
+    struct sockaddr addr;
+    struct sockaddr_in in;
+    struct sockaddr_in6 in6;
+#ifdef HAS_IPX
+    struct sockaddr_ipx ipx;
+#endif
+#ifdef HAS_IRDA
+    struct sockaddr_irda irda;
+#endif
+};
+
+static int sockaddr_from_unix( const union unix_sockaddr *uaddr, struct WS_sockaddr *wsaddr, socklen_t wsaddrlen )
+{
+    memset( wsaddr, 0, wsaddrlen );
+
+    switch (uaddr->addr.sa_family)
+    {
+    case AF_INET:
+    {
+        struct WS_sockaddr_in *win = (struct WS_sockaddr_in *)wsaddr;
+
+        if (wsaddrlen < sizeof(*win)) return -1;
+        win->sin_family = WS_AF_INET;
+        win->sin_port = uaddr->in.sin_port;
+        memcpy( &win->sin_addr, &uaddr->in.sin_addr, sizeof(win->sin_addr) );
+        return sizeof(*win);
+    }
+
+    case AF_INET6:
+    {
+        struct WS_sockaddr_in6 *win = (struct WS_sockaddr_in6 *)wsaddr;
+
+        if (wsaddrlen < sizeof(struct WS_sockaddr_in6_old)) return -1;
+        win->sin6_family = WS_AF_INET6;
+        win->sin6_port = uaddr->in6.sin6_port;
+        win->sin6_flowinfo = uaddr->in6.sin6_flowinfo;
+        memcpy( &win->sin6_addr, &uaddr->in6.sin6_addr, sizeof(win->sin6_addr) );
+#ifdef HAVE_STRUCT_SOCKADDR_IN6_SIN6_SCOPE_ID
+        if (wsaddrlen >= sizeof(struct WS_sockaddr_in6))
+        {
+            win->sin6_scope_id = uaddr->in6.sin6_scope_id;
+            return sizeof(struct WS_sockaddr_in6);
+        }
+#endif
+        return sizeof(struct WS_sockaddr_in6_old);
+    }
+
+#ifdef HAS_IPX
+    case AF_IPX:
+    {
+        struct WS_sockaddr_ipx *win = (struct WS_sockaddr_ipx *)wsaddr;
+
+        if (wsaddrlen < sizeof(*win)) return -1;
+        win->sa_family = WS_AF_IPX;
+        memcpy( win->sa_netnum, &uaddr->ipx.sipx_network, sizeof(win->sa_netnum) );
+        memcpy( win->sa_nodenum, &uaddr->ipx.sipx_node, sizeof(win->sa_nodenum) );
+        win->sa_socket = uaddr->ipx.sipx_port;
+        return sizeof(*win);
+    }
+#endif
+
+#ifdef HAS_IRDA
+    case AF_IRDA:
+    {
+        SOCKADDR_IRDA *win = (SOCKADDR_IRDA *)wsaddr;
+
+        if (wsaddrlen < sizeof(*win)) return -1;
+        win->irdaAddressFamily = WS_AF_IRDA;
+        memcpy( win->irdaDeviceID, &uaddr->irda.sir_addr, sizeof(win->irdaDeviceID) );
+        if (uaddr->irda.sir_lsap_sel != LSAP_ANY)
+            snprintf( win->irdaServiceName, sizeof(win->irdaServiceName), "LSAP-SEL%u", uaddr->irda.sir_lsap_sel );
+        else
+            memcpy( win->irdaServiceName, uaddr->irda.sir_name, sizeof(win->irdaServiceName) );
+        return sizeof(*win);
+    }
+#endif
+
+    case AF_UNSPEC:
+        return 0;
+
+    default:
+        return -1;
+
+    }
+}
 
 /* Permutation of 0..FD_MAX_EVENTS - 1 representing the order in which
  * we post messages if there are multiple events.  Used to send
@@ -339,8 +440,135 @@ static inline int sock_error( struct fd *fd )
     return optval;
 }
 
+static void free_accept_req( struct accept_req *req )
+{
+    list_remove( &req->entry );
+    req->acceptsock->accept_recv_req = NULL;
+    release_object( req->async );
+    free( req );
+}
+
+static void fill_accept_output( struct accept_req *req, struct iosb *iosb )
+{
+    union unix_sockaddr unix_addr;
+    struct WS_sockaddr *win_addr;
+    socklen_t unix_len;
+    int fd, size = 0;
+    char *out_data;
+    int win_len;
+
+    if (!(out_data = mem_alloc( iosb->out_size ))) return;
+
+    fd = get_unix_fd( req->acceptsock->fd );
+
+    if (req->recv_len && (size = recv( fd, out_data, req->recv_len, 0 )) < 0)
+    {
+        if (!req->accepted && errno == EWOULDBLOCK)
+        {
+            req->accepted = 1;
+            sock_reselect( req->acceptsock );
+            set_error( STATUS_PENDING );
+            return;
+        }
+
+        set_win32_error( sock_get_error( errno ) );
+        free( out_data );
+        return;
+    }
+
+    if (req->local_len)
+    {
+        if (req->local_len < sizeof(int))
+        {
+            set_error( STATUS_BUFFER_TOO_SMALL );
+            free( out_data );
+            return;
+        }
+
+        unix_len = sizeof(unix_addr);
+        win_addr = (struct WS_sockaddr *)(out_data + req->recv_len + sizeof(int));
+        if (getsockname( fd, &unix_addr.addr, &unix_len ) < 0 ||
+            (win_len = sockaddr_from_unix( &unix_addr, win_addr, req->local_len )) < 0)
+        {
+            set_win32_error( sock_get_error( errno ) );
+            free( out_data );
+            return;
+        }
+        *(int *)(out_data + req->recv_len) = win_len;
+
+    }
+
+    unix_len = sizeof(unix_addr);
+    win_addr = (struct WS_sockaddr *)(out_data + req->recv_len + req->local_len + sizeof(int));
+    if (getpeername( fd, &unix_addr.addr, &unix_len ) < 0 ||
+        (win_len = sockaddr_from_unix( &unix_addr, win_addr, iosb->out_size - req->recv_len - req->local_len )) < 0)
+    {
+        set_win32_error( sock_get_error( errno ) );
+        free( out_data );
+        return;
+    }
+    *(int *)(out_data + req->recv_len + req->local_len) = win_len;
+
+    iosb->status = STATUS_SUCCESS;
+    iosb->result = size;
+    iosb->out_data = out_data;
+    set_error( STATUS_ALERTED );
+}
+
+static void complete_async_accept( struct sock *sock, struct accept_req *req )
+{
+    struct sock *acceptsock = req->acceptsock;
+    struct async *async = req->async;
+    struct iosb *iosb;
+
+    if (debug_level) fprintf( stderr, "completing accept request for socket %p\n", sock );
+
+    if (!accept_into_socket( sock, acceptsock )) return;
+
+    iosb = async_get_iosb( async );
+    fill_accept_output( req, iosb );
+    release_object( iosb );
+}
+
+static void complete_async_accept_recv( struct accept_req *req )
+{
+    struct async *async = req->async;
+    struct iosb *iosb;
+
+    if (debug_level) fprintf( stderr, "completing accept recv request for socket %p\n", req->acceptsock );
+
+    assert( req->recv_len );
+
+    iosb = async_get_iosb( async );
+    fill_accept_output( req, iosb );
+    release_object( iosb );
+}
+
 static int sock_dispatch_asyncs( struct sock *sock, int event, int error )
 {
+    if (event & (POLLIN | POLLPRI))
+    {
+        struct accept_req *req;
+
+        LIST_FOR_EACH_ENTRY( req, &sock->accept_list, struct accept_req, entry )
+        {
+            if (!req->accepted)
+            {
+                complete_async_accept( sock, req );
+                if (get_error() != STATUS_PENDING)
+                    async_terminate( req->async, get_error() );
+                break;
+            }
+        }
+
+        if (sock->accept_recv_req)
+        {
+            complete_async_accept_recv( sock->accept_recv_req );
+            if (get_error() != STATUS_PENDING)
+                async_terminate( sock->accept_recv_req->async, get_error() );
+        }
+    }
+
     if (is_fd_overlapped( sock->fd ))
     {
         if (event & (POLLIN|POLLPRI) && async_waiting( &sock->read_q ))
@@ -355,16 +583,25 @@ static int sock_dispatch_asyncs( struct sock *sock, int event, int error )
             async_wake_up( &sock->write_q, STATUS_ALERTED );
             event &= ~POLLOUT;
         }
-        if ( event & (POLLERR|POLLHUP) )
-        {
-            int status = sock_get_ntstatus( error );
+    }
 
-            if ( !(sock->state & FD_READ) )
-                async_wake_up( &sock->read_q, status );
-            if ( !(sock->state & FD_WRITE) )
-                async_wake_up( &sock->write_q, status );
-        }
+    if (event & (POLLERR | POLLHUP))
+    {
+        int status = sock_get_ntstatus( error );
+        struct accept_req *req, *next;
+
+        if (!(sock->state & FD_READ))
+            async_wake_up( &sock->read_q, status );
+        if (!(sock->state & FD_WRITE))
+            async_wake_up( &sock->write_q, status );
+
+        LIST_FOR_EACH_ENTRY_SAFE( req, next, &sock->accept_list, struct accept_req, entry )
+            async_terminate( req->async, status );
+
+        if (sock->accept_recv_req)
+            async_terminate( sock->accept_recv_req->async, status );
     }
+
     return event;
 }
 
@@ -539,7 +776,11 @@ static int sock_get_poll_events( struct fd *fd )
         /* connecting, wait for writable */
         return POLLOUT;
 
-    if (async_queued( &sock->read_q ))
+    if (!list_empty( &sock->accept_list ) || sock->accept_recv_req )
+    {
+        ev |= POLLIN | POLLPRI;
+    }
+    else if (async_queued( &sock->read_q ))
     {
         if (async_waiting( &sock->read_q )) ev |= POLLIN | POLLPRI;
     }
@@ -601,6 +842,16 @@ static void sock_queue_async( struct fd *fd, struct async *async, int type, int
 static void sock_reselect_async( struct fd *fd, struct async_queue *queue )
 {
     struct sock *sock = get_fd_user( fd );
+    struct accept_req *req, *next;
+
+    LIST_FOR_EACH_ENTRY_SAFE( req, next, &sock->accept_list, struct accept_req, entry )
+    {
+        struct iosb *iosb = async_get_iosb( req->async );
+        if (iosb->status != STATUS_PENDING)
+            free_accept_req( req );
+        release_object( iosb );
+    }
+
     /* ignore reselect on ifchange queue */
     if (&sock->ifchange_q != queue)
         sock_reselect( sock );
@@ -615,6 +866,8 @@ static struct fd *sock_get_fd( struct object *obj )
 static void sock_destroy( struct object *obj )
 {
     struct sock *sock = (struct sock *)obj;
+    struct accept_req *req, *next;
+
     assert( obj->ops == &sock_ops );
 
     /* FIXME: special socket shutdown stuff? */
@@ -622,11 +875,18 @@ static void sock_destroy( struct object *obj )
     if ( sock->deferred )
         release_object( sock->deferred );
 
+    if (sock->accept_recv_req)
+        async_terminate( sock->accept_recv_req->async, STATUS_CANCELLED );
+
+    LIST_FOR_EACH_ENTRY_SAFE( req, next, &sock->accept_list, struct accept_req, entry )
+        async_terminate( req->async, STATUS_CANCELLED );
+
     async_wake_up( &sock->ifchange_q, STATUS_CANCELLED );
     sock_release_ifchange( sock );
     free_async_queue( &sock->read_q );
     free_async_queue( &sock->write_q );
     free_async_queue( &sock->ifchange_q );
+    free_async_queue( &sock->accept_q );
     if (sock->event) release_object( sock->event );
     if (sock->fd)
     {
@@ -658,10 +918,13 @@ static struct sock *create_socket(void)
     sock->connect_time = 0;
     sock->deferred = NULL;
     sock->ifchange_obj = NULL;
+    sock->accept_recv_req = NULL;
     init_async_queue( &sock->read_q );
     init_async_queue( &sock->write_q );
     init_async_queue( &sock->ifchange_q );
+    init_async_queue( &sock->accept_q );
     memset( sock->errors, 0, sizeof(sock->errors) );
+    list_init( &sock->accept_list );
     return sock;
 }
 
@@ -1065,6 +1328,24 @@ static int sock_get_ntstatus( int err )
     }
 }
 
+static struct accept_req *alloc_accept_req( struct sock *acceptsock, struct async *async,
+                                            const struct afd_accept_into_params *params )
+{
+    struct accept_req *req = mem_alloc( sizeof(*req) );
+
+    if (req)
+    {
+        req->async = (struct async *)grab_object( async );
+        req->acceptsock = acceptsock;
+        req->accepted = 0;
+        req->recv_len = 0;
+        req->local_len = 0;
+        req->recv_len = params->recv_len;
+        req->local_len = params->local_len;
+    }
+    return req;
+}
+
 static int sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async )
 {
     struct sock *sock = get_fd_user( fd );
@@ -1111,22 +1392,38 @@ static int sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async )
     case IOCTL_AFD_ACCEPT_INTO:
     {
         static const int access = FILE_READ_ATTRIBUTES | FILE_WRITE_ATTRIBUTES | FILE_READ_DATA;
+        const struct afd_accept_into_params *params = get_req_data();
         struct sock *acceptsock;
-        obj_handle_t handle;
+        unsigned int remote_len;
+        struct accept_req *req;
 
-        if (get_req_data_size() != sizeof(handle))
+        if (get_req_data_size() != sizeof(*params) ||
+            get_reply_max_size() < params->recv_len + params->local_len)
         {
             set_error( STATUS_BUFFER_TOO_SMALL );
             return 0;
         }
-        handle = *(obj_handle_t *)get_req_data();
 
-        if (!(acceptsock = (struct sock *)get_handle_obj( current->process, handle, access, &sock_ops )))
+        remote_len = get_reply_max_size() - params->recv_len - params->local_len;
+        if (remote_len < sizeof(int))
+        {
+            set_error( STATUS_INVALID_PARAMETER );
+            return 0;
+        }
+
+        if (!(acceptsock = (struct sock *)get_handle_obj( current->process, params->accept_handle, access, &sock_ops )))
             return 0;
-        if (accept_into_socket( sock, acceptsock ))
-            acceptsock->wparam = handle;
+
+        if (!(req = alloc_accept_req( acceptsock, async, params ))) return 0;
+        list_add_tail( &sock->accept_list, &req->entry );
+        acceptsock->accept_recv_req = req;
         release_object( acceptsock );
-        return 0;
+
+        acceptsock->wparam = params->accept_handle;
+        queue_async( &sock->accept_q, async );
+        sock_reselect( sock );
+        set_error( STATUS_PENDING );
+        return 1;
     }
 
     case IOCTL_AFD_ADDRESS_LIST_CHANGE:
-- 
2.28.0




More information about the wine-devel mailing list