rpcrt4: Implement RpcMgmtWaitServerListen

Dan Hipschman dsh at linux.ucla.edu
Tue Apr 10 20:30:20 CDT 2007


Hi, this is my first wine patch in a while, so hopefully I remembered
all the customs.  This patch implements RpcMgmtWaitServerListen.  I was
working on writing a framework to test the stub code generated by widl,
starting with something like the very simple example at

     http://www.softlookup.com/tutorial/vc++/vcu43fi.asp

which needs RpcMgmtWaitServerListen.  This includes a couple tests for
error cases (conforming to XP).  I'd like to add further tests, but they
involve opening a server port on people's computers, and I'm still
looking into the best way to do that kind of stuff.  The hello world
application on the page above (or something very close to it) works with
this patch, and I tested it using the ncacn_ip_tcp and ncacn_np
protocols.

---
 dlls/rpcrt4/rpc_binding.h   |    8 +++--
 dlls/rpcrt4/rpc_message.c   |   10 +++---
 dlls/rpcrt4/rpc_server.c    |   69 ++++++++++++++++++++++++++++--------------
 dlls/rpcrt4/rpc_transport.c |   57 +++++++++++++++++++++++++++++++++--
 dlls/rpcrt4/tests/rpc.c     |    8 +++++
 5 files changed, 118 insertions(+), 34 deletions(-)

diff --git a/dlls/rpcrt4/rpc_binding.h b/dlls/rpcrt4/rpc_binding.h
index 6dde2ae..fdb84be 100644
--- a/dlls/rpcrt4/rpc_binding.h
+++ b/dlls/rpcrt4/rpc_binding.h
@@ -86,6 +86,7 @@ typedef struct _RpcConnection
 
   /* client-only */
   struct list conn_pool_entry;
+  struct list client_entry;
 } RpcConnection;
 
 struct connection_ops {
@@ -94,7 +95,8 @@ struct connection_ops {
   RpcConnection *(*alloc)(void);
   RPC_STATUS (*open_connection_client)(RpcConnection *conn);
   RPC_STATUS (*handoff)(RpcConnection *old_conn, RpcConnection *new_conn);
-  int (*read)(RpcConnection *conn, void *buffer, unsigned int len);
+  int (*read)(RpcConnection *conn, void *buffer, unsigned int len, BOOL check_stop_event);
+  int (*signal_to_stop)(RpcConnection *conn);
   int (*write)(RpcConnection *conn, const void *buffer, unsigned int len);
   int (*close)(RpcConnection *conn);
   size_t (*get_top_of_tower)(unsigned char *tower_data, const char *networkaddr, const char *endpoint);
@@ -164,9 +166,9 @@ static inline const char *rpcrt4_conn_get_name(RpcConnection *Connection)
 }
 
 static inline int rpcrt4_conn_read(RpcConnection *Connection,
-                     void *buffer, unsigned int len)
+                     void *buffer, unsigned int len, BOOL check_stop_event)
 {
-  return Connection->ops->read(Connection, buffer, len);
+  return Connection->ops->read(Connection, buffer, len, check_stop_event);
 }
 
 static inline int rpcrt4_conn_write(RpcConnection *Connection,
diff --git a/dlls/rpcrt4/rpc_message.c b/dlls/rpcrt4/rpc_message.c
index 1086602..d921dc1 100644
--- a/dlls/rpcrt4/rpc_message.c
+++ b/dlls/rpcrt4/rpc_message.c
@@ -622,7 +622,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
   TRACE("(%p, %p, %p)\n", Connection, Header, pMsg);
 
   /* read packet common header */
-  dwRead = rpcrt4_conn_read(Connection, &common_hdr, sizeof(common_hdr));
+  dwRead = rpcrt4_conn_read(Connection, &common_hdr, sizeof(common_hdr), TRUE);
   if (dwRead != sizeof(common_hdr)) {
     WARN("Short read of header, %d bytes\n", dwRead);
     status = RPC_S_PROTOCOL_ERROR;
@@ -648,7 +648,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
   memcpy(*Header, &common_hdr, sizeof(common_hdr));
 
   /* read the rest of packet header */
-  dwRead = rpcrt4_conn_read(Connection, &(*Header)->common + 1, hdr_length - sizeof(common_hdr));
+  dwRead = rpcrt4_conn_read(Connection, &(*Header)->common + 1, hdr_length - sizeof(common_hdr), FALSE);
   if (dwRead != hdr_length - sizeof(common_hdr)) {
     WARN("bad header length, %d bytes, hdr_length %d\n", dwRead, hdr_length);
     status = RPC_S_PROTOCOL_ERROR;
@@ -720,7 +720,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
 
     if (data_length == 0) dwRead = 0; else
     dwRead = rpcrt4_conn_read(Connection,
-        (unsigned char *)pMsg->Buffer + buffer_length, data_length);
+        (unsigned char *)pMsg->Buffer + buffer_length, data_length, FALSE);
     if (dwRead != data_length) {
       WARN("bad data length, %d/%ld\n", dwRead, data_length);
       status = RPC_S_PROTOCOL_ERROR;
@@ -739,7 +739,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
        * however, the details of how this is done is very sketchy in the
        * DCE/RPC spec. for all other packet types that have authentication
        * verifier data then it is just duplicated in all the fragments */
-      dwRead = rpcrt4_conn_read(Connection, auth_data, header_auth_len);
+      dwRead = rpcrt4_conn_read(Connection, auth_data, header_auth_len, FALSE);
       if (dwRead != header_auth_len) {
         WARN("bad authentication data length, %d/%d\n", dwRead,
           header_auth_len);
@@ -765,7 +765,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
       TRACE("next header\n");
 
       /* read the header of next packet */
-      dwRead = rpcrt4_conn_read(Connection, *Header, hdr_length);
+      dwRead = rpcrt4_conn_read(Connection, *Header, hdr_length, FALSE);
       if (dwRead != hdr_length) {
         WARN("invalid packet header size (%d)\n", dwRead);
         status = RPC_S_PROTOCOL_ERROR;
diff --git a/dlls/rpcrt4/rpc_server.c b/dlls/rpcrt4/rpc_server.c
index 7152453..1284f3f 100644
--- a/dlls/rpcrt4/rpc_server.c
+++ b/dlls/rpcrt4/rpc_server.c
@@ -69,6 +69,16 @@ static RpcObjTypeMap *RpcObjTypeMaps;
 /* list of type RpcServerProtseq */
 static struct list protseqs = LIST_INIT(protseqs);
 static struct list server_interfaces = LIST_INIT(server_interfaces);
+static struct list client_connections = LIST_INIT(client_connections);
+
+static CRITICAL_SECTION client_connections_cs;
+static CRITICAL_SECTION_DEBUG client_connections_cs_debug =
+{
+    0, 0, &client_connections_cs,
+    { &client_connections_cs_debug.ProcessLocksList, &client_connections_cs_debug.ProcessLocksList },
+      0, 0, { (DWORD_PTR)(__FILE__ ": client_connections_cs") }
+};
+static CRITICAL_SECTION client_connections_cs = { &client_connections_cs_debug, -1, 0, 0, 0, 0 };
 
 static CRITICAL_SECTION server_cs;
 static CRITICAL_SECTION_DEBUG server_cs_debug =
@@ -90,6 +100,7 @@ static CRITICAL_SECTION listen_cs = { &listen_cs_debug, -1, 0, 0, 0, 0 };
 
 /* whether the server is currently listening */
 static BOOL std_listen;
+static HANDLE server_stop_event, clients_completed_event;
 /* number of manual listeners (calls to RpcServerListen) */
 static LONG manual_listen_count;
 /* total listeners including auto listeners */
@@ -304,14 +315,6 @@ fail:
   HeapFree(GetProcessHeap(), 0, msg);
 }
 
-static DWORD CALLBACK RPCRT4_worker_thread(LPVOID the_arg)
-{
-  RpcPacket *pkt = the_arg;
-  RPCRT4_process_packet(pkt->conn, pkt->hdr, pkt->msg);
-  HeapFree(GetProcessHeap(), 0, pkt);
-  return 0;
-}
-
 static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg)
 {
   RpcConnection* conn = (RpcConnection*)the_arg;
@@ -319,10 +322,14 @@ static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg)
   RpcBinding *pbind;
   RPC_MESSAGE *msg;
   RPC_STATUS status;
-  RpcPacket *packet;
 
   TRACE("(%p)\n", conn);
 
+  EnterCriticalSection(&client_connections_cs);
+  list_add_head(&client_connections, &conn->client_entry);
+  ResetEvent(clients_completed_event);
+  LeaveCriticalSection(&client_connections_cs);
+
   for (;;) {
     msg = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(RPC_MESSAGE));
 
@@ -338,17 +345,17 @@ static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg)
       break;
     }
 
-#if 0
     RPCRT4_process_packet(conn, hdr, msg);
-#else
-    packet = HeapAlloc(GetProcessHeap(), 0, sizeof(RpcPacket));
-    packet->conn = conn;
-    packet->hdr = hdr;
-    packet->msg = msg;
-    QueueUserWorkItem(RPCRT4_worker_thread, packet, WT_EXECUTELONGFUNCTION);
-#endif
-    msg = NULL;
   }
+
+  EnterCriticalSection(&client_connections_cs);
+  list_remove(&conn->client_entry);
+  if (list_empty(&client_connections)) {
+    TRACE("last in the list to complete (%p)\n", conn);
+    SetEvent(clients_completed_event);
+  }
+  LeaveCriticalSection(&client_connections_cs);
+
   RPCRT4_DestroyConnection(conn);
   return 0;
 }
@@ -361,10 +368,6 @@ void RPCRT4_new_client(RpcConnection* conn)
     ERR("failed to create thread, error=%08x\n", err);
     RPCRT4_DestroyConnection(conn);
   }
-  /* we could set conn->thread, but then we'd have to make the io_thread wait
-   * for that, otherwise the thread might finish, destroy the connection, and
-   * free the memory we'd write to before we did, causing crashes and stuff -
-   * so let's implement that later, when we really need conn->thread */
 
   CloseHandle( thread );
 }
@@ -971,6 +974,14 @@ RPC_STATUS WINAPI RpcServerListen( UINT MinimumCallThreads, UINT MaxCalls, UINT
   if (list_empty(&protseqs))
     return RPC_S_NO_PROTSEQS_REGISTERED;
 
+  EnterCriticalSection(&client_connections_cs);
+  if (server_stop_event == NULL)
+    server_stop_event = CreateEventW(NULL, FALSE, FALSE, NULL);
+  ResetEvent(server_stop_event);
+  if (clients_completed_event == NULL)
+    clients_completed_event = CreateEventW(NULL, FALSE, FALSE, NULL);
+  LeaveCriticalSection(&client_connections_cs);
+
   status = RPCRT4_start_listen(FALSE);
 
   if (DontWait || (status != RPC_S_OK)) return status;
@@ -983,6 +994,8 @@ RPC_STATUS WINAPI RpcServerListen( UINT MinimumCallThreads, UINT MaxCalls, UINT
  */
 RPC_STATUS WINAPI RpcMgmtWaitServerListen( void )
 {
+  RpcConnection *conn;
+
   TRACE("()\n");
 
   EnterCriticalSection(&listen_cs);
@@ -994,7 +1007,15 @@ RPC_STATUS WINAPI RpcMgmtWaitServerListen( void )
   
   LeaveCriticalSection(&listen_cs);
 
-  FIXME("not waiting for server calls to finish\n");
+  WaitForSingleObject(server_stop_event, INFINITE);
+
+  EnterCriticalSection(&client_connections_cs);
+  LIST_FOR_EACH_ENTRY(conn, &client_connections, RpcConnection, client_entry) {
+    conn->ops->signal_to_stop(conn);
+  }
+  LeaveCriticalSection(&client_connections_cs);
+
+  WaitForSingleObject(clients_completed_event, INFINITE);
 
   return RPC_S_OK;
 }
@@ -1012,6 +1033,8 @@ RPC_STATUS WINAPI RpcMgmtStopServerListening ( RPC_BINDING_HANDLE Binding )
   }
   
   RPCRT4_stop_listen(FALSE);
+  if (server_stop_event)
+    SetEvent(server_stop_event);
 
   return RPC_S_OK;
 }
diff --git a/dlls/rpcrt4/rpc_transport.c b/dlls/rpcrt4/rpc_transport.c
index b4af8e9..600a8a2 100644
--- a/dlls/rpcrt4/rpc_transport.c
+++ b/dlls/rpcrt4/rpc_transport.c
@@ -99,6 +99,7 @@ typedef struct _RpcConnection_np
   HANDLE pipe;
   OVERLAPPED ovl;
   BOOL listening;
+  HANDLE stop_event;
 } RpcConnection_np;
 
 static RpcConnection *rpcrt4_conn_np_alloc(void)
@@ -109,6 +110,7 @@ static RpcConnection *rpcrt4_conn_np_alloc(void)
     npc->pipe = NULL;
     memset(&npc->ovl, 0, sizeof(npc->ovl));
     npc->listening = FALSE;
+    npc->stop_event = CreateEventW(NULL, FALSE, FALSE, NULL); /* handle NULL? */
   }
   return &npc->common;
 }
@@ -363,12 +365,22 @@ static RPC_STATUS rpcrt4_ncalrpc_handoff(RpcConnection *old_conn, RpcConnection
 }
 
 static int rpcrt4_conn_np_read(RpcConnection *Connection,
-                        void *buffer, unsigned int count)
+                        void *buffer, unsigned int count, BOOL check_stop_event)
 {
   RpcConnection_np *npc = (RpcConnection_np *) Connection;
   char *buf = buffer;
   BOOL ret = TRUE;
   unsigned int bytes_left = count;
+  HANDLE objects[2] = {npc->pipe, npc->stop_event};
+  DWORD num_objects = check_stop_event ? 2 : 1;
+  DWORD wait_res;
+
+  wait_res = WaitForMultipleObjects(num_objects, objects, FALSE, INFINITE);
+  if (wait_res == WAIT_OBJECT_0 + 1) {
+    TRACE("noticed a stop event for %p\n", npc);
+    return 0;
+  } else if (wait_res != WAIT_OBJECT_0)
+    return -1;
 
   while (bytes_left)
   {
@@ -414,6 +426,15 @@ static int rpcrt4_conn_np_close(RpcConnection *Connection)
     CloseHandle(npc->ovl.hEvent);
     npc->ovl.hEvent = 0;
   }
+  CloseHandle(npc->stop_event);
+  return 0;
+}
+
+static int rpcrt4_conn_np_signal_to_stop(RpcConnection *Connection)
+{
+  RpcConnection_np *npc = (RpcConnection_np *) Connection;
+  TRACE("(%p)\n", npc);
+  SetEvent(npc->stop_event);
   return 0;
 }
 
@@ -704,6 +725,7 @@ typedef struct _RpcConnection_tcp
 {
   RpcConnection common;
   int sock;
+  int stop_event[2];
 } RpcConnection_tcp;
 
 static RpcConnection *rpcrt4_conn_tcp_alloc(void)
@@ -713,6 +735,7 @@ static RpcConnection *rpcrt4_conn_tcp_alloc(void)
   if (tcpc == NULL)
     return NULL;
   tcpc->sock = -1;
+  pipe(tcpc->stop_event);
   return &tcpc->common;
 }
 
@@ -941,10 +964,22 @@ static RPC_STATUS rpcrt4_conn_tcp_handoff(RpcConnection *old_conn, RpcConnection
 }
 
 static int rpcrt4_conn_tcp_read(RpcConnection *Connection,
-                                void *buffer, unsigned int count)
+                                void *buffer, unsigned int count, BOOL check_stop_event)
 {
   RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
-  int r = recv(tcpc->sock, buffer, count, MSG_WAITALL);
+  struct pollfd fds[2] = {{tcpc->sock, POLLIN, 0}, {tcpc->stop_event[0], POLLIN, 0}};
+  int num_fds = check_stop_event ? 2 : 1;
+  int poll_res, r;
+
+  poll_res = poll(fds, num_fds, -1);
+  if (poll_res < 0)
+    return -1;
+  else if (fds[1].revents & POLLIN) {
+    TRACE("noticed a stop event for %p\n", tcpc);
+    return 0;
+  }
+
+  r = recv(tcpc->sock, buffer, count, MSG_WAITALL);
   TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, r);
   return r;
 }
@@ -967,6 +1002,19 @@ static int rpcrt4_conn_tcp_close(RpcConnection *Connection)
   if (tcpc->sock != -1)
     close(tcpc->sock);
   tcpc->sock = -1;
+
+  close(tcpc->stop_event[0]);
+  close(tcpc->stop_event[1]);
+
+  return 0;
+}
+
+static int rpcrt4_conn_tcp_signal_to_stop(RpcConnection *Connection)
+{
+  RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
+  char c = 0;
+  TRACE("(%p)\n", tcpc);
+  write(tcpc->stop_event[1], &c, 1);
   return 0;
 }
 
@@ -1250,6 +1298,7 @@ static const struct connection_ops conn_protseq_list[] = {
     rpcrt4_ncacn_np_open,
     rpcrt4_ncacn_np_handoff,
     rpcrt4_conn_np_read,
+    rpcrt4_conn_np_signal_to_stop,
     rpcrt4_conn_np_write,
     rpcrt4_conn_np_close,
     rpcrt4_ncacn_np_get_top_of_tower,
@@ -1261,6 +1310,7 @@ static const struct connection_ops conn_protseq_list[] = {
     rpcrt4_ncalrpc_open,
     rpcrt4_ncalrpc_handoff,
     rpcrt4_conn_np_read,
+    rpcrt4_conn_np_signal_to_stop,
     rpcrt4_conn_np_write,
     rpcrt4_conn_np_close,
     rpcrt4_ncalrpc_get_top_of_tower,
@@ -1272,6 +1322,7 @@ static const struct connection_ops conn_protseq_list[] = {
     rpcrt4_ncacn_ip_tcp_open,
     rpcrt4_conn_tcp_handoff,
     rpcrt4_conn_tcp_read,
+    rpcrt4_conn_tcp_signal_to_stop,
     rpcrt4_conn_tcp_write,
     rpcrt4_conn_tcp_close,
     rpcrt4_ncacn_ip_tcp_get_top_of_tower,
diff --git a/dlls/rpcrt4/tests/rpc.c b/dlls/rpcrt4/tests/rpc.c
index 023741f..bc7f627 100644
--- a/dlls/rpcrt4/tests/rpc.c
+++ b/dlls/rpcrt4/tests/rpc.c
@@ -186,6 +186,14 @@ static void test_rpc_ncacn_ip_tcp(void)
 
     status = RpcNetworkIsProtseqValid(ncacn_ip_tcp);
     ok(status == RPC_S_OK, "return wrong\n");
+
+    status = RpcMgmtWaitServerListen();
+    ok(status == RPC_S_NOT_LISTENING,
+       "wrong RpcMgmtWaitServerListen error status (%lu)\n", status);
+
+    status = RpcServerListen(1, 20, FALSE);
+    ok(status == RPC_S_NO_PROTSEQS_REGISTERED,
+       "wrong RpcServerListen error status (%lu)\n", status);
 }
 
 /* this is what's generated with MS/RPC - it includes an extra 2



More information about the wine-patches mailing list