[PATCH v2 11/11] server: Handle the entire IOCTL_AFD_POLL ioctl on the server side.

Zebediah Figura zfigura at codeweavers.com
Fri Dec 10 11:27:39 CST 2021


Signed-off-by: Zebediah Figura <zfigura at codeweavers.com>
---
 dlls/ntdll/unix/socket.c | 153 +-------------------------------
 include/wine/afd.h       |  30 ++++++-
 server/protocol.def      |  25 ------
 server/sock.c            | 182 ++++++++++++++++++++++++++++-----------
 server/trace.c           |  32 -------
 5 files changed, 166 insertions(+), 256 deletions(-)

diff --git a/dlls/ntdll/unix/socket.c b/dlls/ntdll/unix/socket.c
index 20368ad6415..92374e39db7 100644
--- a/dlls/ntdll/unix/socket.c
+++ b/dlls/ntdll/unix/socket.c
@@ -770,151 +770,6 @@ static NTSTATUS sock_recv( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, voi
 }
 
 
-struct async_poll_ioctl
-{
-    struct async_fileio io;
-    unsigned int count;
-    struct afd_poll_params *input, *output;
-    struct poll_socket_output sockets[1];
-};
-
-static ULONG_PTR fill_poll_output( struct async_poll_ioctl *async, NTSTATUS status )
-{
-    struct afd_poll_params *input = async->input, *output = async->output;
-    unsigned int i, count = 0;
-
-    memcpy( output, input, offsetof( struct afd_poll_params, sockets[0] ) );
-
-    if (!status)
-    {
-        for (i = 0; i < async->count; ++i)
-        {
-            if (async->sockets[i].flags)
-            {
-                output->sockets[count].socket = input->sockets[i].socket;
-                output->sockets[count].flags = async->sockets[i].flags;
-                output->sockets[count].status = async->sockets[i].status;
-                ++count;
-            }
-        }
-    }
-    output->count = count;
-    return offsetof( struct afd_poll_params, sockets[count] );
-}
-
-static BOOL async_poll_proc( void *user, ULONG_PTR *info, NTSTATUS *status )
-{
-    struct async_poll_ioctl *async = user;
-
-    if (*status == STATUS_ALERTED)
-    {
-        SERVER_START_REQ( get_async_result )
-        {
-            req->user_arg = wine_server_client_ptr( async );
-            wine_server_set_reply( req, async->sockets, async->count * sizeof(async->sockets[0]) );
-            *status = wine_server_call( req );
-        }
-        SERVER_END_REQ;
-
-        *info = fill_poll_output( async, *status );
-    }
-
-    free( async->input );
-    release_fileio( &async->io );
-    return TRUE;
-}
-
-
-/* we could handle this ioctl entirely on the server side, but the differing
- * structure size makes it painful */
-static NTSTATUS sock_poll( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc_user, IO_STATUS_BLOCK *io,
-                           void *in_buffer, ULONG in_size, void *out_buffer, ULONG out_size )
-{
-    const struct afd_poll_params *params = in_buffer;
-    struct poll_socket_input *input;
-    struct async_poll_ioctl *async;
-    HANDLE wait_handle;
-    DWORD async_size;
-    NTSTATUS status;
-    unsigned int i;
-    ULONG options;
-
-    if (in_size < sizeof(*params) || out_size < in_size || !params->count
-            || in_size < offsetof( struct afd_poll_params, sockets[params->count] ))
-        return STATUS_INVALID_PARAMETER;
-
-    TRACE( "timeout %s, count %u, exclusive %#x, padding (%#x, %#x, %#x), sockets[0] {%04lx, %#x}\n",
-            wine_dbgstr_longlong(params->timeout), params->count, params->exclusive,
-            params->padding[0], params->padding[1], params->padding[2],
-            params->sockets[0].socket, params->sockets[0].flags );
-
-    if (params->padding[0]) FIXME( "padding[0] is %#x\n", params->padding[0] );
-    if (params->padding[1]) FIXME( "padding[1] is %#x\n", params->padding[1] );
-    if (params->padding[2]) FIXME( "padding[2] is %#x\n", params->padding[2] );
-    for (i = 0; i < params->count; ++i)
-    {
-        if (params->sockets[i].flags & ~0x1ff)
-            FIXME( "unknown socket flags %#x\n", params->sockets[i].flags );
-    }
-
-    if (!(input = malloc( params->count * sizeof(*input) )))
-        return STATUS_NO_MEMORY;
-
-    async_size = offsetof( struct async_poll_ioctl, sockets[params->count] );
-
-    if (!(async = (struct async_poll_ioctl *)alloc_fileio( async_size, async_poll_proc, handle )))
-    {
-        free( input );
-        return STATUS_NO_MEMORY;
-    }
-
-    if (!(async->input = malloc( in_size )))
-    {
-        release_fileio( &async->io );
-        free( input );
-        return STATUS_NO_MEMORY;
-    }
-    memcpy( async->input, in_buffer, in_size );
-
-    async->count = params->count;
-    async->output = out_buffer;
-
-    for (i = 0; i < params->count; ++i)
-    {
-        input[i].socket = params->sockets[i].socket;
-        input[i].flags = params->sockets[i].flags;
-    }
-
-    SERVER_START_REQ( poll_socket )
-    {
-        req->async = server_async( handle, &async->io, event, apc, apc_user, iosb_client_ptr(io) );
-        req->exclusive = !!params->exclusive;
-        req->timeout = params->timeout;
-        wine_server_add_data( req, input, params->count * sizeof(*input) );
-        wine_server_set_reply( req, async->sockets, params->count * sizeof(async->sockets[0]) );
-        status = wine_server_call( req );
-        wait_handle = wine_server_ptr_handle( reply->wait );
-        options = reply->options;
-        if (wait_handle && status != STATUS_PENDING)
-        {
-            io->Status = status;
-            io->Information = fill_poll_output( async, status );
-        }
-    }
-    SERVER_END_REQ;
-
-    free( input );
-
-    if (status != STATUS_PENDING)
-    {
-        free( async->input );
-        release_fileio( &async->io );
-    }
-
-    if (wait_handle) status = wait_async( wait_handle, (options & FILE_SYNCHRONOUS_IO_ALERT) );
-    return status;
-}
-
 static NTSTATUS try_send( int fd, struct async_send_ioctl *async )
 {
     union unix_sockaddr unix_addr;
@@ -1374,6 +1229,10 @@ NTSTATUS sock_ioctl( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc
             status = STATUS_BAD_DEVICE_TYPE;
             break;
 
+        case IOCTL_AFD_POLL:
+            status = STATUS_BAD_DEVICE_TYPE;
+            break;
+
         case IOCTL_AFD_RECV:
         {
             struct afd_recv_params params;
@@ -1516,10 +1375,6 @@ NTSTATUS sock_ioctl( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc
             break;
         }
 
-        case IOCTL_AFD_POLL:
-            status = sock_poll( handle, event, apc, apc_user, io, in_buffer, in_size, out_buffer, out_size );
-            break;
-
         case IOCTL_AFD_WINE_FIONREAD:
         {
             int value, ret;
diff --git a/include/wine/afd.h b/include/wine/afd.h
index 97128c67ca6..efd5787e90a 100644
--- a/include/wine/afd.h
+++ b/include/wine/afd.h
@@ -122,13 +122,41 @@ struct afd_poll_params
     unsigned int count;
     BOOLEAN exclusive;
     BOOLEAN padding[3];
-    struct
+    struct afd_poll_socket
     {
         SOCKET socket;
         int flags;
         NTSTATUS status;
     } sockets[1];
 };
+
+struct afd_poll_params_64
+{
+    LONGLONG timeout;
+    unsigned int count;
+    BOOLEAN exclusive;
+    BOOLEAN padding[3];
+    struct afd_poll_socket_64
+    {
+        ULONGLONG socket;
+        int flags;
+        NTSTATUS status;
+    } sockets[1];
+};
+
+struct afd_poll_params_32
+{
+    LONGLONG timeout;
+    unsigned int count;
+    BOOLEAN exclusive;
+    BOOLEAN padding[3];
+    struct afd_poll_socket_32
+    {
+        ULONG socket;
+        int flags;
+        NTSTATUS status;
+    } sockets[1];
+};
 #include <poppack.h>
 
 struct afd_event_select_params
diff --git a/server/protocol.def b/server/protocol.def
index efe0d22cbc4..db73f0418a9 100644
--- a/server/protocol.def
+++ b/server/protocol.def
@@ -1443,31 +1443,6 @@ enum server_fd_type
 @END
 
 
-struct poll_socket_input
-{
-    obj_handle_t socket;        /* socket handle */
-    int flags;                  /* events to poll for */
-};
-
-struct poll_socket_output
-{
-    int flags;                  /* events signaled */
-    unsigned int status;        /* socket status */
-};
-
-/* Perform an async poll on a socket */
- at REQ(poll_socket)
-    int          exclusive;     /* is the poll exclusive? */
-    async_data_t async;         /* async I/O parameters */
-    timeout_t    timeout;       /* timeout */
-    VARARG(sockets,poll_socket_input); /* list of sockets to poll */
- at REPLY
-    obj_handle_t wait;          /* handle to wait on for blocking poll */
-    unsigned int options;       /* file open options */
-    VARARG(sockets,poll_socket_output); /* data returned */
- at END
-
-
 /* Perform a send on a socket */
 @REQ(send_socket)
     async_data_t async;         /* async I/O parameters */
diff --git a/server/sock.c b/server/sock.c
index d674d2a8f84..4f728ef74dc 100644
--- a/server/sock.c
+++ b/server/sock.c
@@ -122,13 +122,16 @@ struct poll_req
     struct async *async;
     struct iosb *iosb;
     struct timeout_user *timeout;
+    timeout_t orig_timeout;
     int exclusive;
     unsigned int count;
-    struct poll_socket_output *output;
     struct
     {
         struct sock *sock;
         int mask;
+        obj_handle_t handle;
+        int flags;
+        unsigned int status;
     } sockets[1];
 };
 
@@ -235,6 +238,8 @@ static int accept_into_socket( struct sock *sock, struct sock *acceptsock );
 static struct sock *accept_socket( struct sock *sock );
 static int sock_get_ntstatus( int err );
 static unsigned int sock_get_error( int err );
+static void poll_socket( struct sock *poll_sock, struct async *async, int exclusive, timeout_t timeout,
+                         unsigned int count, const struct afd_poll_socket_64 *sockets );
 
 static const struct object_ops sock_ops =
 {
@@ -789,7 +794,6 @@ static void free_poll_req( void *private )
     release_object( req->async );
     release_object( req->iosb );
     list_remove( &req->entry );
-    free( req->output );
     free( req );
 }
 
@@ -832,8 +836,7 @@ static int get_poll_flags( struct sock *sock, int event )
 
 static void complete_async_poll( struct poll_req *req, unsigned int status )
 {
-    struct poll_socket_output *output = req->output;
-    unsigned int i;
+    unsigned int i, signaled_count = 0;;
 
     for (i = 0; i < req->count; ++i)
     {
@@ -843,9 +846,65 @@ static void complete_async_poll( struct poll_req *req, unsigned int status )
             sock->main_poll = NULL;
     }
 
-    /* pass 0 as result; client will set actual result size */
-    req->output = NULL;
-    async_request_complete( req->async, status, 0, req->count * sizeof(*output), output );
+    if (!status)
+    {
+        for (i = 0; i < req->count; ++i)
+        {
+            if (req->sockets[i].flags)
+                ++signaled_count;
+        }
+    }
+
+    if (is_machine_64bit( async_get_thread( req->async )->process->machine ))
+    {
+        size_t output_size = offsetof( struct afd_poll_params_64, sockets[signaled_count] );
+        struct afd_poll_params_64 *output;
+
+        if (!(output = mem_alloc( output_size )))
+        {
+            async_terminate( req->async, get_error() );
+            return;
+        }
+        memset( output, 0, output_size );
+        output->timeout = req->orig_timeout;
+        output->exclusive = req->exclusive;
+        for (i = 0; i < req->count; ++i)
+        {
+            if (!req->sockets[i].flags) continue;
+            output->sockets[output->count].socket = req->sockets[i].handle;
+            output->sockets[output->count].flags = req->sockets[i].flags;
+            output->sockets[output->count].status = req->sockets[i].status;
+            ++output->count;
+        }
+        assert( output->count == signaled_count );
+
+        async_request_complete( req->async, status, output_size, output_size, output );
+    }
+    else
+    {
+        size_t output_size = offsetof( struct afd_poll_params_32, sockets[signaled_count] );
+        struct afd_poll_params_32 *output;
+
+        if (!(output = mem_alloc( output_size )))
+        {
+            async_terminate( req->async, get_error() );
+            return;
+        }
+        memset( output, 0, output_size );
+        output->timeout = req->orig_timeout;
+        output->exclusive = req->exclusive;
+        for (i = 0; i < req->count; ++i)
+        {
+            if (!req->sockets[i].flags) continue;
+            output->sockets[output->count].socket = req->sockets[i].handle;
+            output->sockets[output->count].flags = req->sockets[i].flags;
+            output->sockets[output->count].status = req->sockets[i].status;
+            ++output->count;
+        }
+        assert( output->count == signaled_count );
+
+        async_request_complete( req->async, status, output_size, output_size, output );
+    }
 }
 
 static void complete_async_polls( struct sock *sock, int event, int error )
@@ -868,8 +927,8 @@ static void complete_async_polls( struct sock *sock, int event, int error )
                 fprintf( stderr, "completing poll for socket %p, wanted %#x got %#x\n",
                          sock, req->sockets[i].mask, flags );
 
-            req->output[i].flags = req->sockets[i].mask & flags;
-            req->output[i].status = sock_get_ntstatus( error );
+            req->sockets[i].flags = req->sockets[i].mask & flags;
+            req->sockets[i].status = sock_get_ntstatus( error );
 
             complete_async_poll( req, STATUS_SUCCESS );
             break;
@@ -1353,8 +1412,8 @@ static int sock_close_handle( struct object *obj, struct process *process, obj_h
                 if (poll_req->sockets[i].sock == sock)
                 {
                     signaled = TRUE;
-                    poll_req->output[i].flags = AFD_POLL_CLOSE;
-                    poll_req->output[i].status = 0;
+                    poll_req->sockets[i].flags = AFD_POLL_CLOSE;
+                    poll_req->sockets[i].status = 0;
                 }
             }
 
@@ -2849,6 +2908,55 @@ static void sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async )
         return;
     }
 
+    case IOCTL_AFD_POLL:
+    {
+        if (get_reply_max_size() < get_req_data_size())
+        {
+            set_error( STATUS_INVALID_PARAMETER );
+            return;
+        }
+
+        if (is_machine_64bit( current->process->machine ))
+        {
+            const struct afd_poll_params_64 *params = get_req_data();
+
+            if (get_req_data_size() < sizeof(struct afd_poll_params_64) ||
+                get_req_data_size() < offsetof( struct afd_poll_params_64, sockets[params->count] ))
+            {
+                set_error( STATUS_INVALID_PARAMETER );
+                return;
+            }
+
+            poll_socket( sock, async, params->exclusive, params->timeout, params->count, params->sockets );
+        }
+        else
+        {
+            const struct afd_poll_params_32 *params = get_req_data();
+            struct afd_poll_socket_64 *sockets;
+            unsigned int i;
+
+            if (get_req_data_size() < sizeof(struct afd_poll_params_32) ||
+                get_req_data_size() < offsetof( struct afd_poll_params_32, sockets[params->count] ))
+            {
+                set_error( STATUS_INVALID_PARAMETER );
+                return;
+            }
+
+            if (!(sockets = mem_alloc( params->count * sizeof(*sockets) ))) return;
+            for (i = 0; i < params->count; ++i)
+            {
+                sockets[i].socket = params->sockets[i].socket;
+                sockets[i].flags = params->sockets[i].flags;
+                sockets[i].status = params->sockets[i].status;
+            }
+
+            poll_socket( sock, async, params->exclusive, params->timeout, params->count, sockets );
+            free( sockets );
+        }
+
+        return;
+    }
+
     default:
         set_error( STATUS_NOT_SUPPORTED );
         return;
@@ -2899,51 +3007,49 @@ static void handle_exclusive_poll(struct poll_req *req)
 }
 
 static void poll_socket( struct sock *poll_sock, struct async *async, int exclusive, timeout_t timeout,
-                         unsigned int count, const struct poll_socket_input *input )
+                         unsigned int count, const struct afd_poll_socket_64 *sockets )
 {
-    struct poll_socket_output *output;
     BOOL signaled = FALSE;
     struct poll_req *req;
     unsigned int i, j;
 
-    if (!(output = mem_alloc( count * sizeof(*output) )))
+    if (!count)
+    {
+        set_error( STATUS_INVALID_PARAMETER );
         return;
-    memset( output, 0, count * sizeof(*output) );
+    }
 
     if (!(req = mem_alloc( offsetof( struct poll_req, sockets[count] ) )))
-    {
-        free( output );
         return;
-    }
 
     req->timeout = NULL;
     if (timeout && timeout != TIMEOUT_INFINITE &&
         !(req->timeout = add_timeout_user( timeout, async_poll_timeout, req )))
     {
         free( req );
-        free( output );
         return;
     }
+    req->orig_timeout = timeout;
 
     for (i = 0; i < count; ++i)
     {
-        req->sockets[i].sock = (struct sock *)get_handle_obj( current->process, input[i].socket, 0, &sock_ops );
+        req->sockets[i].sock = (struct sock *)get_handle_obj( current->process, sockets[i].socket, 0, &sock_ops );
         if (!req->sockets[i].sock)
         {
             for (j = 0; j < i; ++j) release_object( req->sockets[i].sock );
             if (req->timeout) remove_timeout_user( req->timeout );
             free( req );
-            free( output );
             return;
         }
-        req->sockets[i].mask = input[i].flags;
+        req->sockets[i].handle = sockets[i].socket;
+        req->sockets[i].mask = sockets[i].flags;
+        req->sockets[i].flags = 0;
     }
 
     req->exclusive = exclusive;
     req->count = count;
     req->async = (struct async *)grab_object( async );
     req->iosb = async_get_iosb( async );
-    req->output = output;
 
     handle_exclusive_poll(req);
 
@@ -2960,16 +3066,16 @@ static void poll_socket( struct sock *poll_sock, struct async *async, int exclus
         if (flags)
         {
             signaled = TRUE;
-            output[i].flags = flags;
-            output[i].status = sock_get_ntstatus( sock_error( sock->fd ) );
+            req->sockets[i].flags = flags;
+            req->sockets[i].status = sock_get_ntstatus( sock_error( sock->fd ) );
         }
 
         /* FIXME: do other error conditions deserve a similar treatment? */
         if (sock->state != SOCK_CONNECTING && sock->errors[AFD_POLL_BIT_CONNECT_ERR] && (mask & AFD_POLL_CONNECT_ERR))
         {
             signaled = TRUE;
-            output[i].flags |= AFD_POLL_CONNECT_ERR;
-            output[i].status = sock_get_ntstatus( sock->errors[AFD_POLL_BIT_CONNECT_ERR] );
+            req->sockets[i].flags |= AFD_POLL_CONNECT_ERR;
+            req->sockets[i].status = sock_get_ntstatus( sock->errors[AFD_POLL_BIT_CONNECT_ERR] );
         }
     }
 
@@ -3335,28 +3441,6 @@ DECL_HANDLER(recv_socket)
     release_object( sock );
 }
 
-DECL_HANDLER(poll_socket)
-{
-    struct sock *sock = (struct sock *)get_handle_obj( current->process, req->async.handle, 0, &sock_ops );
-    const struct poll_socket_input *input = get_req_data();
-    struct async *async;
-    unsigned int count;
-
-    if (!sock) return;
-
-    count = get_req_data_size() / sizeof(*input);
-
-    if ((async = create_request_async( sock->fd, get_fd_comp_flags( sock->fd ), &req->async )))
-    {
-        poll_socket( sock, async, req->exclusive, req->timeout, count, input );
-        reply->wait = async_handoff( async, NULL, 0 );
-        reply->options = get_fd_options( sock->fd );
-        release_object( async );
-    }
-
-    release_object( sock );
-}
-
 DECL_HANDLER(send_socket)
 {
     struct sock *sock = (struct sock *)get_handle_obj( current->process, req->async.handle, 0, &sock_ops );
diff --git a/server/trace.c b/server/trace.c
index f7c9cbd975e..3d0f39cb30b 100644
--- a/server/trace.c
+++ b/server/trace.c
@@ -1394,38 +1394,6 @@ static void dump_varargs_handle_infos( const char *prefix, data_size_t size )
     fputc( '}', stderr );
 }
 
-static void dump_varargs_poll_socket_input( const char *prefix, data_size_t size )
-{
-    const struct poll_socket_input *input;
-
-    fprintf( stderr, "%s{", prefix );
-    while (size >= sizeof(*input))
-    {
-        input = cur_data;
-        fprintf( stderr, "{socket=%04x,flags=%08x}", input->socket, input->flags );
-        size -= sizeof(*input);
-        remove_data( sizeof(*input) );
-        if (size) fputc( ',', stderr );
-    }
-    fputc( '}', stderr );
-}
-
-static void dump_varargs_poll_socket_output( const char *prefix, data_size_t size )
-{
-    const struct poll_socket_output *output;
-
-    fprintf( stderr, "%s{", prefix );
-    while (size >= sizeof(*output))
-    {
-        output = cur_data;
-        fprintf( stderr, "{flags=%08x,status=%s}", output->flags, get_status_name( output->status ) );
-        size -= sizeof(*output);
-        remove_data( sizeof(*output) );
-        if (size) fputc( ',', stderr );
-    }
-    fputc( '}', stderr );
-}
-
 typedef void (*dump_func)( const void *req );
 
 /* Everything below this line is generated automatically by tools/make_requests */
-- 
2.34.1




More information about the wine-devel mailing list