Help to create a server request

Bruno Jesus 00cpxxx at gmail.com
Thu Aug 29 19:43:40 CDT 2013


Hi all, I need some help to continue my current wine work.

In order to implement SO_PROTOCOL_INFO for getsockopt I need to
retrieve some information from the socket like its family and
protocol.

I have searched for a few days and ended up with a solution I dislike
so I had a better idea (at least I hope I did).

Instead of using non-portable SO_DOMAIN and SO_PROTOCOL/SO_PROTOTYPE
to retrieve the socket family and protocol or using non-reliable
guessing using only the socket type I thought it would be better to
ask the server for this information. Using a request just like is used
for several other information in ws2_32 (operation which will work on
every OS).

So, all I need is a server request that based on the socket fd will
return the socket family, type and protocol. I tried to understand how
requests work but I failed completely.

Maybe this request can be later improved to return the connection time
so we can finally fix SO_CONNECT_TIME option.

The current solution is attached, since I sent the tests separated and
they were commited the patch will not apply, it's only for reference.
The idea is to remove the functions get_sock_[family|protocol|type] to
a single server request.

So, is this a good idea? If yes, how can I create and use the request?
If not I'm all ears.


Best wishes and thanks in advance,
Bruno
-------------- next part --------------
diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c
index 462f153..5218b5c 100644
--- a/dlls/ws2_32/socket.c
+++ b/dlls/ws2_32/socket.c
@@ -1018,6 +1018,57 @@ static inline int get_rcvsnd_timeo( int fd, int optname)
   return ret;
 }
 
+static inline int get_sock_family(int fd)
+{
+    int optval = AF_UNSPEC;
+#ifdef SO_DOMAIN
+    socklen_t optlen = sizeof(optval);
+    if (getsockopt(fd, SOL_SOCKET, SO_DOMAIN, (char *) &optval, &optlen))
+        ERR("getsockopt(SO_DOMAIN) failed\n");
+#else
+    union generic_unix_sockaddr uaddr;
+    socklen_t uaddrlen = sizeof(uaddr);
+
+    if (!getsockname(fd, &uaddr.addr, &uaddrlen))
+        optval = uaddr.addr.sa_family;
+
+    if (optval == AF_UNSPEC)
+    {
+        optval = AF_INET;
+        ERR("could not detect socket family - defaulting to AF_INET\n");
+    }
+#endif
+    return optval;
+}
+
+static inline int get_sock_type(int fd)
+{
+  int optval = 0;
+  socklen_t optlen = sizeof(optval);
+  if (getsockopt(fd, SOL_SOCKET, SO_TYPE, (char *) &optval, &optlen))
+      ERR("getsockopt(SO_TYPE) failed\n");
+  return optval;
+}
+
+static inline int get_sock_protocol(int fd)
+{
+    int optval = IPPROTO_IP;
+#ifdef SO_PROTOCOL
+    socklen_t optlen = sizeof(optval);
+    if (getsockopt(fd, SOL_SOCKET, SO_PROTOCOL, (char *) &optval, &optlen))
+        ERR("getsockopt(SO_PROTOCOL) failed\n");
+#elif defined(SO_PROTOTYPE)
+    socklen_t optlen = sizeof(optval);
+    if (getsockopt(fd, SOL_SOCKET, SO_PROTOTYPE, (char *) &optval, &optlen))
+        ERR("getsockopt(SO_PROTOTYPE) failed\n");
+#else
+    int socktype = get_sock_type(fd);
+    if (socktype == SOCK_STREAM) optval = IPPROTO_TCP;
+    else if (socktype == SOCK_DGRAM) optval = IPPROTO_UDP;
+#endif
+    return optval;
+}
+
 /* macro wrappers for portability */
 #ifdef SO_RCVTIMEO
 #define GET_RCVTIMEO(fd) get_rcvsnd_timeo( (fd), SO_RCVTIMEO)
@@ -1116,6 +1167,60 @@ convert_socktype_u2w(int unixsocktype) {
     return -1;
 }
 
+static int fill_protocol_info(int fd, int unicode, char *optval)
+{
+    int sockfamily, socktype, sockproto, items, sz, i;
+    DWORD listsize = 0;
+    WSAPROTOCOL_INFOW *buffer = NULL;
+
+    union _infow
+    {
+      WSAPROTOCOL_INFOA *a;
+      WSAPROTOCOL_INFOA *w;
+    } info;
+    info.a = (WSAPROTOCOL_INFOA *) optval;
+
+    sz = unicode ? sizeof(WSAPROTOCOL_INFOW) : sizeof(WSAPROTOCOL_INFOA);
+    memset(optval, 0, sz);
+
+    sockfamily = convert_af_u2w(get_sock_family(fd));
+    socktype = convert_socktype_u2w(get_sock_type(fd));
+    sockproto = convert_proto_u2w(get_sock_protocol(fd));
+
+    /* Start by filling basic information in case our search below fails */
+    info.a->iAddressFamily = sockfamily;
+    info.a->iSocketType = socktype;
+    info.a->iProtocol = sockproto;
+
+    items = WSAEnumProtocolsW(NULL, NULL, &listsize);
+    if (items == SOCKET_ERROR && WSAGetLastError() == WSAENOBUFS &&
+       (buffer = HeapAlloc(GetProcessHeap(), 0, listsize)))
+    {
+        items = WSAEnumProtocolsW(NULL, buffer, &listsize);
+        for (i = 0; i < items; i++)
+        {
+            if (buffer[i].iAddressFamily == sockfamily && 
+                buffer[i].iSocketType == socktype &&
+                buffer[i].iProtocol == sockproto)
+            {
+                if (unicode)
+                    memcpy(info.w, &buffer[i], sz);
+                else
+                {
+                    /* convert the structure from W to A */
+                    memcpy(info.a, &buffer[i], FIELD_OFFSET(WSAPROTOCOL_INFOA, szProtocol));
+                    WideCharToMultiByte(CP_ACP, 0, buffer[i].szProtocol, -1,
+                                        info.a->szProtocol, WSAPROTOCOL_LEN+1, NULL, NULL);
+                }
+                break;
+            }
+        }
+    }
+
+    HeapFree(GetProcessHeap(), 0, buffer);
+    return sz;
+}
+
 /* ----------------------------------- API -----
  *
  * Init / cleanup / error checking.
@@ -2776,6 +2881,22 @@ INT WINAPI WS_getsockopt(SOCKET s, INT level,
             TRACE("getting global SO_OPENTYPE = 0x%x\n", *((int*)optval) );
             return 0;
 
+        case WS_SO_PROTOCOL_INFOA:
+        case WS_SO_PROTOCOL_INFOW:
+            if (!optlen || !optval ||
+                *optlen < (optname == WS_SO_PROTOCOL_INFOA ?
+                 sizeof(WSAPROTOCOL_INFOA) : sizeof(WSAPROTOCOL_INFOW)))
+            {
+                SetLastError(WSAEFAULT);
+                return SOCKET_ERROR;
+            }
+            if ( (fd = get_sock_fd( s, 0, NULL )) == -1)
+                return SOCKET_ERROR;
+
+            *optlen = fill_protocol_info(fd, optname == WS_SO_PROTOCOL_INFOW, optval);
+            release_sock_fd( s, fd );
+            return ret;
+
 #ifdef SO_RCVTIMEO
         case WS_SO_RCVTIMEO:
 #endif
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c
index 5ce4959..38face9 100644
--- a/dlls/ws2_32/tests/sock.c
+++ b/dlls/ws2_32/tests/sock.c
@@ -1105,6 +1105,18 @@ static void test_set_getsockopt(void)
     int timeout;
     LINGER lingval;
     int size;
+    WSAPROTOCOL_INFOA infoA;
+    WSAPROTOCOL_INFOW infoW;
+    char providername[WSAPROTOCOL_LEN+1];
+    struct _prottest
+    {
+        int family, type, proto;
+    } prottest[] = {
+        {AF_INET, SOCK_STREAM, IPPROTO_TCP},
+        {AF_INET, SOCK_DGRAM, IPPROTO_UDP},
+        {AF_INET6, SOCK_STREAM, IPPROTO_TCP},
+        {AF_INET6, SOCK_DGRAM, IPPROTO_UDP}
+    };
 
     s = socket(AF_INET, SOCK_STREAM, 0);
     ok(s!=INVALID_SOCKET, "socket() failed error: %d\n", WSAGetLastError());
@@ -1221,6 +1233,82 @@ todo_wine
         err, WSAGetLastError());
 
     closesocket(s);
+
+    /* test SO_PROTOCOL_INFOA invalid parameters */
+    s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, NULL, NULL),
+       "getsockopt should have failed\n");
+    err = WSAGetLastError();
+    ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err);
+    ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, (char *) &infoA, NULL),
+       "getsockopt should have failed\n");
+    err = WSAGetLastError();
+    ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err);
+    ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, NULL, &size),
+       "getsockopt should have failed\n");
+    err = WSAGetLastError();
+    ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err);
+
+    size = sizeof(WSAPROTOCOL_INFOA) / 2;
+    ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, (char *) &infoA, &size),
+       "getsockopt should have failed\n");
+    err = WSAGetLastError();
+    ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err);
+    closesocket(s);
+
+    /* test SO_PROTOCOL_INFO structure returned for different protocols */
+    for (i = 0; i < sizeof(prottest) / sizeof(prottest[0]); i++)
+    {
+        s = socket(prottest[i].family, prottest[i].type, prottest[i].proto);
+        if (s == INVALID_SOCKET && prottest[i].family == AF_INET6) continue;
+
+        ok(s != INVALID_SOCKET, "Failed to create socket: %d\n",
+          WSAGetLastError());
+
+        /* compare both A and W version */
+        infoA.szProtocol[0] = 0;
+        size = sizeof(WSAPROTOCOL_INFOA);
+        err = getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, (char *) &infoA, &size);
+        ok(!err,"getsockopt failed with %d\n", WSAGetLastError());
+
+        infoW.szProtocol[0] = 0;
+        size = sizeof(WSAPROTOCOL_INFOW);
+        err = getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOW, (char *) &infoW, &size);
+        ok(!err,"getsockopt failed with %d\n", WSAGetLastError());
+
+        trace("provider name '%s', family %d, type %d, proto %d\n",
+              infoA.szProtocol, prottest[i].family, prottest[i].type, prottest[i].proto);
+
+        /* TODO: remove when WSAEnumProtocols return AF_INET6 data */
+        if (prottest[i].family == AF_INET6)
+        {
+          todo_wine {
+          ok(infoA.szProtocol[0], "WSAPROTOCOL_INFOA was not filled\n");
+          ok(infoW.szProtocol[0], "WSAPROTOCOL_INFOW was not filled\n");
+          }
+        }
+        else
+        {
+          ok(infoA.szProtocol[0], "WSAPROTOCOL_INFOA was not filled\n");
+          ok(infoW.szProtocol[0], "WSAPROTOCOL_INFOW was not filled\n");
+        }
+
+        WideCharToMultiByte(CP_ACP, 0, infoW.szProtocol, -1,
+                            providername, sizeof(providername), NULL, NULL);        
+        ok(!strcmp(infoA.szProtocol,providername),
+           "different provider names '%s' != '%s'\n", infoA.szProtocol, providername);
+
+        ok(!memcmp(&infoA, &infoW, FIELD_OFFSET(WSAPROTOCOL_INFOA, szProtocol)),
+           "SO_PROTOCOL_INFO[A/W] comparison failed\n");
+        ok(infoA.iAddressFamily == prottest[i].family, "socket family invalid, expected %d received %d\n",
+           prottest[i].family, infoA.iAddressFamily);
+        ok(infoA.iSocketType == prottest[i].type, "socket type invalid, expected %d received %d\n",
+           prottest[i].type, infoA.iSocketType);
+        ok(infoA.iProtocol == prottest[i].proto, "socket protocol invalid, expected %d received %d\n",
+           prottest[i].proto, infoA.iProtocol);
+
+        closesocket(s);
+    }
 }
 
 static void test_so_reuseaddr(void)


More information about the wine-devel mailing list