[PATCH 1/2 v5] ntdll/socket: Implement exclusive flag for IOCTL_AFD_POLL.

Zebediah Figura (she/her) zfigura at codeweavers.com
Thu Sep 2 10:45:51 CDT 2021


On 9/1/21 10:52 AM, Guillaume Charifi wrote:
> v2-v5: Fix some corner cases.
> 
> Signed-off-by: Guillaume Charifi <guillaume.charifi at sfr.fr>
> ---
>   dlls/ntdll/unix/socket.c       |  6 +--
>   include/wine/afd.h             |  2 +-
>   include/wine/server_protocol.h |  5 ++-
>   server/protocol.def            |  1 +
>   server/request.h               |  1 +
>   server/sock.c                  | 75 ++++++++++++++++++++++++++++++++--
>   server/trace.c                 |  3 +-
>   7 files changed, 83 insertions(+), 10 deletions(-)
> 
> diff --git a/dlls/ntdll/unix/socket.c b/dlls/ntdll/unix/socket.c
> index 8469def786a..b1b90b363ee 100644
> --- a/dlls/ntdll/unix/socket.c
> +++ b/dlls/ntdll/unix/socket.c
> @@ -747,12 +747,11 @@ static NTSTATUS sock_poll( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, voi
>               || in_size < offsetof( struct afd_poll_params, sockets[params->count] ))
>           return STATUS_INVALID_PARAMETER;
>   
> -    TRACE( "timeout %s, count %u, unknown %#x, padding (%#x, %#x, %#x), sockets[0] {%04lx, %#x}\n",
> -            wine_dbgstr_longlong(params->timeout), params->count, params->unknown,
> +    TRACE( "timeout %s, count %u, exclusive %u, padding (%#x, %#x, %#x), sockets[0] {%04lx, %#x}\n",
> +            wine_dbgstr_longlong(params->timeout), params->count, params->exclusive,
>               params->padding[0], params->padding[1], params->padding[2],
>               params->sockets[0].socket, params->sockets[0].flags );
>   
> -    if (params->unknown) FIXME( "unknown boolean is %#x\n", params->unknown );
>       if (params->padding[0]) FIXME( "padding[0] is %#x\n", params->padding[0] );
>       if (params->padding[1]) FIXME( "padding[1] is %#x\n", params->padding[1] );
>       if (params->padding[2]) FIXME( "padding[2] is %#x\n", params->padding[2] );
> @@ -793,6 +792,7 @@ static NTSTATUS sock_poll( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, voi
>       SERVER_START_REQ( poll_socket )
>       {
>           req->async = server_async( handle, &async->io, event, apc, apc_user, iosb_client_ptr(io) );
> +        req->exclusive = !!params->exclusive;
>           req->timeout = params->timeout;
>           wine_server_add_data( req, input, params->count * sizeof(*input) );
>           wine_server_set_reply( req, async->sockets, params->count * sizeof(async->sockets[0]) );
> diff --git a/include/wine/afd.h b/include/wine/afd.h
> index e67ecae25a9..f4682f464e8 100644
> --- a/include/wine/afd.h
> +++ b/include/wine/afd.h
> @@ -104,7 +104,7 @@ struct afd_poll_params
>   {
>       LONGLONG timeout;
>       unsigned int count;
> -    BOOLEAN unknown;
> +    BOOLEAN exclusive;
>       BOOLEAN padding[3];
>       struct
>       {
> diff --git a/include/wine/server_protocol.h b/include/wine/server_protocol.h
> index bb4862c9669..14518f0bf6f 100644
> --- a/include/wine/server_protocol.h
> +++ b/include/wine/server_protocol.h
> @@ -1761,7 +1761,8 @@ struct poll_socket_output
>   struct poll_socket_request
>   {
>       struct request_header __header;
> -    char __pad_12[4];
> +    char         exclusive;
> +    char __pad_13[3];
>       async_data_t async;
>       timeout_t    timeout;
>       /* VARARG(sockets,poll_socket_input); */
> @@ -6252,7 +6253,7 @@ union generic_reply
>   
>   /* ### protocol_version begin ### */
>   
> -#define SERVER_PROTOCOL_VERSION 726
> +#define SERVER_PROTOCOL_VERSION 727
>   
>   /* ### protocol_version end ### */
>   
> diff --git a/server/protocol.def b/server/protocol.def
> index 133d6ad0552..8eefdaae17f 100644
> --- a/server/protocol.def
> +++ b/server/protocol.def
> @@ -1450,6 +1450,7 @@ struct poll_socket_output
>   
>   /* Perform an async poll on a socket */
>   @REQ(poll_socket)
> +    char         exclusive;

If you look at the generated server/trace.c output, this generates %c, 
which isn't really what we want. int is fine here.

>       async_data_t async;         /* async I/O parameters */
>       timeout_t    timeout;       /* timeout */
>       VARARG(sockets,poll_socket_input); /* list of sockets to poll */
> diff --git a/server/request.h b/server/request.h
> index 18a6a2df3c7..f1c8f1c5432 100644
> --- a/server/request.h
> +++ b/server/request.h
> @@ -1045,6 +1045,7 @@ C_ASSERT( sizeof(struct recv_socket_request) == 64 );
>   C_ASSERT( FIELD_OFFSET(struct recv_socket_reply, wait) == 8 );
>   C_ASSERT( FIELD_OFFSET(struct recv_socket_reply, options) == 12 );
>   C_ASSERT( sizeof(struct recv_socket_reply) == 16 );
> +C_ASSERT( FIELD_OFFSET(struct poll_socket_request, exclusive) == 12 );
>   C_ASSERT( FIELD_OFFSET(struct poll_socket_request, async) == 16 );
>   C_ASSERT( FIELD_OFFSET(struct poll_socket_request, timeout) == 56 );
>   C_ASSERT( sizeof(struct poll_socket_request) == 64 );
> diff --git a/server/sock.c b/server/sock.c
> index 50bfc08e145..7cdbed940c6 100644
> --- a/server/sock.c
> +++ b/server/sock.c
> @@ -128,6 +128,7 @@ struct poll_req
>       struct async *async;
>       struct iosb *iosb;
>       struct timeout_user *timeout;
> +    char exclusive;
>       unsigned int count;
>       struct poll_socket_output *output;
>       struct
> @@ -2853,8 +2854,8 @@ static int poll_single_socket( struct sock *sock, int mask )
>       return get_poll_flags( sock, pollfd.revents ) & mask;
>   }
>   
> -static int poll_socket( struct sock *poll_sock, struct async *async, timeout_t timeout,
> -                        unsigned int count, const struct poll_socket_input *input )
> +static int poll_socket( struct sock *poll_sock, struct async *async, char exclusive,
> +                        timeout_t timeout, unsigned int count, const struct poll_socket_input *input )
>   {
>       struct poll_socket_output *output;
>       struct poll_req *req;
> @@ -2893,11 +2894,79 @@ static int poll_socket( struct sock *poll_sock, struct async *async, timeout_t t
>           req->sockets[i].flags = input[i].flags;
>       }
>   
> +    req->exclusive = exclusive;
>       req->count = count;
>       req->async = (struct async *)grab_object( async );
>       req->iosb = async_get_iosb( async );
>       req->output = output;
>   
> +    if (exclusive)
> +    {
> +        int has_shared = FALSE;
> +        struct poll_req *areq;
> +
> +        LIST_FOR_EACH_ENTRY( areq, &poll_list, struct poll_req, entry )
> +        {
> +            int skip_areq = FALSE;
> +
> +            for (i = 0; i < areq->count; ++i)
> +            {
> +                struct sock *asock = areq->sockets[i].sock;
> +
> +                for (j = 0; j < req->count; ++j)
> +                {
> +                    if (asock == req->sockets[j].sock)
> +                    {
> +                        if (!areq->exclusive)
> +                            has_shared = TRUE;
> +
> +                        skip_areq = TRUE;
> +                        break;
> +                    }
> +                }
> +
> +                if (skip_areq || has_shared)
> +                    break;
> +            }
> +
> +            if (has_shared)
> +                break;
> +        }
> +
> +        LIST_FOR_EACH_ENTRY( areq, &poll_list, struct poll_req, entry )
> +        {
> +            int areq_term = FALSE;
> +
> +            if (has_shared)
> +                break;

Wait, so we don't wake up *any* exclusive pollers if there are *any* 
shared pollers? This definitely deserves a test.

> +
> +            for (i = 0; i < areq->count; ++i)
> +            {
> +                struct sock *asock = areq->sockets[i].sock;
> +
> +                for (j = 0; j < req->count; ++j)
> +                {
> +                    if (asock == req->sockets[j].sock)
> +                    {
> +                        areq_term = TRUE;
> +                        break;
> +                    }
> +                }
> +
> +                if (areq_term)
> +                    break;
> +            }
> +
> +            if (areq_term)
> +            {
> +                areq->iosb->status = STATUS_SUCCESS;
> +                areq->iosb->out_data = areq->output;
> +                areq->iosb->out_size = areq->count * sizeof(*areq->output);
> +                async_terminate( areq->async, STATUS_ALERTED );
> +            }
> +        }

These loops are not particularly easy to read, and would probably be a 
lot better with a helper function.

> +    }
> +
>       list_add_tail( &poll_list, &req->entry );
>       async_set_completion_callback( async, free_poll_req, req );
>       queue_async( &poll_sock->poll_q, async );
> @@ -3312,7 +3381,7 @@ DECL_HANDLER(poll_socket)
>   
>       if ((async = create_request_async( sock->fd, get_fd_comp_flags( sock->fd ), &req->async )))
>       {
> -        reply->wait = async_handoff( async, poll_socket( sock, async, req->timeout, count, input ), NULL, 0 );
> +        reply->wait = async_handoff( async, poll_socket( sock, async, req->exclusive, req->timeout, count, input ), NULL, 0 );
>           reply->options = get_fd_options( sock->fd );
>           release_object( async );
>       }
> diff --git a/server/trace.c b/server/trace.c
> index 4e91d933a5d..567c54c4df7 100644
> --- a/server/trace.c
> +++ b/server/trace.c
> @@ -2122,7 +2122,8 @@ static void dump_recv_socket_reply( const struct recv_socket_reply *req )
>   
>   static void dump_poll_socket_request( const struct poll_socket_request *req )
>   {
> -    dump_async_data( " async=", &req->async );
> +    fprintf( stderr, " exclusive=%c", req->exclusive );
> +    dump_async_data( ", async=", &req->async );
>       dump_timeout( ", timeout=", &req->timeout );
>       dump_varargs_poll_socket_input( ", sockets=", cur_size );
>   }
> 



More information about the wine-devel mailing list