[PATCH 5/5] ws2_32/tests: Add some more tests for reset TCP connections.

Zebediah Figura zfigura at codeweavers.com
Tue May 3 17:56:28 CDT 2022


Signed-off-by: Zebediah Figura <zfigura at codeweavers.com>
---
 dlls/ws2_32/tests/afd.c  | 114 +++++++++++++++++++++++++++-
 dlls/ws2_32/tests/sock.c | 155 +++++++++++++++++++++++++++++++--------
 2 files changed, 237 insertions(+), 32 deletions(-)

diff --git a/dlls/ws2_32/tests/afd.c b/dlls/ws2_32/tests/afd.c
index 184c7e0725a..cf27b4003c5 100644
--- a/dlls/ws2_32/tests/afd.c
+++ b/dlls/ws2_32/tests/afd.c
@@ -76,6 +76,21 @@ static void set_blocking(SOCKET s, ULONG blocking)
     ok(!ret, "got error %u\n", WSAGetLastError());
 }
 
+/* Set the linger timeout to zero and close the socket. This will trigger an
+ * RST on the connection on Windows as well as on Unix systems. */
+static void close_with_rst(SOCKET s)
+{
+    static const struct linger linger = {.l_onoff = 1};
+    int ret;
+
+    SetLastError(0xdeadbeef);
+    ret = setsockopt(s, SOL_SOCKET, SO_LINGER, (const char *)&linger, sizeof(linger));
+    ok(!ret, "got %d\n", ret);
+    ok(!GetLastError(), "got error %lu\n", GetLastError());
+
+    closesocket(s);
+}
+
 static void test_open_device(void)
 {
     OBJECT_BASIC_INFORMATION info;
@@ -142,7 +157,8 @@ static void check_poll_(int line, SOCKET s, HANDLE event, int mask, int expect,
     ok_(__FILE__, line)(out_params.count == 1, "got count %u\n", out_params.count);
     ok_(__FILE__, line)(out_params.sockets[0].socket == s, "got socket %#Ix\n", out_params.sockets[0].socket);
     todo_wine_if (todo) ok_(__FILE__, line)(out_params.sockets[0].flags == expect, "got flags %#x\n", out_params.sockets[0].flags);
-    ok_(__FILE__, line)(!out_params.sockets[0].status, "got status %#x\n", out_params.sockets[0].status);
+    todo_wine_if (expect & AFD_POLL_RESET)
+        ok_(__FILE__, line)(!out_params.sockets[0].status, "got status %#x\n", out_params.sockets[0].status);
 }
 
 static void test_poll(void)
@@ -1311,6 +1327,50 @@ static void test_poll_completion_port(void)
     CloseHandle(event);
 }
 
+static void test_poll_reset(void)
+{
+    char in_buffer[offsetof(struct afd_poll_params, sockets[3])];
+    char out_buffer[offsetof(struct afd_poll_params, sockets[3])];
+    struct afd_poll_params *in_params = (struct afd_poll_params *)in_buffer;
+    struct afd_poll_params *out_params = (struct afd_poll_params *)out_buffer;
+    SOCKET client, server;
+    IO_STATUS_BLOCK io;
+    ULONG params_size;
+    HANDLE event;
+    int ret;
+
+    memset(in_buffer, 0, sizeof(in_buffer));
+    memset(out_buffer, 0, sizeof(out_buffer));
+    event = CreateEventW(NULL, TRUE, FALSE, NULL);
+    tcp_socketpair(&client, &server);
+
+    in_params->timeout = -1000 * 10000;
+    in_params->count = 1;
+    in_params->sockets[0].socket = client;
+    in_params->sockets[0].flags = ~(AFD_POLL_WRITE | AFD_POLL_CONNECT);
+    params_size = offsetof(struct afd_poll_params, sockets[1]);
+
+    ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io,
+            IOCTL_AFD_POLL, in_params, params_size, out_params, params_size);
+    ok(ret == STATUS_PENDING, "got %#x\n", ret);
+
+    close_with_rst(server);
+
+    ret = WaitForSingleObject(event, 100);
+    ok(!ret, "got %#x\n", ret);
+    ok(!io.Status, "got %#lx\n", io.Status);
+    ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information);
+    ok(out_params->count == 1, "got count %u\n", out_params->count);
+    ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket);
+    todo_wine ok(out_params->sockets[0].flags == AFD_POLL_RESET, "got flags %#x\n", out_params->sockets[0].flags);
+    ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status);
+
+    check_poll_todo(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_RESET);
+
+    closesocket(client);
+    CloseHandle(event);
+}
+
 static void test_recv(void)
 {
     const struct sockaddr_in bind_addr = {.sin_family = AF_INET, .sin_addr.s_addr = htonl(INADDR_LOOPBACK)};
@@ -1914,6 +1974,56 @@ static void test_get_events(void)
     CloseHandle(event);
 }
 
+static void test_get_events_reset(void)
+{
+    struct afd_get_events_params params;
+    SOCKET client, server;
+    IO_STATUS_BLOCK io;
+    unsigned int i;
+    HANDLE event;
+    int ret;
+
+    event = CreateEventW(NULL, TRUE, FALSE, NULL);
+
+    tcp_socketpair(&client, &server);
+
+    ret = WSAEventSelect(client, event, FD_ACCEPT | FD_CONNECT | FD_CLOSE | FD_OOB | FD_READ | FD_WRITE);
+    ok(!ret, "got error %lu\n", GetLastError());
+
+    close_with_rst(server);
+
+    memset(&params, 0xcc, sizeof(params));
+    memset(&io, 0xcc, sizeof(io));
+    ret = NtDeviceIoControlFile((HANDLE)client, NULL, NULL, NULL, &io,
+            IOCTL_AFD_GET_EVENTS, NULL, 0, &params, sizeof(params));
+    ok(!ret, "got %#x\n", ret);
+    todo_wine ok(params.flags == (AFD_POLL_RESET | AFD_POLL_CONNECT | AFD_POLL_WRITE), "got flags %#x\n", params.flags);
+    for (i = 0; i < ARRAY_SIZE(params.status); ++i)
+        ok(!params.status[i], "got status[%u] %#x\n", i, params.status[i]);
+
+    closesocket(client);
+
+    tcp_socketpair(&client, &server);
+
+    ret = WSAEventSelect(server, event, FD_ACCEPT | FD_CONNECT | FD_CLOSE | FD_OOB | FD_READ | FD_WRITE);
+    ok(!ret, "got error %lu\n", GetLastError());
+
+    close_with_rst(client);
+
+    memset(&params, 0xcc, sizeof(params));
+    memset(&io, 0xcc, sizeof(io));
+    ret = NtDeviceIoControlFile((HANDLE)server, NULL, NULL, NULL, &io,
+            IOCTL_AFD_GET_EVENTS, NULL, 0, &params, sizeof(params));
+    ok(!ret, "got %#x\n", ret);
+    todo_wine ok(params.flags == (AFD_POLL_RESET | AFD_POLL_WRITE), "got flags %#x\n", params.flags);
+    for (i = 0; i < ARRAY_SIZE(params.status); ++i)
+        ok(!params.status[i], "got status[%u] %#x\n", i, params.status[i]);
+
+    closesocket(server);
+
+    CloseHandle(event);
+}
+
 static void test_bind(void)
 {
     const struct sockaddr_in6 bind_addr6 = {.sin6_family = AF_INET6, .sin6_addr.s6_words = {0, 0, 0, 0, 0, 0, 0, htons(1)}};
@@ -2255,9 +2365,11 @@ START_TEST(afd)
     test_poll();
     test_poll_exclusive();
     test_poll_completion_port();
+    test_poll_reset();
     test_recv();
     test_event_select();
     test_get_events();
+    test_get_events_reset();
     test_bind();
     test_getsockname();
 
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c
index 4eb2d16f871..99221ad52ca 100644
--- a/dlls/ws2_32/tests/sock.c
+++ b/dlls/ws2_32/tests/sock.c
@@ -202,6 +202,21 @@ static void tcp_socketpair(SOCKET *src, SOCKET *dst)
     tcp_socketpair_flags(src, dst, WSA_FLAG_OVERLAPPED);
 }
 
+/* Set the linger timeout to zero and close the socket. This will trigger an
+ * RST on the connection on Windows as well as on Unix systems. */
+static void close_with_rst(SOCKET s)
+{
+    static const struct linger linger = {.l_onoff = 1};
+    int ret;
+
+    SetLastError(0xdeadbeef);
+    ret = setsockopt(s, SOL_SOCKET, SO_LINGER, (const char *)&linger, sizeof(linger));
+    ok(!ret, "got %d\n", ret);
+    ok(!GetLastError(), "got error %lu\n", GetLastError());
+
+    closesocket(s);
+}
+
 #define check_poll(a, b) check_poll_(__LINE__, a, POLLRDNORM | POLLRDBAND | POLLWRNORM, b, FALSE)
 #define check_poll_todo(a, b) check_poll_(__LINE__, a, POLLRDNORM | POLLRDBAND | POLLWRNORM, b, TRUE)
 #define check_poll_mask(a, b, c) check_poll_(__LINE__, a, b, c, FALSE)
@@ -5302,7 +5317,9 @@ static void check_events_(int line, struct event_test_ctx *ctx,
     else
     {
         WSANETWORKEVENTS events;
+        unsigned int i;
 
+        memset(&events, 0xcc, sizeof(events));
         ret = WaitForSingleObject(ctx->event, timeout);
         if (flag1 | flag2)
             todo_wine_if (todo_event && ret) ok_(__FILE__, line)(!ret, "event wait timed out\n");
@@ -5311,7 +5328,16 @@ static void check_events_(int line, struct event_test_ctx *ctx,
         ret = WSAEnumNetworkEvents(ctx->socket, ctx->event, &events);
         ok_(__FILE__, line)(!ret, "failed to get events, error %u\n", WSAGetLastError());
         todo_wine_if (todo_event)
-            ok_(__FILE__, line)(events.lNetworkEvents == (flag1 | flag2), "got events %#lx\n", events.lNetworkEvents);
+            ok_(__FILE__, line)(events.lNetworkEvents == LOWORD(flag1 | flag2), "got events %#lx\n", events.lNetworkEvents);
+        for (i = 0; i < ARRAY_SIZE(events.iErrorCode); ++i)
+        {
+            if ((1u << i) == LOWORD(flag1) && (events.lNetworkEvents & LOWORD(flag1)))
+                todo_wine_if (HIWORD(flag1)) ok_(__FILE__, line)(events.iErrorCode[i] == HIWORD(flag1),
+                        "got error code %d for event %#x\n", events.iErrorCode[i], 1u << i);
+            if ((1u << i) == LOWORD(flag2) && (events.lNetworkEvents & LOWORD(flag2)))
+                ok_(__FILE__, line)(events.iErrorCode[i] == HIWORD(flag2),
+                        "got error code %d for event %#x\n", events.iErrorCode[i], 1u << i);
+        }
     }
 }
 
@@ -6114,6 +6140,28 @@ static void test_close_events(struct event_test_ctx *ctx)
     check_events(ctx, FD_CLOSE, 0, 200);
 
     closesocket(server);
+
+    /* Trigger RST. */
+
+    tcp_socketpair(&client, &server);
+
+    select_events(ctx, server, FD_ACCEPT | FD_CLOSE | FD_CONNECT | FD_OOB | FD_READ);
+
+    close_with_rst(client);
+
+    check_events_todo_msg(ctx, MAKELONG(FD_CLOSE, WSAECONNABORTED), 0, 200);
+    check_events(ctx, 0, 0, 0);
+    select_events(ctx, server, FD_ACCEPT | FD_CLOSE | FD_CONNECT | FD_OOB | FD_READ);
+    if (ctx->is_message)
+        check_events_todo(ctx, MAKELONG(FD_CLOSE, WSAECONNABORTED), 0, 200);
+    check_events(ctx, 0, 0, 0);
+    select_events(ctx, server, 0);
+    select_events(ctx, server, FD_ACCEPT | FD_CLOSE | FD_CONNECT | FD_OOB | FD_READ);
+    if (ctx->is_message)
+        check_events_todo(ctx, MAKELONG(FD_CLOSE, WSAECONNABORTED), 0, 200);
+    check_events(ctx, 0, 0, 0);
+
+    closesocket(server);
 }
 
 static void test_events(void)
@@ -6552,7 +6600,6 @@ static void test_WSARecv(void)
     WSABUF bufs[2];
     WSAOVERLAPPED ov;
     DWORD bytesReturned, flags, id;
-    struct linger ling;
     struct sockaddr_in addr;
     int iret, len;
     DWORD dwret;
@@ -6621,19 +6668,13 @@ static void test_WSARecv(void)
     if (!event)
         goto end;
 
-    ling.l_onoff = 1;
-    ling.l_linger = 0;
-    iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling));
-    ok(!iret, "Failed to set linger %ld\n", GetLastError());
-
     iret = WSARecv(dest, bufs, 1, NULL, &flags, &ov, NULL);
     ok(iret == SOCKET_ERROR && GetLastError() == ERROR_IO_PENDING, "WSARecv failed - %d error %ld\n", iret, GetLastError());
 
     iret = WSARecv(dest, bufs, 1, &bytesReturned, &flags, &ov, NULL);
     ok(iret == SOCKET_ERROR && GetLastError() == ERROR_IO_PENDING, "WSARecv failed - %d error %ld\n", iret, GetLastError());
 
-    closesocket(src);
-    src = INVALID_SOCKET;
+    close_with_rst(src);
 
     dwret = WaitForSingleObject(ov.hEvent, 1000);
     ok(dwret == WAIT_OBJECT_0, "Waiting for disconnect event failed with %ld + errno %ld\n", dwret, GetLastError());
@@ -9239,7 +9280,6 @@ static void test_completion_port(void)
     char buf[1024];
     WSABUF bufs;
     DWORD num_bytes, flags;
-    struct linger ling;
     int iret;
     BOOL bret;
     ULONG_PTR key;
@@ -9260,11 +9300,6 @@ static void test_completion_port(void)
     bufs.buf = buf;
     flags = 0;
 
-    ling.l_onoff = 1;
-    ling.l_linger = 0;
-    iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling));
-    ok(!iret, "Failed to set linger %ld\n", GetLastError());
-
     io_port = CreateIoCompletionPort( (HANDLE)dest, io_port, 125, 0 );
     ok(io_port != NULL, "Failed to create completion port %lu\n", GetLastError());
 
@@ -9276,8 +9311,7 @@ static void test_completion_port(void)
 
     Sleep(100);
 
-    closesocket(src);
-    src = INVALID_SOCKET;
+    close_with_rst(src);
 
     SetLastError(0xdeadbeef);
     key = 0xdeadbeef;
@@ -9314,18 +9348,12 @@ static void test_completion_port(void)
     bufs.buf = buf;
     flags = 0;
 
-    ling.l_onoff = 1;
-    ling.l_linger = 0;
-    iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling));
-    ok(!iret, "Failed to set linger %ld\n", GetLastError());
-
     io_port = CreateIoCompletionPort((HANDLE)dest, io_port, 125, 0);
     ok(io_port != NULL, "failed to create completion port %lu\n", GetLastError());
 
     set_blocking(dest, FALSE);
 
-    closesocket(src);
-    src = INVALID_SOCKET;
+    close_with_rst(src);
 
     Sleep(100);
 
@@ -9432,17 +9460,11 @@ static void test_completion_port(void)
     flags = 0;
     memset(&ov, 0, sizeof(ov));
 
-    ling.l_onoff = 1;
-    ling.l_linger = 0;
-    iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling));
-    ok(!iret, "Failed to set linger %ld\n", GetLastError());
-
     io_port = CreateIoCompletionPort((HANDLE)dest, io_port, 125, 0);
     ok(io_port != NULL, "failed to create completion port %lu\n", GetLastError());
     set_blocking(dest, FALSE);
 
-    closesocket(src);
-    src = INVALID_SOCKET;
+    close_with_rst(src);
 
     FD_ZERO(&fds_recv);
     FD_SET(dest, &fds_recv);
@@ -12441,6 +12463,76 @@ static void test_sockopt_validity(void)
     CloseHandle(file);
 }
 
+static void test_tcp_reset(void)
+{
+    static const struct timeval select_timeout;
+    fd_set readfds, writefds, exceptfds;
+    OVERLAPPED overlapped = {0};
+    SOCKET client, server;
+    DWORD size, flags = 0;
+    int ret, len, error;
+    char buffer[10];
+    WSABUF wsabuf;
+
+    overlapped.hEvent = CreateEventW(NULL, TRUE, FALSE, NULL);
+
+    tcp_socketpair(&client, &server);
+
+    wsabuf.buf = buffer;
+    wsabuf.len = sizeof(buffer);
+    WSASetLastError(0xdeadbeef);
+    size = 0xdeadbeef;
+    ret = WSARecv(client, &wsabuf, 1, &size, &flags, &overlapped, NULL);
+    ok(ret == -1, "got %d\n", ret);
+    ok(WSAGetLastError() == ERROR_IO_PENDING, "got error %u\n", WSAGetLastError());
+
+    close_with_rst(server);
+
+    ret = WaitForSingleObject(overlapped.hEvent, 1000);
+    ok(!ret, "wait failed\n");
+    ret = GetOverlappedResult((HANDLE)client, &overlapped, &size, FALSE);
+    todo_wine ok(!ret, "expected failure\n");
+    todo_wine ok(GetLastError() == ERROR_NETNAME_DELETED, "got error %lu\n", GetLastError());
+    ok(!size, "got size %lu\n", size);
+    todo_wine ok((NTSTATUS)overlapped.Internal == STATUS_CONNECTION_RESET, "got status %#lx\n", (NTSTATUS)overlapped.Internal);
+
+    len = sizeof(error);
+    ret = getsockopt(client, SOL_SOCKET, SO_ERROR, (char *)&error, &len);
+    ok(!ret, "got error %u\n", WSAGetLastError());
+    todo_wine ok(!error, "got error %u\n", error);
+
+    wsabuf.buf = buffer;
+    wsabuf.len = sizeof(buffer);
+    WSASetLastError(0xdeadbeef);
+    size = 0xdeadbeef;
+    ret = WSARecv(client, &wsabuf, 1, &size, &flags, &overlapped, NULL);
+    todo_wine ok(ret == -1, "got %d\n", ret);
+    todo_wine ok(WSAGetLastError() == WSAECONNRESET, "got error %u\n", WSAGetLastError());
+
+    check_poll_todo(client, POLLERR | POLLHUP | POLLWRNORM);
+
+    FD_ZERO(&readfds);
+    FD_ZERO(&writefds);
+    FD_ZERO(&exceptfds);
+    FD_SET(client, &readfds);
+    FD_SET(client, &writefds);
+    FD_SET(client, &exceptfds);
+    ret = select(0, &readfds, &writefds, &exceptfds, &select_timeout);
+    ok(ret == 2, "got %d\n", ret);
+    ok(FD_ISSET(client, &readfds), "FD should be set\n");
+    ok(FD_ISSET(client, &writefds), "FD should be set\n");
+    ok(!FD_ISSET(client, &exceptfds), "FD should be set\n");
+
+    FD_ZERO(&exceptfds);
+    FD_SET(client, &exceptfds);
+    ret = select(0, NULL, NULL, &exceptfds, &select_timeout);
+    ok(!ret, "got %d\n", ret);
+    ok(!FD_ISSET(client, &exceptfds), "FD should be set\n");
+
+    closesocket(server);
+    CloseHandle(overlapped.hEvent);
+}
+
 START_TEST( sock )
 {
     int i;
@@ -12514,6 +12606,7 @@ START_TEST( sock )
     test_simultaneous_async_recv();
     test_empty_recv();
     test_timeout();
+    test_tcp_reset();
 
     /* this is an io heavy test, do it at the end so the kernel doesn't start dropping packets */
     test_send();
-- 
2.34.1




More information about the wine-devel mailing list