[PATCH] netio.sys: Fill socket addresses when accepting connection.
Paul Gofman
pgofman at codeweavers.com
Wed Jun 24 07:09:02 CDT 2020
Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
dlls/netio.sys/netio.c | 35 ++++++++++++++++++++++---------
dlls/ntoskrnl.exe/tests/driver4.c | 18 ++++++++++++++--
2 files changed, 41 insertions(+), 12 deletions(-)
diff --git a/dlls/netio.sys/netio.c b/dlls/netio.sys/netio.c
index dd6d2f48140..b54819664de 100644
--- a/dlls/netio.sys/netio.c
+++ b/dlls/netio.sys/netio.c
@@ -46,6 +46,7 @@ struct _WSK_CLIENT
struct listen_socket_callback_context
{
+ SOCKADDR *local_address;
SOCKADDR *remote_address;
const void *client_dispatch;
void *client_context;
@@ -53,15 +54,6 @@ struct listen_socket_callback_context
SOCKET acceptor;
};
-struct connect_socket_callback_context
-{
- struct wsk_socket_internal *socket;
- SOCKADDR *remote_address;
- const void *client_dispatch;
- void *client_context;
- IRP *pending_irp;
-};
-
#define MAX_PENDING_IO 10
struct wsk_pending_io
@@ -96,6 +88,7 @@ struct wsk_socket_internal
};
static LPFN_ACCEPTEX pAcceptEx;
+static LPFN_GETACCEPTEXSOCKADDRS pGetAcceptExSockaddrs;
static LPFN_CONNECTEX pConnectEx;
static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch;
@@ -319,6 +312,8 @@ static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_
{
struct listen_socket_callback_context *context
= &socket->callback_context.listen_socket_callback_context;
+ INT local_address_len, remote_address_len;
+ SOCKADDR *local_address, *remote_address;
struct wsk_socket_internal *accept_socket;
if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket))))
@@ -338,7 +333,17 @@ static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_
accept_socket->protocol = socket->protocol;
accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET;
socket_init(accept_socket);
- /* TODO: fill local and remote addresses. */
+
+ pGetAcceptExSockaddrs(context->addr_buffer, 0, sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16,
+ &local_address, &local_address_len, &remote_address, &remote_address_len);
+
+ if (context->local_address)
+ memcpy(context->local_address, local_address,
+ min(sizeof(*context->local_address), local_address_len));
+
+ if (context->remote_address)
+ memcpy(context->remote_address, remote_address,
+ min(sizeof(*context->remote_address), remote_address_len));
dispatch_pending_io(io, STATUS_SUCCESS, (ULONG_PTR)&accept_socket->wsk_socket);
}
@@ -373,6 +378,7 @@ static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_
static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **context)
{
+ GUID get_acceptex_guid = WSAID_GETACCEPTEXSOCKADDRS;
GUID acceptex_guid = WSAID_ACCEPTEX;
SOCKET s = (SOCKET)param;
DWORD size;
@@ -383,6 +389,14 @@ static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **co
ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
return FALSE;
}
+
+ if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &get_acceptex_guid, sizeof(get_acceptex_guid),
+ &pGetAcceptExSockaddrs, sizeof(pGetAcceptExSockaddrs), &size, NULL, NULL))
+ {
+ ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
+ return FALSE;
+ }
+
return TRUE;
}
@@ -430,6 +444,7 @@ static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *
return STATUS_PENDING;
}
+ context->local_address = local_address;
context->remote_address = remote_address;
context->client_dispatch = accept_socket_dispatch;
context->client_context = accept_socket_context;
diff --git a/dlls/ntoskrnl.exe/tests/driver4.c b/dlls/ntoskrnl.exe/tests/driver4.c
index dad1d6a04fd..48de151f31f 100644
--- a/dlls/ntoskrnl.exe/tests/driver4.c
+++ b/dlls/ntoskrnl.exe/tests/driver4.c
@@ -177,10 +177,10 @@ static void test_wsk_listen_socket(void)
static const WSK_CLIENT_LISTEN_DISPATCH client_listen_dispatch;
const WSK_PROVIDER_CONNECTION_DISPATCH *accept_dispatch;
WSK_SOCKET *tcp_socket, *udp_socket, *accept_socket;
+ struct sockaddr_in addr, local_addr, remote_addr;
struct socket_context context;
WSK_BUF wsk_buf1, wsk_buf2;
void *buffer1, *buffer2;
- struct sockaddr_in addr;
LARGE_INTEGER timeout;
MDL *mdl1, *mdl2;
NTSTATUS status;
@@ -287,7 +287,10 @@ static void test_wsk_listen_socket(void)
IoReuseIrp(wsk_irp, STATUS_UNSUCCESSFUL);
IoSetCompletionRoutine(wsk_irp, irp_completion_routine, &irp_complete_event, TRUE, TRUE, TRUE);
- status = tcp_dispatch->WskAccept(tcp_socket, 0, NULL, NULL, NULL, NULL, wsk_irp);
+ memset(&local_addr, 0, sizeof(local_addr));
+ memset(&remote_addr, 0, sizeof(remote_addr));
+ status = tcp_dispatch->WskAccept(tcp_socket, 0, NULL, NULL,
+ (SOCKADDR *)&local_addr, (SOCKADDR *)&remote_addr, wsk_irp);
ok(status == STATUS_PENDING, "Got unexpected status %#x.\n", status);
if (0)
@@ -306,6 +309,17 @@ static void test_wsk_listen_socket(void)
if (status == STATUS_SUCCESS && wsk_irp->IoStatus.Status == STATUS_SUCCESS)
{
+ ok(local_addr.sin_family == AF_INET, "Got unexpected sin_family %u.\n", local_addr.sin_family);
+ ok(local_addr.sin_port == htons(SERVER_LISTEN_PORT), "Got unexpected sin_port %u.\n",
+ ntohs(local_addr.sin_port));
+ ok(local_addr.sin_addr.s_addr == htonl(0x7f000001), "Got unexpected sin_addr %#x.\n",
+ ntohl(local_addr.sin_addr.s_addr));
+
+ ok(remote_addr.sin_family == AF_INET, "Got unexpected sin_family %u.\n", remote_addr.sin_family);
+ ok(remote_addr.sin_port, "Got zero sin_port.\n");
+ ok(remote_addr.sin_addr.s_addr == htonl(0x7f000001), "Got unexpected sin_addr %#x.\n",
+ ntohl(remote_addr.sin_addr.s_addr));
+
accept_socket = (WSK_SOCKET *)wsk_irp->IoStatus.Information;
accept_dispatch = accept_socket->Dispatch;
--
2.26.2
More information about the wine-devel
mailing list