Zebediah Figura : ntdll: Use the recv_socket request for NtReadFile() on a socket.

Alexandre Julliard julliard at winehq.org
Thu Jul 14 16:59:25 CDT 2022


Module: wine
Branch: master
Commit: 9783e2e16c35586b5fa65ae27ac1f96c6260c534
URL:    https://gitlab.winehq.org/wine/wine/-/commit/9783e2e16c35586b5fa65ae27ac1f96c6260c534

Author: Zebediah Figura <zfigura at codeweavers.com>
Date:   Sun May 29 16:40:36 2022 -0500

ntdll: Use the recv_socket request for NtReadFile() on a socket.

recv_socket does some extra bookkeeping that's currently missing from the
register_async path. Instead of adding it to sock_queue_async(), let's just
centralize all recv requests so that they go through recv_socket.

---

 dlls/ntdll/unix/file.c         |   6 ++
 dlls/ntdll/unix/socket.c       | 129 +++++++++++++++++++++++++----------------
 dlls/ntdll/unix/unix_private.h |   2 +
 dlls/ws2_32/tests/afd.c        |   2 +-
 4 files changed, 89 insertions(+), 50 deletions(-)

diff --git a/dlls/ntdll/unix/file.c b/dlls/ntdll/unix/file.c
index e09e8cafc82..f57532f55bf 100644
--- a/dlls/ntdll/unix/file.c
+++ b/dlls/ntdll/unix/file.c
@@ -5177,6 +5177,12 @@ NTSTATUS WINAPI NtReadFile( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, vo
             goto done;
         }
     }
+    else if (type == FD_TYPE_SOCKET)
+    {
+        status = sock_read( handle, unix_handle, event, apc, apc_user, io, buffer, length );
+        if (needs_close) close( unix_handle );
+        return status;
+    }
 
     if (type == FD_TYPE_SERIAL && async_read && length)
     {
diff --git a/dlls/ntdll/unix/socket.c b/dlls/ntdll/unix/socket.c
index 351028d4983..6abee680ee5 100644
--- a/dlls/ntdll/unix/socket.c
+++ b/dlls/ntdll/unix/socket.c
@@ -676,17 +676,67 @@ static BOOL async_recv_proc( void *user, ULONG_PTR *info, NTSTATUS *status )
 }
 
 static NTSTATUS sock_recv( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc_user, IO_STATUS_BLOCK *io,
-                           int fd, const void *buffers_ptr, unsigned int count, WSABUF *control,
-                           struct WS_sockaddr *addr, int *addr_len, DWORD *ret_flags, int unix_flags, int force_async )
+                           int fd, struct async_recv_ioctl *async, int force_async )
 {
-    struct async_recv_ioctl *async;
+    BOOL nonblocking, alerted;
     ULONG_PTR information;
     HANDLE wait_handle;
-    DWORD async_size;
     NTSTATUS status;
     unsigned int i;
     ULONG options;
-    BOOL nonblocking, alerted;
+
+    for (i = 0; i < async->count; ++i)
+    {
+        if (!virtual_check_buffer_for_write( async->iov[i].iov_base, async->iov[i].iov_len ))
+        {
+            release_fileio( &async->io );
+            return STATUS_ACCESS_VIOLATION;
+        }
+    }
+
+    SERVER_START_REQ( recv_socket )
+    {
+        req->force_async = force_async;
+        req->async  = server_async( handle, &async->io, event, apc, apc_user, iosb_client_ptr(io) );
+        req->oob    = !!(async->unix_flags & MSG_OOB);
+        status = wine_server_call( req );
+        wait_handle = wine_server_ptr_handle( reply->wait );
+        options     = reply->options;
+        nonblocking = reply->nonblocking;
+    }
+    SERVER_END_REQ;
+
+    alerted = status == STATUS_ALERTED;
+    if (alerted)
+    {
+        status = try_recv( fd, async, &information );
+        if (status == STATUS_DEVICE_NOT_READY && (force_async || !nonblocking))
+            status = STATUS_PENDING;
+    }
+
+    if (status != STATUS_PENDING)
+    {
+        if (!NT_ERROR(status) || (wait_handle && !alerted))
+        {
+            io->Status = status;
+            io->Information = information;
+        }
+        release_fileio( &async->io );
+    }
+
+    if (alerted) set_async_direct_result( &wait_handle, status, information, FALSE );
+    if (wait_handle) status = wait_async( wait_handle, options & FILE_SYNCHRONOUS_IO_ALERT );
+    return status;
+}
+
+
+static NTSTATUS sock_ioctl_recv( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc_user, IO_STATUS_BLOCK *io,
+                                 int fd, const void *buffers_ptr, unsigned int count, WSABUF *control,
+                                 struct WS_sockaddr *addr, int *addr_len, DWORD *ret_flags, int unix_flags, int force_async )
+{
+    struct async_recv_ioctl *async;
+    DWORD async_size;
+    unsigned int i;
 
     if (unix_flags & MSG_OOB)
     {
@@ -728,48 +778,29 @@ static NTSTATUS sock_recv( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, voi
     async->addr_len = addr_len;
     async->ret_flags = ret_flags;
 
-    for (i = 0; i < count; ++i)
-    {
-        if (!virtual_check_buffer_for_write( async->iov[i].iov_base, async->iov[i].iov_len ))
-        {
-            release_fileio( &async->io );
-            return STATUS_ACCESS_VIOLATION;
-        }
-    }
+    return sock_recv( handle, event, apc, apc_user, io, fd, async, force_async );
+}
 
-    SERVER_START_REQ( recv_socket )
-    {
-        req->force_async = force_async;
-        req->async  = server_async( handle, &async->io, event, apc, apc_user, iosb_client_ptr(io) );
-        req->oob    = !!(unix_flags & MSG_OOB);
-        status = wine_server_call( req );
-        wait_handle = wine_server_ptr_handle( reply->wait );
-        options     = reply->options;
-        nonblocking = reply->nonblocking;
-    }
-    SERVER_END_REQ;
 
-    alerted = status == STATUS_ALERTED;
-    if (alerted)
-    {
-        status = try_recv( fd, async, &information );
-        if (status == STATUS_DEVICE_NOT_READY && (force_async || !nonblocking))
-            status = STATUS_PENDING;
-    }
+NTSTATUS sock_read( HANDLE handle, int fd, HANDLE event, PIO_APC_ROUTINE apc,
+                    void *apc_user, IO_STATUS_BLOCK *io, void *buffer, ULONG length )
+{
+    static const DWORD async_size = offsetof( struct async_recv_ioctl, iov[1] );
+    struct async_recv_ioctl *async;
 
-    if (status != STATUS_PENDING)
-    {
-        if (!NT_ERROR(status) || (wait_handle && !alerted))
-        {
-            io->Status = status;
-            io->Information = information;
-        }
-        release_fileio( &async->io );
-    }
+    if (!(async = (struct async_recv_ioctl *)alloc_fileio( async_size, async_recv_proc, handle )))
+        return STATUS_NO_MEMORY;
 
-    if (alerted) set_async_direct_result( &wait_handle, status, information, FALSE );
-    if (wait_handle) status = wait_async( wait_handle, options & FILE_SYNCHRONOUS_IO_ALERT );
-    return status;
+    async->count = 1;
+    async->iov[0].iov_base = buffer;
+    async->iov[0].iov_len = length;
+    async->unix_flags = 0;
+    async->control = NULL;
+    async->addr = NULL;
+    async->addr_len = NULL;
+    async->ret_flags = NULL;
+
+    return sock_recv( handle, event, apc, apc_user, io, fd, async, 1 );
 }
 
 
@@ -1334,8 +1365,8 @@ NTSTATUS sock_ioctl( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc
                 unix_flags |= MSG_PEEK;
             if (params.msg_flags & AFD_MSG_WAITALL)
                 FIXME( "MSG_WAITALL is not supported\n" );
-            status = sock_recv( handle, event, apc, apc_user, io, fd, params.buffers, params.count, NULL,
-                                NULL, NULL, NULL, unix_flags, !!(params.recv_flags & AFD_RECV_FORCE_ASYNC) );
+            status = sock_ioctl_recv( handle, event, apc, apc_user, io, fd, params.buffers, params.count, NULL,
+                                      NULL, NULL, NULL, unix_flags, !!(params.recv_flags & AFD_RECV_FORCE_ASYNC) );
             if (needs_close) close( fd );
             return status;
         }
@@ -1361,10 +1392,10 @@ NTSTATUS sock_ioctl( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc
                 unix_flags |= MSG_PEEK;
             if (*ws_flags & WS_MSG_WAITALL)
                 FIXME( "MSG_WAITALL is not supported\n" );
-            status = sock_recv( handle, event, apc, apc_user, io, fd, u64_to_user_ptr(params->buffers_ptr),
-                                params->count, u64_to_user_ptr(params->control_ptr),
-                                u64_to_user_ptr(params->addr_ptr), u64_to_user_ptr(params->addr_len_ptr),
-                                ws_flags, unix_flags, params->force_async );
+            status = sock_ioctl_recv( handle, event, apc, apc_user, io, fd, u64_to_user_ptr(params->buffers_ptr),
+                                      params->count, u64_to_user_ptr(params->control_ptr),
+                                      u64_to_user_ptr(params->addr_ptr), u64_to_user_ptr(params->addr_len_ptr),
+                                      ws_flags, unix_flags, params->force_async );
             if (needs_close) close( fd );
             return status;
         }
diff --git a/dlls/ntdll/unix/unix_private.h b/dlls/ntdll/unix/unix_private.h
index 39b667ddd4a..8d305bdada6 100644
--- a/dlls/ntdll/unix/unix_private.h
+++ b/dlls/ntdll/unix/unix_private.h
@@ -256,6 +256,8 @@ extern NTSTATUS serial_DeviceIoControl( HANDLE device, HANDLE event, PIO_APC_ROU
 extern NTSTATUS serial_FlushBuffersFile( int fd ) DECLSPEC_HIDDEN;
 extern NTSTATUS sock_ioctl( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc_user, IO_STATUS_BLOCK *io,
                             ULONG code, void *in_buffer, ULONG in_size, void *out_buffer, ULONG out_size ) DECLSPEC_HIDDEN;
+extern NTSTATUS sock_read( HANDLE handle, int fd, HANDLE event, PIO_APC_ROUTINE apc, void *apc_user,
+                           IO_STATUS_BLOCK *io, void *buffer, ULONG length ) DECLSPEC_HIDDEN;
 extern NTSTATUS tape_DeviceIoControl( HANDLE device, HANDLE event, PIO_APC_ROUTINE apc, void *apc_user,
                                       IO_STATUS_BLOCK *io, ULONG code, void *in_buffer,
                                       ULONG in_size, void *out_buffer, ULONG out_size ) DECLSPEC_HIDDEN;
diff --git a/dlls/ws2_32/tests/afd.c b/dlls/ws2_32/tests/afd.c
index 22ccdd0adeb..334c2c41be6 100644
--- a/dlls/ws2_32/tests/afd.c
+++ b/dlls/ws2_32/tests/afd.c
@@ -2645,7 +2645,7 @@ static void test_read_write(void)
 
     ret = WSAEnumNetworkEvents(client, event, &events);
     ok(!ret, "got error %lu\n", GetLastError());
-    todo_wine ok(events.lNetworkEvents == FD_READ, "got events %#lx\n", events.lNetworkEvents);
+    ok(events.lNetworkEvents == FD_READ, "got events %#lx\n", events.lNetworkEvents);
 
     memset(buffer, 0xcc, sizeof(buffer));
     ret = NtReadFile((HANDLE)client, NULL, NULL, NULL, &io, buffer, sizeof(buffer), NULL, NULL);




More information about the wine-cvs mailing list