[PATCH 4/4] ws2_32/tests: Add tests for IOCTL_AFD_POLL.

Zebediah Figura z.figura12 at gmail.com
Mon May 24 23:13:42 CDT 2021


Signed-off-by: Zebediah Figura <z.figura12 at gmail.com>
---
 dlls/ws2_32/tests/afd.c | 683 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 683 insertions(+)

diff --git a/dlls/ws2_32/tests/afd.c b/dlls/ws2_32/tests/afd.c
index 48b177ee845..13d0f3942d4 100644
--- a/dlls/ws2_32/tests/afd.c
+++ b/dlls/ws2_32/tests/afd.c
@@ -31,6 +31,41 @@
 #include "wine/afd.h"
 #include "wine/test.h"
 
+static void tcp_socketpair(SOCKET *src, SOCKET *dst)
+{
+    SOCKET server = INVALID_SOCKET;
+    struct sockaddr_in addr;
+    int len, ret;
+
+    *src = WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED);
+    ok(*src != INVALID_SOCKET, "failed to create socket, error %u\n", WSAGetLastError());
+
+    server = WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED);
+    ok(server != INVALID_SOCKET, "failed to create socket, error %u\n", WSAGetLastError());
+
+    memset(&addr, 0, sizeof(addr));
+    addr.sin_family = AF_INET;
+    addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+    ret = bind(server, (struct sockaddr *)&addr, sizeof(addr));
+    ok(!ret, "failed to bind socket, error %u\n", WSAGetLastError());
+
+    len = sizeof(addr);
+    ret = getsockname(server, (struct sockaddr *)&addr, &len);
+    ok(!ret, "failed to get address, error %u\n", WSAGetLastError());
+
+    ret = listen(server, 1);
+    ok(!ret, "failed to listen, error %u\n", WSAGetLastError());
+
+    ret = connect(*src, (struct sockaddr *)&addr, sizeof(addr));
+    ok(!ret, "failed to connect, error %u\n", WSAGetLastError());
+
+    len = sizeof(addr);
+    *dst = accept(server, (struct sockaddr *)&addr, &len);
+    ok(*dst != INVALID_SOCKET, "failed to accept socket, error %u\n", WSAGetLastError());
+
+    closesocket(server);
+}
+
 static void set_blocking(SOCKET s, ULONG blocking)
 {
     int ret;
@@ -76,6 +111,652 @@ static void test_open_device(void)
     closesocket(s);
 }
 
+#define check_poll(a, b, c) check_poll_(__LINE__, a, b, c, FALSE)
+#define check_poll_todo(a, b, c) check_poll_(__LINE__, a, b, c, TRUE)
+static void check_poll_(int line, SOCKET s, HANDLE event, int expect, BOOL todo)
+{
+    struct afd_poll_params in_params = {0}, out_params = {0};
+    IO_STATUS_BLOCK io;
+    NTSTATUS ret;
+
+    in_params.timeout = -1000 * 10000;
+    in_params.count = 1;
+    in_params.sockets[0].socket = s;
+    in_params.sockets[0].flags = ~0;
+    in_params.sockets[0].status = 0xdeadbeef;
+
+    ret = NtDeviceIoControlFile((HANDLE)s, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, &in_params, sizeof(in_params), &out_params, sizeof(out_params));
+    ok_(__FILE__, line)(!ret, "got %#x\n", ret);
+    ok_(__FILE__, line)(!io.Status, "got %#x\n", io.Status);
+    ok_(__FILE__, line)(io.Information == sizeof(out_params), "got %#Ix\n", io.Information);
+    ok_(__FILE__, line)(out_params.timeout == in_params.timeout, "got timeout %I64d\n", out_params.timeout);
+    ok_(__FILE__, line)(out_params.count == 1, "got count %u\n", out_params.count);
+    ok_(__FILE__, line)(out_params.sockets[0].socket == s, "got socket %#Ix\n", out_params.sockets[0].socket);
+    todo_wine_if (todo) ok_(__FILE__, line)(out_params.sockets[0].flags == expect, "got flags %#x\n", out_params.sockets[0].flags);
+    ok_(__FILE__, line)(!out_params.sockets[0].status, "got status %#x\n", out_params.sockets[0].status);
+}
+
+static void test_poll(void)
+{
+    const struct sockaddr_in bind_addr = {.sin_family = AF_INET, .sin_addr.s_addr = htonl(INADDR_LOOPBACK)};
+    char in_buffer[offsetof(struct afd_poll_params, sockets[3])];
+    char out_buffer[offsetof(struct afd_poll_params, sockets[3])];
+    struct afd_poll_params *in_params = (struct afd_poll_params *)in_buffer;
+    struct afd_poll_params *out_params = (struct afd_poll_params *)out_buffer;
+    int large_buffer_size = 1024 * 1024;
+    SOCKET client, server, listener;
+    struct sockaddr_in addr;
+    char *large_buffer;
+    IO_STATUS_BLOCK io;
+    LARGE_INTEGER now;
+    ULONG params_size;
+    HANDLE event;
+    int ret, len;
+
+    large_buffer = malloc(large_buffer_size);
+    memset(in_buffer, 0, sizeof(in_buffer));
+    memset(out_buffer, 0, sizeof(out_buffer));
+    event = CreateEventW(NULL, TRUE, FALSE, NULL);
+
+    listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ret = bind(listener, (const struct sockaddr *)&bind_addr, sizeof(bind_addr));
+    ok(!ret, "got error %u\n", WSAGetLastError());
+    ret = listen(listener, 1);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+    len = sizeof(addr);
+    ret = getsockname(listener, (struct sockaddr *)&addr, &len);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+
+    params_size = offsetof(struct afd_poll_params, sockets[1]);
+    in_params->count = 1;
+
+    /* out_size must be at least as large as in_size. */
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, NULL, 0);
+    ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret);
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, NULL, 0, out_params, params_size);
+    ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret);
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size + 1);
+    ok(ret == STATUS_INVALID_HANDLE, "got %#x\n", ret);
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size + 1, out_params, params_size);
+    ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret);
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size - 1, out_params, params_size - 1);
+    ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret);
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size + 1, out_params, params_size + 1);
+    ok(ret == STATUS_INVALID_HANDLE, "got %#x\n", ret);
+
+    in_params->count = 0;
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret);
+
+    /* Basic semantics of the ioctl. */
+
+    in_params->timeout = 0;
+    in_params->count = 1;
+    in_params->sockets[0].socket = listener;
+    in_params->sockets[0].flags = ~0;
+    in_params->sockets[0].status = 0xdeadbeef;
+
+    memset(out_params, 0, params_size);
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information);
+    ok(!out_params->timeout, "got timeout %#I64x\n", out_params->timeout);
+    ok(!out_params->count, "got count %u\n", out_params->count);
+    ok(!out_params->sockets[0].socket, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(!out_params->sockets[0].flags, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    NtQuerySystemTime(&now);
+    in_params->timeout = now.QuadPart;
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information);
+    ok(out_params->timeout == now.QuadPart, "got timeout %#I64x\n", out_params->timeout);
+    ok(!out_params->count, "got count %u\n", out_params->count);
+
+    in_params->timeout = -1000 * 10000;
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    set_blocking(client, FALSE);
+    ret = connect(client, (struct sockaddr *)&addr, sizeof(addr));
+    ok(!ret || WSAGetLastError() == WSAEWOULDBLOCK, "got error %u\n", WSAGetLastError());
+
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->timeout == -1000 * 10000, "got timeout %#I64x\n", out_params->timeout);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == listener, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_ACCEPT, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->timeout == -1000 * 10000, "got timeout %#I64x\n", out_params->timeout);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == listener, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_ACCEPT, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    in_params->timeout = now.QuadPart;
+    in_params->sockets[0].flags = (~0) & ~AFD_POLL_ACCEPT;
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information);
+    ok(!out_params->count, "got count %u\n", out_params->count);
+
+    server = accept(listener, NULL, NULL);
+    ok(server != -1, "got error %u\n", WSAGetLastError());
+    set_blocking(server, FALSE);
+
+    /* Test flags exposed by connected sockets. */
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+    check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+
+    /* It is valid to poll on a socket other than the one passed to
+     * NtDeviceIoControlFile(). */
+
+    in_params->count = 1;
+    in_params->sockets[0].socket = server;
+    in_params->sockets[0].flags = ~0;
+
+    ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == server, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == (AFD_POLL_WRITE | AFD_POLL_CONNECT),
+            "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    /* Test sending data. */
+
+    ret = send(server, "data", 5, 0);
+    ok(ret == 5, "got %d\n", ret);
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ);
+    check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+
+    while (send(server, large_buffer, large_buffer_size, 0) == large_buffer_size);
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ);
+    check_poll(server, event, AFD_POLL_CONNECT);
+
+    /* Test sending out-of-band data. */
+
+    ret = send(client, "a", 1, MSG_OOB);
+    ok(ret == 1, "got %d\n", ret);
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ);
+    check_poll(server, event, AFD_POLL_CONNECT | AFD_POLL_OOB);
+
+    ret = recv(server, large_buffer, 1, MSG_OOB);
+    ok(ret == 1, "got %d\n", ret);
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ);
+    check_poll(server, event, AFD_POLL_CONNECT);
+
+    ret = 1;
+    ret = setsockopt(server, SOL_SOCKET, SO_OOBINLINE, (char *)&ret, sizeof(ret));
+    ok(!ret, "got error %u\n", WSAGetLastError());
+
+    ret = send(client, "a", 1, MSG_OOB);
+    ok(ret == 1, "got %d\n", ret);
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ);
+    check_poll(server, event, AFD_POLL_CONNECT | AFD_POLL_READ);
+
+    closesocket(client);
+    closesocket(server);
+
+    /* Test shutdown. */
+
+    client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ret = connect(client, (struct sockaddr *)&addr, sizeof(addr));
+    ok(!ret, "got error %u\n", WSAGetLastError());
+    server = accept(listener, NULL, NULL);
+    ok(server != -1, "got error %u\n", WSAGetLastError());
+
+    ret = shutdown(client, SD_RECEIVE);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+    check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+
+    ret = shutdown(client, SD_SEND);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+    check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_HUP);
+
+    closesocket(client);
+    closesocket(server);
+
+    /* Test shutdown with data in the pipe. */
+
+    client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ret = connect(client, (struct sockaddr *)&addr, sizeof(addr));
+    ok(!ret, "got error %u\n", WSAGetLastError());
+    server = accept(listener, NULL, NULL);
+    ok(server != -1, "got error %u\n", WSAGetLastError());
+
+    ret = send(client, "data", 5, 0);
+    ok(ret == 5, "got %d\n", ret);
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+    check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ);
+
+    ret = shutdown(client, SD_SEND);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+
+    check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT);
+    check_poll_todo(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ | AFD_POLL_HUP);
+
+    /* Test closing a socket while polling on it. Note that AFD_POLL_CLOSE
+     * is always returned, regardless of whether it's polled for. */
+
+    in_params->timeout = -1000 * 10000;
+    in_params->count = 1;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = 0;
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    closesocket(client);
+
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_CLOSE,
+            "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    closesocket(server);
+
+    /* Test a failed connection.
+     *
+     * The following poll works even where the equivalent WSAPoll() call fails.
+     * However, it can take over 2 seconds to complete on the testbot. */
+
+    if (winetest_interactive)
+    {
+        client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+        set_blocking(client, FALSE);
+
+        in_params->timeout = -10000 * 10000;
+        in_params->count = 1;
+        in_params->sockets[0].socket = client;
+        in_params->sockets[0].flags = ~0;
+
+        ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+                IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+        ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+        addr.sin_port = 255;
+        ret = connect(client, (struct sockaddr *)&addr, sizeof(addr));
+        ok(!ret || WSAGetLastError() == WSAEWOULDBLOCK, "got error %u\n", WSAGetLastError());
+
+        ret = WaitForSingleObject(event, 10000);
+        ok(!ret, "got %#x\n", ret);
+        ok(!io.Status, "got %#x\n", io.Status);
+        ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+        ok(out_params->count == 1, "got count %u\n", out_params->count);
+        ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+        ok(out_params->sockets[0].flags == AFD_POLL_CONNECT_ERR, "got flags %#x\n", out_params->sockets[0].flags);
+        ok(out_params->sockets[0].status == STATUS_CONNECTION_REFUSED, "got status %#x\n", out_params->sockets[0].status);
+
+        closesocket(client);
+    }
+
+    /* Test supplying multiple handles to the ioctl. */
+
+    len = sizeof(addr);
+    ret = getsockname(listener, (struct sockaddr *)&addr, &len);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+
+    client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ret = connect(client, (struct sockaddr *)&addr, sizeof(addr));
+    ok(!ret, "got error %u\n", WSAGetLastError());
+    server = accept(listener, NULL, NULL);
+    ok(server != -1, "got error %u\n", WSAGetLastError());
+
+    in_params->count = 2;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = AFD_POLL_READ;
+    in_params->sockets[1].socket = server;
+    in_params->sockets[1].flags = AFD_POLL_READ;
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret);
+
+    params_size = offsetof(struct afd_poll_params, sockets[2]);
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    ret = send(client, "data", 5, 0);
+    ok(ret == 5, "got %d\n", ret);
+
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == server, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_READ, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    in_params->count = 2;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+    in_params->sockets[1].socket = server;
+    in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[2]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 2, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+    ok(out_params->sockets[1].socket == server, "got socket %#Ix\n", out_params->sockets[1].socket);
+    ok(out_params->sockets[1].flags == (AFD_POLL_READ | AFD_POLL_WRITE),
+            "got flags %#x\n", out_params->sockets[1].flags);
+    ok(!out_params->sockets[1].status, "got status %#x\n", out_params->sockets[1].status);
+
+    in_params->count = 2;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+    in_params->sockets[1].socket = server;
+    in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[2]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 2, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+    ok(out_params->sockets[1].socket == server, "got socket %#Ix\n", out_params->sockets[1].socket);
+    ok(out_params->sockets[1].flags == (AFD_POLL_READ | AFD_POLL_WRITE),
+            "got flags %#x\n", out_params->sockets[1].flags);
+    ok(!out_params->sockets[1].status, "got status %#x\n", out_params->sockets[1].status);
+
+    /* Close a socket while polling on another. */
+
+    in_params->timeout = -100 * 10000;
+    in_params->count = 1;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = AFD_POLL_READ;
+    params_size = offsetof(struct afd_poll_params, sockets[1]);
+
+    ret = NtDeviceIoControlFile((HANDLE)server, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    closesocket(server);
+
+    ret = WaitForSingleObject(event, 1000);
+    ok(!ret, "got %#x\n", ret);
+    todo_wine ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status);
+    todo_wine ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information);
+    todo_wine ok(!out_params->count, "got count %u\n", out_params->count);
+
+    closesocket(client);
+
+    closesocket(listener);
+
+    /* Test UDP sockets. */
+
+    client = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+    server = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+
+    check_poll(client, event, AFD_POLL_WRITE);
+    check_poll(server, event, AFD_POLL_WRITE);
+
+    ret = bind(client, (const struct sockaddr *)&bind_addr, sizeof(bind_addr));
+    ok(!ret, "got error %u\n", WSAGetLastError());
+    len = sizeof(addr);
+    ret = getsockname(listener, (struct sockaddr *)&addr, &len);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+
+    check_poll(client, event, AFD_POLL_WRITE);
+    check_poll(server, event, AFD_POLL_WRITE);
+
+    in_params->timeout = -1000 * 10000;
+    in_params->count = 1;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = (~0) & ~AFD_POLL_WRITE;
+    params_size = offsetof(struct afd_poll_params, sockets[1]);
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    ret = sendto(server, "data", 5, 0, (struct sockaddr *)&addr, sizeof(addr));
+    ok(ret == 5, "got %d\n", ret);
+
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_READ, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    closesocket(client);
+    closesocket(server);
+
+    /* Passing any invalid sockets yields STATUS_INVALID_HANDLE.
+     *
+     * Note however that WSAPoll() happily accepts invalid sockets. It seems
+     * user-side cached data is used: closing a handle with CloseHandle() before
+     * passing it to WSAPoll() yields ENOTSOCK. */
+
+    tcp_socketpair(&client, &server);
+
+    in_params->count = 2;
+    in_params->sockets[0].socket = 0xabacab;
+    in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+    in_params->sockets[1].socket = client;
+    in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+    params_size = offsetof(struct afd_poll_params, sockets[2]);
+
+    memset(&io, 0, sizeof(io));
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_INVALID_HANDLE, "got %#x\n", ret);
+    todo_wine ok(!io.Status, "got %#x\n", io.Status);
+    ok(!io.Information, "got %#Ix\n", io.Information);
+
+    /* Test passing the same handle twice. */
+
+    in_params->count = 3;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+    in_params->sockets[1].socket = client;
+    in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE;
+    in_params->sockets[2].socket = client;
+    in_params->sockets[2].flags = AFD_POLL_READ | AFD_POLL_WRITE | AFD_POLL_CONNECT;
+    params_size = offsetof(struct afd_poll_params, sockets[3]);
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[3]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 3, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+    ok(out_params->sockets[1].socket == client, "got socket %#Ix\n", out_params->sockets[1].socket);
+    ok(out_params->sockets[1].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[1].flags);
+    ok(!out_params->sockets[1].status, "got status %#x\n", out_params->sockets[1].status);
+    ok(out_params->sockets[2].socket == client, "got socket %#Ix\n", out_params->sockets[2].socket);
+    ok(out_params->sockets[2].flags == (AFD_POLL_WRITE | AFD_POLL_CONNECT),
+            "got flags %#x\n", out_params->sockets[2].flags);
+    ok(!out_params->sockets[2].status, "got status %#x\n", out_params->sockets[2].status);
+
+    in_params->count = 2;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = AFD_POLL_READ;
+    in_params->sockets[1].socket = client;
+    in_params->sockets[1].flags = AFD_POLL_READ;
+    params_size = offsetof(struct afd_poll_params, sockets[2]);
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    ret = send(server, "data", 5, 0);
+    ok(ret == 5, "got %d\n", ret);
+
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#x\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+    ok(out_params->sockets[0].flags == AFD_POLL_READ, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    closesocket(client);
+    closesocket(server);
+
+    CloseHandle(event);
+    free(large_buffer);
+}
+
+static void test_poll_completion_port(void)
+{
+    struct afd_poll_params params = {0};
+    LARGE_INTEGER zero = {0};
+    SOCKET client, server;
+    ULONG_PTR key, value;
+    IO_STATUS_BLOCK io;
+    HANDLE event, port;
+    int ret;
+
+    event = CreateEventW(NULL, TRUE, FALSE, NULL);
+    tcp_socketpair(&client, &server);
+    port = CreateIoCompletionPort((HANDLE)client, NULL, 0, 0);
+
+    params.timeout = -100 * 10000;
+    params.count = 1;
+    params.sockets[0].socket = client;
+    params.sockets[0].flags = AFD_POLL_WRITE;
+    params.sockets[0].status = 0xdeadbeef;
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, &params, sizeof(params), &params, sizeof(params));
+    ok(!ret, "got %#x\n", ret);
+
+    ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero);
+    ok(ret == STATUS_TIMEOUT, "got %#x\n", ret);
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, (void *)0xdeadbeef, &io,
+            IOCTL_AFD_POLL, &params, sizeof(params), &params, sizeof(params));
+    ok(!ret, "got %#x\n", ret);
+
+    ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero);
+    ok(!ret, "got %#x\n", ret);
+    ok(!key, "got key %#Ix\n", key);
+    ok(value == 0xdeadbeef, "got value %#Ix\n", value);
+
+    params.timeout = 0;
+    params.count = 1;
+    params.sockets[0].socket = client;
+    params.sockets[0].flags = AFD_POLL_READ;
+    params.sockets[0].status = 0xdeadbeef;
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, (void *)0xdeadbeef, &io,
+            IOCTL_AFD_POLL, &params, sizeof(params), &params, sizeof(params));
+    ok(!ret, "got %#x\n", ret);
+
+    ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero);
+    ok(!ret, "got %#x\n", ret);
+    ok(!key, "got key %#Ix\n", key);
+    ok(value == 0xdeadbeef, "got value %#Ix\n", value);
+
+    /* Close a socket while polling on another. */
+
+    params.timeout = -100 * 10000;
+    params.count = 1;
+    params.sockets[0].socket = server;
+    params.sockets[0].flags = AFD_POLL_READ;
+    params.sockets[0].status = 0xdeadbeef;
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, (void *)0xdeadbeef, &io,
+            IOCTL_AFD_POLL, &params, sizeof(params), &params, sizeof(params));
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    closesocket(client);
+
+    ret = WaitForSingleObject(event, 1000);
+    ok(!ret, "got %#x\n", ret);
+    todo_wine ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status);
+    todo_wine ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information);
+    todo_wine ok(!params.count, "got count %u\n", params.count);
+
+    ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero);
+    ok(!ret, "got %#x\n", ret);
+    ok(!key, "got key %#Ix\n", key);
+    ok(value == 0xdeadbeef, "got value %#Ix\n", value);
+
+    CloseHandle(port);
+    closesocket(server);
+    CloseHandle(event);
+}
+
 static void test_recv(void)
 {
     const struct sockaddr_in bind_addr = {.sin_family = AF_INET, .sin_addr.s_addr = htonl(INADDR_LOOPBACK)};
@@ -435,6 +1116,8 @@ START_TEST(afd)
     WSAStartup(MAKEWORD(2, 2), &data);
 
     test_open_device();
+    test_poll();
+    test_poll_completion_port();
     test_recv();
 
     WSACleanup();
-- 
2.30.2




More information about the wine-devel mailing list