[PATCH 11/11] server: Handle the entire IOCTL_AFD_POLL ioctl on the server side.
Zebediah Figura
zfigura at codeweavers.com
Thu Dec 9 21:44:01 CST 2021
Signed-off-by: Zebediah Figura <zfigura at codeweavers.com>
---
dlls/ntdll/unix/socket.c | 153 +-------------------------------
dlls/ws2_32/tests/afd.c | 2 +-
include/wine/afd.h | 30 ++++++-
server/protocol.def | 25 ------
server/sock.c | 182 ++++++++++++++++++++++++++++-----------
server/trace.c | 32 -------
6 files changed, 167 insertions(+), 257 deletions(-)
diff --git a/dlls/ntdll/unix/socket.c b/dlls/ntdll/unix/socket.c
index 20368ad6415..92374e39db7 100644
--- a/dlls/ntdll/unix/socket.c
+++ b/dlls/ntdll/unix/socket.c
@@ -770,151 +770,6 @@ static NTSTATUS sock_recv( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, voi
}
-struct async_poll_ioctl
-{
- struct async_fileio io;
- unsigned int count;
- struct afd_poll_params *input, *output;
- struct poll_socket_output sockets[1];
-};
-
-static ULONG_PTR fill_poll_output( struct async_poll_ioctl *async, NTSTATUS status )
-{
- struct afd_poll_params *input = async->input, *output = async->output;
- unsigned int i, count = 0;
-
- memcpy( output, input, offsetof( struct afd_poll_params, sockets[0] ) );
-
- if (!status)
- {
- for (i = 0; i < async->count; ++i)
- {
- if (async->sockets[i].flags)
- {
- output->sockets[count].socket = input->sockets[i].socket;
- output->sockets[count].flags = async->sockets[i].flags;
- output->sockets[count].status = async->sockets[i].status;
- ++count;
- }
- }
- }
- output->count = count;
- return offsetof( struct afd_poll_params, sockets[count] );
-}
-
-static BOOL async_poll_proc( void *user, ULONG_PTR *info, NTSTATUS *status )
-{
- struct async_poll_ioctl *async = user;
-
- if (*status == STATUS_ALERTED)
- {
- SERVER_START_REQ( get_async_result )
- {
- req->user_arg = wine_server_client_ptr( async );
- wine_server_set_reply( req, async->sockets, async->count * sizeof(async->sockets[0]) );
- *status = wine_server_call( req );
- }
- SERVER_END_REQ;
-
- *info = fill_poll_output( async, *status );
- }
-
- free( async->input );
- release_fileio( &async->io );
- return TRUE;
-}
-
-
-/* we could handle this ioctl entirely on the server side, but the differing
- * structure size makes it painful */
-static NTSTATUS sock_poll( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc_user, IO_STATUS_BLOCK *io,
- void *in_buffer, ULONG in_size, void *out_buffer, ULONG out_size )
-{
- const struct afd_poll_params *params = in_buffer;
- struct poll_socket_input *input;
- struct async_poll_ioctl *async;
- HANDLE wait_handle;
- DWORD async_size;
- NTSTATUS status;
- unsigned int i;
- ULONG options;
-
- if (in_size < sizeof(*params) || out_size < in_size || !params->count
- || in_size < offsetof( struct afd_poll_params, sockets[params->count] ))
- return STATUS_INVALID_PARAMETER;
-
- TRACE( "timeout %s, count %u, exclusive %#x, padding (%#x, %#x, %#x), sockets[0] {%04lx, %#x}\n",
- wine_dbgstr_longlong(params->timeout), params->count, params->exclusive,
- params->padding[0], params->padding[1], params->padding[2],
- params->sockets[0].socket, params->sockets[0].flags );
-
- if (params->padding[0]) FIXME( "padding[0] is %#x\n", params->padding[0] );
- if (params->padding[1]) FIXME( "padding[1] is %#x\n", params->padding[1] );
- if (params->padding[2]) FIXME( "padding[2] is %#x\n", params->padding[2] );
- for (i = 0; i < params->count; ++i)
- {
- if (params->sockets[i].flags & ~0x1ff)
- FIXME( "unknown socket flags %#x\n", params->sockets[i].flags );
- }
-
- if (!(input = malloc( params->count * sizeof(*input) )))
- return STATUS_NO_MEMORY;
-
- async_size = offsetof( struct async_poll_ioctl, sockets[params->count] );
-
- if (!(async = (struct async_poll_ioctl *)alloc_fileio( async_size, async_poll_proc, handle )))
- {
- free( input );
- return STATUS_NO_MEMORY;
- }
-
- if (!(async->input = malloc( in_size )))
- {
- release_fileio( &async->io );
- free( input );
- return STATUS_NO_MEMORY;
- }
- memcpy( async->input, in_buffer, in_size );
-
- async->count = params->count;
- async->output = out_buffer;
-
- for (i = 0; i < params->count; ++i)
- {
- input[i].socket = params->sockets[i].socket;
- input[i].flags = params->sockets[i].flags;
- }
-
- SERVER_START_REQ( poll_socket )
- {
- req->async = server_async( handle, &async->io, event, apc, apc_user, iosb_client_ptr(io) );
- req->exclusive = !!params->exclusive;
- req->timeout = params->timeout;
- wine_server_add_data( req, input, params->count * sizeof(*input) );
- wine_server_set_reply( req, async->sockets, params->count * sizeof(async->sockets[0]) );
- status = wine_server_call( req );
- wait_handle = wine_server_ptr_handle( reply->wait );
- options = reply->options;
- if (wait_handle && status != STATUS_PENDING)
- {
- io->Status = status;
- io->Information = fill_poll_output( async, status );
- }
- }
- SERVER_END_REQ;
-
- free( input );
-
- if (status != STATUS_PENDING)
- {
- free( async->input );
- release_fileio( &async->io );
- }
-
- if (wait_handle) status = wait_async( wait_handle, (options & FILE_SYNCHRONOUS_IO_ALERT) );
- return status;
-}
-
static NTSTATUS try_send( int fd, struct async_send_ioctl *async )
{
union unix_sockaddr unix_addr;
@@ -1374,6 +1229,10 @@ NTSTATUS sock_ioctl( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc
status = STATUS_BAD_DEVICE_TYPE;
break;
+ case IOCTL_AFD_POLL:
+ status = STATUS_BAD_DEVICE_TYPE;
+ break;
+
case IOCTL_AFD_RECV:
{
struct afd_recv_params params;
@@ -1516,10 +1375,6 @@ NTSTATUS sock_ioctl( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, void *apc
break;
}
- case IOCTL_AFD_POLL:
- status = sock_poll( handle, event, apc, apc_user, io, in_buffer, in_size, out_buffer, out_size );
- break;
-
case IOCTL_AFD_WINE_FIONREAD:
{
int value, ret;
diff --git a/dlls/ws2_32/tests/afd.c b/dlls/ws2_32/tests/afd.c
index 17126fbdf7e..55029bfd30d 100644
--- a/dlls/ws2_32/tests/afd.c
+++ b/dlls/ws2_32/tests/afd.c
@@ -688,7 +688,7 @@ static void test_poll(void)
ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
ok(ret == STATUS_INVALID_HANDLE, "got %#x\n", ret);
- todo_wine ok(!io.Status, "got %#x\n", io.Status);
+ ok(!io.Status, "got %#x\n", io.Status);
ok(!io.Information, "got %#Ix\n", io.Information);
/* Test passing the same handle twice. */
diff --git a/include/wine/afd.h b/include/wine/afd.h
index 97128c67ca6..efd5787e90a 100644
--- a/include/wine/afd.h
+++ b/include/wine/afd.h
@@ -122,13 +122,41 @@ struct afd_poll_params
unsigned int count;
BOOLEAN exclusive;
BOOLEAN padding[3];
- struct
+ struct afd_poll_socket
{
SOCKET socket;
int flags;
NTSTATUS status;
} sockets[1];
};
+
+struct afd_poll_params_64
+{
+ LONGLONG timeout;
+ unsigned int count;
+ BOOLEAN exclusive;
+ BOOLEAN padding[3];
+ struct afd_poll_socket_64
+ {
+ ULONGLONG socket;
+ int flags;
+ NTSTATUS status;
+ } sockets[1];
+};
+
+struct afd_poll_params_32
+{
+ LONGLONG timeout;
+ unsigned int count;
+ BOOLEAN exclusive;
+ BOOLEAN padding[3];
+ struct afd_poll_socket_32
+ {
+ ULONG socket;
+ int flags;
+ NTSTATUS status;
+ } sockets[1];
+};
#include <poppack.h>
struct afd_event_select_params
diff --git a/server/protocol.def b/server/protocol.def
index efe0d22cbc4..db73f0418a9 100644
--- a/server/protocol.def
+++ b/server/protocol.def
@@ -1443,31 +1443,6 @@ enum server_fd_type
@END
-struct poll_socket_input
-{
- obj_handle_t socket; /* socket handle */
- int flags; /* events to poll for */
-};
-
-struct poll_socket_output
-{
- int flags; /* events signaled */
- unsigned int status; /* socket status */
-};
-
-/* Perform an async poll on a socket */
- at REQ(poll_socket)
- int exclusive; /* is the poll exclusive? */
- async_data_t async; /* async I/O parameters */
- timeout_t timeout; /* timeout */
- VARARG(sockets,poll_socket_input); /* list of sockets to poll */
- at REPLY
- obj_handle_t wait; /* handle to wait on for blocking poll */
- unsigned int options; /* file open options */
- VARARG(sockets,poll_socket_output); /* data returned */
- at END
-
-
/* Perform a send on a socket */
@REQ(send_socket)
async_data_t async; /* async I/O parameters */
diff --git a/server/sock.c b/server/sock.c
index d674d2a8f84..4f728ef74dc 100644
--- a/server/sock.c
+++ b/server/sock.c
@@ -122,13 +122,16 @@ struct poll_req
struct async *async;
struct iosb *iosb;
struct timeout_user *timeout;
+ timeout_t orig_timeout;
int exclusive;
unsigned int count;
- struct poll_socket_output *output;
struct
{
struct sock *sock;
int mask;
+ obj_handle_t handle;
+ int flags;
+ unsigned int status;
} sockets[1];
};
@@ -235,6 +238,8 @@ static int accept_into_socket( struct sock *sock, struct sock *acceptsock );
static struct sock *accept_socket( struct sock *sock );
static int sock_get_ntstatus( int err );
static unsigned int sock_get_error( int err );
+static void poll_socket( struct sock *poll_sock, struct async *async, int exclusive, timeout_t timeout,
+ unsigned int count, const struct afd_poll_socket_64 *sockets );
static const struct object_ops sock_ops =
{
@@ -789,7 +794,6 @@ static void free_poll_req( void *private )
release_object( req->async );
release_object( req->iosb );
list_remove( &req->entry );
- free( req->output );
free( req );
}
@@ -832,8 +836,7 @@ static int get_poll_flags( struct sock *sock, int event )
static void complete_async_poll( struct poll_req *req, unsigned int status )
{
- struct poll_socket_output *output = req->output;
- unsigned int i;
+ unsigned int i, signaled_count = 0;;
for (i = 0; i < req->count; ++i)
{
@@ -843,9 +846,65 @@ static void complete_async_poll( struct poll_req *req, unsigned int status )
sock->main_poll = NULL;
}
- /* pass 0 as result; client will set actual result size */
- req->output = NULL;
- async_request_complete( req->async, status, 0, req->count * sizeof(*output), output );
+ if (!status)
+ {
+ for (i = 0; i < req->count; ++i)
+ {
+ if (req->sockets[i].flags)
+ ++signaled_count;
+ }
+ }
+
+ if (is_machine_64bit( async_get_thread( req->async )->process->machine ))
+ {
+ size_t output_size = offsetof( struct afd_poll_params_64, sockets[signaled_count] );
+ struct afd_poll_params_64 *output;
+
+ if (!(output = mem_alloc( output_size )))
+ {
+ async_terminate( req->async, get_error() );
+ return;
+ }
+ memset( output, 0, output_size );
+ output->timeout = req->orig_timeout;
+ output->exclusive = req->exclusive;
+ for (i = 0; i < req->count; ++i)
+ {
+ if (!req->sockets[i].flags) continue;
+ output->sockets[output->count].socket = req->sockets[i].handle;
+ output->sockets[output->count].flags = req->sockets[i].flags;
+ output->sockets[output->count].status = req->sockets[i].status;
+ ++output->count;
+ }
+ assert( output->count == signaled_count );
+
+ async_request_complete( req->async, status, output_size, output_size, output );
+ }
+ else
+ {
+ size_t output_size = offsetof( struct afd_poll_params_32, sockets[signaled_count] );
+ struct afd_poll_params_32 *output;
+
+ if (!(output = mem_alloc( output_size )))
+ {
+ async_terminate( req->async, get_error() );
+ return;
+ }
+ memset( output, 0, output_size );
+ output->timeout = req->orig_timeout;
+ output->exclusive = req->exclusive;
+ for (i = 0; i < req->count; ++i)
+ {
+ if (!req->sockets[i].flags) continue;
+ output->sockets[output->count].socket = req->sockets[i].handle;
+ output->sockets[output->count].flags = req->sockets[i].flags;
+ output->sockets[output->count].status = req->sockets[i].status;
+ ++output->count;
+ }
+ assert( output->count == signaled_count );
+
+ async_request_complete( req->async, status, output_size, output_size, output );
+ }
}
static void complete_async_polls( struct sock *sock, int event, int error )
@@ -868,8 +927,8 @@ static void complete_async_polls( struct sock *sock, int event, int error )
fprintf( stderr, "completing poll for socket %p, wanted %#x got %#x\n",
sock, req->sockets[i].mask, flags );
- req->output[i].flags = req->sockets[i].mask & flags;
- req->output[i].status = sock_get_ntstatus( error );
+ req->sockets[i].flags = req->sockets[i].mask & flags;
+ req->sockets[i].status = sock_get_ntstatus( error );
complete_async_poll( req, STATUS_SUCCESS );
break;
@@ -1353,8 +1412,8 @@ static int sock_close_handle( struct object *obj, struct process *process, obj_h
if (poll_req->sockets[i].sock == sock)
{
signaled = TRUE;
- poll_req->output[i].flags = AFD_POLL_CLOSE;
- poll_req->output[i].status = 0;
+ poll_req->sockets[i].flags = AFD_POLL_CLOSE;
+ poll_req->sockets[i].status = 0;
}
}
@@ -2849,6 +2908,55 @@ static void sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async )
return;
}
+ case IOCTL_AFD_POLL:
+ {
+ if (get_reply_max_size() < get_req_data_size())
+ {
+ set_error( STATUS_INVALID_PARAMETER );
+ return;
+ }
+
+ if (is_machine_64bit( current->process->machine ))
+ {
+ const struct afd_poll_params_64 *params = get_req_data();
+
+ if (get_req_data_size() < sizeof(struct afd_poll_params_64) ||
+ get_req_data_size() < offsetof( struct afd_poll_params_64, sockets[params->count] ))
+ {
+ set_error( STATUS_INVALID_PARAMETER );
+ return;
+ }
+
+ poll_socket( sock, async, params->exclusive, params->timeout, params->count, params->sockets );
+ }
+ else
+ {
+ const struct afd_poll_params_32 *params = get_req_data();
+ struct afd_poll_socket_64 *sockets;
+ unsigned int i;
+
+ if (get_req_data_size() < sizeof(struct afd_poll_params_32) ||
+ get_req_data_size() < offsetof( struct afd_poll_params_32, sockets[params->count] ))
+ {
+ set_error( STATUS_INVALID_PARAMETER );
+ return;
+ }
+
+ if (!(sockets = mem_alloc( params->count * sizeof(*sockets) ))) return;
+ for (i = 0; i < params->count; ++i)
+ {
+ sockets[i].socket = params->sockets[i].socket;
+ sockets[i].flags = params->sockets[i].flags;
+ sockets[i].status = params->sockets[i].status;
+ }
+
+ poll_socket( sock, async, params->exclusive, params->timeout, params->count, sockets );
+ free( sockets );
+ }
+
+ return;
+ }
+
default:
set_error( STATUS_NOT_SUPPORTED );
return;
@@ -2899,51 +3007,49 @@ static void handle_exclusive_poll(struct poll_req *req)
}
static void poll_socket( struct sock *poll_sock, struct async *async, int exclusive, timeout_t timeout,
- unsigned int count, const struct poll_socket_input *input )
+ unsigned int count, const struct afd_poll_socket_64 *sockets )
{
- struct poll_socket_output *output;
BOOL signaled = FALSE;
struct poll_req *req;
unsigned int i, j;
- if (!(output = mem_alloc( count * sizeof(*output) )))
+ if (!count)
+ {
+ set_error( STATUS_INVALID_PARAMETER );
return;
- memset( output, 0, count * sizeof(*output) );
+ }
if (!(req = mem_alloc( offsetof( struct poll_req, sockets[count] ) )))
- {
- free( output );
return;
- }
req->timeout = NULL;
if (timeout && timeout != TIMEOUT_INFINITE &&
!(req->timeout = add_timeout_user( timeout, async_poll_timeout, req )))
{
free( req );
- free( output );
return;
}
+ req->orig_timeout = timeout;
for (i = 0; i < count; ++i)
{
- req->sockets[i].sock = (struct sock *)get_handle_obj( current->process, input[i].socket, 0, &sock_ops );
+ req->sockets[i].sock = (struct sock *)get_handle_obj( current->process, sockets[i].socket, 0, &sock_ops );
if (!req->sockets[i].sock)
{
for (j = 0; j < i; ++j) release_object( req->sockets[i].sock );
if (req->timeout) remove_timeout_user( req->timeout );
free( req );
- free( output );
return;
}
- req->sockets[i].mask = input[i].flags;
+ req->sockets[i].handle = sockets[i].socket;
+ req->sockets[i].mask = sockets[i].flags;
+ req->sockets[i].flags = 0;
}
req->exclusive = exclusive;
req->count = count;
req->async = (struct async *)grab_object( async );
req->iosb = async_get_iosb( async );
- req->output = output;
handle_exclusive_poll(req);
@@ -2960,16 +3066,16 @@ static void poll_socket( struct sock *poll_sock, struct async *async, int exclus
if (flags)
{
signaled = TRUE;
- output[i].flags = flags;
- output[i].status = sock_get_ntstatus( sock_error( sock->fd ) );
+ req->sockets[i].flags = flags;
+ req->sockets[i].status = sock_get_ntstatus( sock_error( sock->fd ) );
}
/* FIXME: do other error conditions deserve a similar treatment? */
if (sock->state != SOCK_CONNECTING && sock->errors[AFD_POLL_BIT_CONNECT_ERR] && (mask & AFD_POLL_CONNECT_ERR))
{
signaled = TRUE;
- output[i].flags |= AFD_POLL_CONNECT_ERR;
- output[i].status = sock_get_ntstatus( sock->errors[AFD_POLL_BIT_CONNECT_ERR] );
+ req->sockets[i].flags |= AFD_POLL_CONNECT_ERR;
+ req->sockets[i].status = sock_get_ntstatus( sock->errors[AFD_POLL_BIT_CONNECT_ERR] );
}
}
@@ -3335,28 +3441,6 @@ DECL_HANDLER(recv_socket)
release_object( sock );
}
-DECL_HANDLER(poll_socket)
-{
- struct sock *sock = (struct sock *)get_handle_obj( current->process, req->async.handle, 0, &sock_ops );
- const struct poll_socket_input *input = get_req_data();
- struct async *async;
- unsigned int count;
-
- if (!sock) return;
-
- count = get_req_data_size() / sizeof(*input);
-
- if ((async = create_request_async( sock->fd, get_fd_comp_flags( sock->fd ), &req->async )))
- {
- poll_socket( sock, async, req->exclusive, req->timeout, count, input );
- reply->wait = async_handoff( async, NULL, 0 );
- reply->options = get_fd_options( sock->fd );
- release_object( async );
- }
-
- release_object( sock );
-}
-
DECL_HANDLER(send_socket)
{
struct sock *sock = (struct sock *)get_handle_obj( current->process, req->async.handle, 0, &sock_ops );
diff --git a/server/trace.c b/server/trace.c
index f7c9cbd975e..3d0f39cb30b 100644
--- a/server/trace.c
+++ b/server/trace.c
@@ -1394,38 +1394,6 @@ static void dump_varargs_handle_infos( const char *prefix, data_size_t size )
fputc( '}', stderr );
}
-static void dump_varargs_poll_socket_input( const char *prefix, data_size_t size )
-{
- const struct poll_socket_input *input;
-
- fprintf( stderr, "%s{", prefix );
- while (size >= sizeof(*input))
- {
- input = cur_data;
- fprintf( stderr, "{socket=%04x,flags=%08x}", input->socket, input->flags );
- size -= sizeof(*input);
- remove_data( sizeof(*input) );
- if (size) fputc( ',', stderr );
- }
- fputc( '}', stderr );
-}
-
-static void dump_varargs_poll_socket_output( const char *prefix, data_size_t size )
-{
- const struct poll_socket_output *output;
-
- fprintf( stderr, "%s{", prefix );
- while (size >= sizeof(*output))
- {
- output = cur_data;
- fprintf( stderr, "{flags=%08x,status=%s}", output->flags, get_status_name( output->status ) );
- size -= sizeof(*output);
- remove_data( sizeof(*output) );
- if (size) fputc( ',', stderr );
- }
- fputc( '}', stderr );
-}
-
typedef void (*dump_func)( const void *req );
/* Everything below this line is generated automatically by tools/make_requests */
--
2.34.1
More information about the wine-devel
mailing list