[PATCH 4/4] ws2_32: Do not assume that an fd_set is bounded by FD_SETSIZE.

Zebediah Figura zfigura at codeweavers.com
Wed Dec 29 15:46:00 CST 2021


Wine-Bug: https://bugs.winehq.org/show_bug.cgi?id=52302
Signed-off-by: Zebediah Figura <zfigura at codeweavers.com>
---
 dlls/ws2_32/socket.c         | 92 +++++++++++++++++++++++-------------
 dlls/ws2_32/tests/sock.c     | 31 +++++++++---
 dlls/ws2_32/ws2_32_private.h |  1 +
 3 files changed, 84 insertions(+), 40 deletions(-)

diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c
index e517b239d5b..2b05884acaf 100644
--- a/dlls/ws2_32/socket.c
+++ b/dlls/ws2_32/socket.c
@@ -2364,13 +2364,8 @@ static int add_fd_to_set( SOCKET fd, struct fd_set *set )
             return 0;
     }
 
-    if (set->fd_count < FD_SETSIZE)
-    {
-        set->fd_array[set->fd_count++] = fd;
-        return 1;
-    }
-
-    return 0;
+    set->fd_array[set->fd_count++] = fd;
+    return 1;
 }
 
 
@@ -2380,9 +2375,9 @@ static int add_fd_to_set( SOCKET fd, struct fd_set *set )
 int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
                    fd_set *except_ptr, const struct timeval *timeout)
 {
-    char buffer[offsetof( struct afd_poll_params, sockets[FD_SETSIZE * 3] )] = {0};
-    struct afd_poll_params *params = (struct afd_poll_params *)buffer;
-    struct fd_set read_input;
+    struct fd_set *read_input = NULL;
+    struct afd_poll_params *params;
+    unsigned int poll_count = 0;
     ULONG params_size, i, j;
     SOCKET poll_socket = 0;
     IO_STATUS_BLOCK io;
@@ -2392,22 +2387,49 @@ int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
 
     TRACE( "read %p, write %p, except %p, timeout %p\n", read_ptr, write_ptr, except_ptr, timeout );
 
-    FD_ZERO( &read_input );
-    if (read_ptr) read_input.fd_count = read_ptr->fd_count;
-
     if (!(sync_event = get_sync_event())) return -1;
 
+    if (read_ptr) poll_count += read_ptr->fd_count;
+    if (write_ptr) poll_count += write_ptr->fd_count;
+    if (except_ptr) poll_count += except_ptr->fd_count;
+
+    if (!poll_count)
+    {
+        SetLastError( WSAEINVAL );
+        return -1;
+    }
+
+    params_size = offsetof( struct afd_poll_params, sockets[poll_count] );
+    if (!(params = calloc( params_size, 1 )))
+    {
+        SetLastError( WSAENOBUFS );
+        return -1;
+    }
+
     if (timeout)
         params->timeout = timeout->tv_sec * -10000000 + timeout->tv_usec * -10;
     else
         params->timeout = TIMEOUT_INFINITE;
 
-    for (i = 0; i < read_input.fd_count; ++i)
+    if (read_ptr)
     {
-        params->sockets[params->count].socket = read_input.fd_array[i] = read_ptr->fd_array[i];
-        params->sockets[params->count].flags = AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP;
-        ++params->count;
-        poll_socket = read_input.fd_array[i];
+        unsigned int read_size = offsetof( struct fd_set, fd_array[read_ptr->fd_count] );
+
+        if (!(read_input = malloc( read_size )))
+        {
+            free( params );
+            SetLastError( WSAENOBUFS );
+            return -1;
+        }
+        memcpy( read_input, read_ptr, read_size );
+
+        for (i = 0; i < read_ptr->fd_count; ++i)
+        {
+            params->sockets[params->count].socket = read_ptr->fd_array[i];
+            params->sockets[params->count].flags = AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP;
+            ++params->count;
+            poll_socket = read_ptr->fd_array[i];
+        }
     }
 
     if (write_ptr)
@@ -2432,42 +2454,43 @@ int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
         }
     }
 
-    if (!params->count)
-    {
-        SetLastError( WSAEINVAL );
-        return -1;
-    }
-
-    params_size = offsetof( struct afd_poll_params, sockets[params->count] );
+    assert( params->count == poll_count );
 
     status = NtDeviceIoControlFile( (HANDLE)poll_socket, sync_event, NULL, NULL, &io,
                                     IOCTL_AFD_POLL, params, params_size, params, params_size );
     if (status == STATUS_PENDING)
     {
         if (WaitForSingleObject( sync_event, INFINITE ) == WAIT_FAILED)
+        {
+            free( read_input );
+            free( params );
             return -1;
+        }
         status = io.u.Status;
     }
     if (status == STATUS_TIMEOUT) status = STATUS_SUCCESS;
     if (!status)
     {
         /* pointers may alias, so clear them all first */
-        if (read_ptr) FD_ZERO( read_ptr );
-        if (write_ptr) FD_ZERO( write_ptr );
-        if (except_ptr) FD_ZERO( except_ptr );
+        if (read_ptr) read_ptr->fd_count = 0;
+        if (write_ptr) write_ptr->fd_count = 0;
+        if (except_ptr) except_ptr->fd_count = 0;
 
         for (i = 0; i < params->count; ++i)
         {
             unsigned int flags = params->sockets[i].flags;
             SOCKET s = params->sockets[i].socket;
 
-            for (j = 0; j < read_input.fd_count; ++j)
+            if (read_input)
             {
-                if (read_input.fd_array[j] == s
-                        && (flags & (AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP | AFD_POLL_CLOSE)))
+                for (j = 0; j < read_input->fd_count; ++j)
                 {
-                    ret_count += add_fd_to_set( s, read_ptr );
-                    flags &= ~AFD_POLL_CLOSE;
+                    if (read_input->fd_array[j] == s
+                            && (flags & (AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP | AFD_POLL_CLOSE)))
+                    {
+                        ret_count += add_fd_to_set( s, read_ptr );
+                        flags &= ~AFD_POLL_CLOSE;
+                    }
                 }
             }
 
@@ -2482,6 +2505,9 @@ int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
         }
     }
 
+    free( read_input );
+    free( params );
+
     SetLastError( NtStatusToWSAError( status ) );
     return status ? -1 : ret_count;
 }
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c
index f1a020d1f0c..46f26d32e50 100644
--- a/dlls/ws2_32/tests/sock.c
+++ b/dlls/ws2_32/tests/sock.c
@@ -3204,9 +3204,8 @@ static void test_select(void)
 {
     static char tmp_buf[1024];
 
-    SOCKET fdListen, fdRead, fdWrite;
-    fd_set readfds, writefds, exceptfds, *alloc_readfds;
-    unsigned int maxfd;
+    fd_set readfds, writefds, exceptfds, *alloc_fds;
+    SOCKET fdListen, fdRead, fdWrite, sockets[200];
     int ret, len;
     char buffer;
     struct timeval select_timeout;
@@ -3214,6 +3213,7 @@ static void test_select(void)
     select_thread_params thread_params;
     HANDLE thread_handle;
     DWORD ticks, id, old_protect;
+    unsigned int maxfd, i;
     char *page_pair;
 
     fdRead = socket(AF_INET, SOCK_STREAM, 0);
@@ -3392,16 +3392,33 @@ static void test_select(void)
 
     page_pair = VirtualAlloc(NULL, 0x1000 * 2, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
     VirtualProtect(page_pair + 0x1000, 0x1000, PAGE_NOACCESS, &old_protect);
-    alloc_readfds = (fd_set *)((page_pair + 0x1000) - offsetof(fd_set, fd_array[1]));
-    alloc_readfds->fd_count = 1;
-    alloc_readfds->fd_array[0] = fdRead;
-    ret = select(fdRead+1, alloc_readfds, NULL, NULL, &select_timeout);
+    alloc_fds = (fd_set *)((page_pair + 0x1000) - offsetof(fd_set, fd_array[1]));
+    alloc_fds->fd_count = 1;
+    alloc_fds->fd_array[0] = fdRead;
+    ret = select(fdRead+1, alloc_fds, NULL, NULL, &select_timeout);
     ok(ret == 1, "select returned %d\n", ret);
     VirtualFree(page_pair, 0, MEM_RELEASE);
 
     closesocket(fdRead);
     closesocket(fdWrite);
 
+    alloc_fds = malloc(offsetof(fd_set, fd_array[ARRAY_SIZE(sockets)]));
+    alloc_fds->fd_count = ARRAY_SIZE(sockets);
+    for (i = 0; i < ARRAY_SIZE(sockets); i += 2)
+    {
+        tcp_socketpair(&sockets[i], &sockets[i + 1]);
+        alloc_fds->fd_array[i] = sockets[i];
+        alloc_fds->fd_array[i + 1] = sockets[i + 1];
+    }
+    ret = select(0, NULL, alloc_fds, NULL, &select_timeout);
+    ok(ret == ARRAY_SIZE(sockets), "got %d\n", ret);
+    for (i = 0; i < ARRAY_SIZE(sockets); ++i)
+    {
+        ok(alloc_fds->fd_array[i] == sockets[i], "got socket %#Ix at index %u\n", alloc_fds->fd_array[i], i);
+        closesocket(sockets[i]);
+    }
+    free(alloc_fds);
+
     /* select() works in 3 distinct states:
      * - to check if a connection attempt ended with success or error;
      * - to check if a pending connection is waiting for acceptance;
diff --git a/dlls/ws2_32/ws2_32_private.h b/dlls/ws2_32/ws2_32_private.h
index fa2f89f6e98..f6b6ecc7eba 100644
--- a/dlls/ws2_32/ws2_32_private.h
+++ b/dlls/ws2_32/ws2_32_private.h
@@ -19,6 +19,7 @@
 #ifndef __WINE_WS2_32_PRIVATE_H
 #define __WINE_WS2_32_PRIVATE_H
 
+#include <assert.h>
 #include <stdarg.h>
 #include <stdio.h>
 #include <stdlib.h>
-- 
2.34.1




More information about the wine-devel mailing list