rpcrt4: Add tests for RPC calls with authentication.

Hans Leidekker hans at codeweavers.com
Mon Dec 14 04:54:36 CST 2009


---
 dlls/rpcrt4/tests/Makefile.in |    2 +-
 dlls/rpcrt4/tests/server.c    |  159 ++++++++++++++++++++++++++++++++++++-----
 dlls/rpcrt4/tests/server.idl  |    2 +
 3 files changed, 145 insertions(+), 18 deletions(-)

diff --git a/dlls/rpcrt4/tests/Makefile.in b/dlls/rpcrt4/tests/Makefile.in
index 69bbd62..f7a2010 100644
--- a/dlls/rpcrt4/tests/Makefile.in
+++ b/dlls/rpcrt4/tests/Makefile.in
@@ -3,7 +3,7 @@ TOPOBJDIR = ../../..
 SRCDIR    = @srcdir@
 VPATH     = @srcdir@
 TESTDLL   = rpcrt4.dll
-IMPORTS   = ole32 rpcrt4 kernel32
+IMPORTS   = ole32 rpcrt4 advapi32 kernel32
 EXTRAIDLFLAGS = --prefix-server=s_
 
 IDL_C_SRCS = server.idl
diff --git a/dlls/rpcrt4/tests/server.c b/dlls/rpcrt4/tests/server.c
index 2a32b45..a871583 100644
--- a/dlls/rpcrt4/tests/server.c
+++ b/dlls/rpcrt4/tests/server.c
@@ -19,6 +19,8 @@
  */
 
 #include <windows.h>
+#include <secext.h>
+#include <rpcdce.h>
 #include "wine/test.h"
 #include "server.h"
 #include "server_defines.h"
@@ -41,6 +43,12 @@ static void (WINAPI *pNDRSContextMarshall2)(RPC_BINDING_HANDLE, NDR_SCONTEXT, vo
 static NDR_SCONTEXT (WINAPI *pNDRSContextUnmarshall2)(RPC_BINDING_HANDLE, void*, ULONG, void*, ULONG);
 static RPC_STATUS (WINAPI *pRpcServerRegisterIfEx)(RPC_IF_HANDLE,UUID*, RPC_MGR_EPV*, unsigned int,
                    unsigned int,RPC_IF_CALLBACK_FN*);
+static BOOLEAN (WINAPI *pGetUserNameExA)(EXTENDED_NAME_FORMAT, LPSTR, PULONG);
+static RPC_STATUS (WINAPI *pRpcBindingSetAuthInfoExA)(RPC_BINDING_HANDLE, RPC_CSTR, ULONG, ULONG,
+                                                      RPC_AUTH_IDENTITY_HANDLE, ULONG, RPC_SECURITY_QOS *);
+static RPC_STATUS (WINAPI *pRpcServerRegisterAuthInfoA)(RPC_CSTR, ULONG, RPC_AUTH_KEY_RETRIEVAL_FN, LPVOID);
+
+static char *domain_and_user;
 
 /* type check statements generated in header file */
 fnprintf *p_printf = printf;
@@ -48,10 +56,14 @@ fnprintf *p_printf = printf;
 static void InitFunctionPointers(void)
 {
     HMODULE hrpcrt4 = GetModuleHandleA("rpcrt4.dll");
+    HMODULE hsecur32 = LoadLibraryA("secur32.dll");
 
     pNDRSContextMarshall2 = (void *)GetProcAddress(hrpcrt4, "NDRSContextMarshall2");
     pNDRSContextUnmarshall2 = (void *)GetProcAddress(hrpcrt4, "NDRSContextUnmarshall2");
     pRpcServerRegisterIfEx = (void *)GetProcAddress(hrpcrt4, "RpcServerRegisterIfEx");
+    pRpcBindingSetAuthInfoExA = (void *)GetProcAddress(hrpcrt4, "RpcBindingSetAuthInfoExA");
+    pRpcServerRegisterAuthInfoA = (void *)GetProcAddress(hrpcrt4, "RpcServerRegisterAuthInfoA");
+    pGetUserNameExA = (void *)GetProcAddress(hsecur32, "GetUserNameExA");
 
     if (!pNDRSContextMarshall2) old_windows_version = TRUE;
 }
@@ -1307,6 +1319,64 @@ array_tests(void)
   HeapFree(GetProcessHeap(), 0, pi);
 }
 
+void
+s_authinfo_test(unsigned int protseq, int secure)
+{
+    RPC_BINDING_HANDLE binding;
+    RPC_STATUS status;
+    ULONG level, authnsvc;
+    RPC_AUTHZ_HANDLE privs;
+    unsigned char *principal;
+
+    binding = I_RpcGetCurrentCallHandle();
+    ok(binding != NULL, "I_RpcGetCurrentCallHandle returned NULL\n");
+
+    level = authnsvc = 0xdeadbeef;
+    privs = (RPC_AUTHZ_HANDLE)0xdeadbeef;
+    principal = (unsigned char *)0xdeadbeef;
+
+    if (secure || protseq == RPC_PROTSEQ_LRPC)
+    {
+        status = RpcImpersonateClient(NULL);
+        ok(status == RPC_S_OK, "expected RPC_S_OK got %u\n", status);
+        status = RpcRevertToSelf();
+        ok(status == RPC_S_OK, "expected RPC_S_OK got %u\n", status);
+
+        status = RpcBindingInqAuthClientA(binding, &privs, &principal, &level, &authnsvc, NULL);
+        ok(status == RPC_S_OK, "expected RPC_S_OK got %u\n", status);
+        ok(privs != (RPC_AUTHZ_HANDLE)0xdeadbeef, "privs unchanged\n");
+        ok(principal != (unsigned char *)0xdeadbeef, "principal unchanged\n");
+        if (protseq != RPC_PROTSEQ_LRPC)
+        {
+            todo_wine
+            ok(principal != NULL, "NULL principal\n");
+        }
+        if (protseq == RPC_PROTSEQ_LRPC && principal)
+        {
+            int len;
+            char *spn;
+
+            len = WideCharToMultiByte(CP_ACP, 0, (const WCHAR *)privs, -1, NULL, 0, NULL, NULL);
+            spn = HeapAlloc( GetProcessHeap(), 0, len );
+            WideCharToMultiByte(CP_ACP, 0, (const WCHAR *)privs, -1, spn, len, NULL, NULL);
+
+            ok(!strcmp(domain_and_user, spn), "expected %s got %s\n", domain_and_user, spn);
+            HeapFree( GetProcessHeap(), 0, spn );
+        }
+        ok(level == RPC_C_AUTHN_LEVEL_PKT_PRIVACY, "level unchanged\n");
+        ok(authnsvc == RPC_C_AUTHN_WINNT, "authnsvc unchanged\n");
+    }
+    else
+    {
+        status = RpcBindingInqAuthClientA(binding, &privs, &principal, &level, &authnsvc, NULL);
+        ok(status == RPC_S_BINDING_HAS_NO_AUTH, "expected RPC_S_BINDING_HAS_NO_AUTH got %u\n", status);
+        ok(privs == (RPC_AUTHZ_HANDLE)0xdeadbeef, "got %p\n", privs);
+        ok(principal == (unsigned char *)0xdeadbeef, "got %s\n", principal);
+        ok(level == 0xdeadbeef, "got %u\n", level);
+        ok(authnsvc == 0xdeadbeef, "got %u\n", authnsvc);
+    }
+}
+
 static void
 run_tests(void)
 {
@@ -1318,48 +1388,86 @@ run_tests(void)
 }
 
 static void
+set_auth_info(RPC_BINDING_HANDLE handle)
+{
+    RPC_STATUS status;
+    RPC_SECURITY_QOS qos;
+
+    qos.Version = 1;
+    qos.Capabilities = RPC_C_QOS_CAPABILITIES_MUTUAL_AUTH;
+    qos.IdentityTracking = RPC_C_QOS_IDENTITY_STATIC;
+    qos.ImpersonationType = RPC_C_IMP_LEVEL_IMPERSONATE;
+
+    status = pRpcBindingSetAuthInfoExA(handle, (RPC_CSTR)domain_and_user, RPC_C_AUTHN_LEVEL_PKT_PRIVACY,
+                                       RPC_C_AUTHN_WINNT, NULL, 0, &qos);
+    ok(status == RPC_S_OK, "RpcBindingSetAuthInfoExA failed %d\n", status);
+}
+
+static void
 client(const char *test)
 {
+  static unsigned char iptcp[] = "ncacn_ip_tcp";
+  static unsigned char np[] = "ncacn_np";
+  static unsigned char ncalrpc[] = "ncalrpc";
+  static unsigned char address[] = "127.0.0.1";
+  static unsigned char address_np[] = "\\\\.";
+  static unsigned char port[] = PORT;
+  static unsigned char pipe[] = PIPE;
+  static unsigned char guid[] = "00000000-4114-0704-2301-000000000000";
+
+  unsigned char *binding;
+
   if (strcmp(test, "tcp_basic") == 0)
   {
-    static unsigned char iptcp[] = "ncacn_ip_tcp";
-    static unsigned char address[] = "127.0.0.1";
-    static unsigned char port[] = PORT;
-    unsigned char *binding;
-
     ok(RPC_S_OK == RpcStringBindingCompose(NULL, iptcp, address, port, NULL, &binding), "RpcStringBindingCompose\n");
     ok(RPC_S_OK == RpcBindingFromStringBinding(binding, &IServer_IfHandle), "RpcBindingFromStringBinding\n");
 
     run_tests();
+    authinfo_test(RPC_PROTSEQ_TCP, 0);
+
+    ok(RPC_S_OK == RpcStringFree(&binding), "RpcStringFree\n");
+    ok(RPC_S_OK == RpcBindingFree(&IServer_IfHandle), "RpcBindingFree\n");
+  }
+  else if (strcmp(test, "tcp_secure") == 0)
+  {
+    ok(RPC_S_OK == RpcStringBindingCompose(NULL, iptcp, address, port, NULL, &binding), "RpcStringBindingCompose\n");
+    ok(RPC_S_OK == RpcBindingFromStringBinding(binding, &IServer_IfHandle), "RpcBindingFromStringBinding\n");
+
+    set_auth_info(IServer_IfHandle);
+    authinfo_test(RPC_PROTSEQ_TCP, 1);
 
     ok(RPC_S_OK == RpcStringFree(&binding), "RpcStringFree\n");
     ok(RPC_S_OK == RpcBindingFree(&IServer_IfHandle), "RpcBindingFree\n");
   }
   else if (strcmp(test, "ncalrpc_basic") == 0)
   {
-    static unsigned char ncalrpc[] = "ncalrpc";
-    static unsigned char guid[] = "00000000-4114-0704-2301-000000000000";
-    unsigned char *binding;
+    ok(RPC_S_OK == RpcStringBindingCompose(NULL, ncalrpc, NULL, guid, NULL, &binding), "RpcStringBindingCompose\n");
+    ok(RPC_S_OK == RpcBindingFromStringBinding(binding, &IServer_IfHandle), "RpcBindingFromStringBinding\n");
+
+    run_tests(); /* can cause RPC_X_BAD_STUB_DATA exception */
+    authinfo_test(RPC_PROTSEQ_LRPC, 0);
 
+    ok(RPC_S_OK == RpcStringFree(&binding), "RpcStringFree\n");
+    ok(RPC_S_OK == RpcBindingFree(&IServer_IfHandle), "RpcBindingFree\n");
+  }
+  else if (strcmp(test, "ncalrpc_secure") == 0)
+  {
     ok(RPC_S_OK == RpcStringBindingCompose(NULL, ncalrpc, NULL, guid, NULL, &binding), "RpcStringBindingCompose\n");
     ok(RPC_S_OK == RpcBindingFromStringBinding(binding, &IServer_IfHandle), "RpcBindingFromStringBinding\n");
 
-    run_tests();
+    set_auth_info(IServer_IfHandle);
+    authinfo_test(RPC_PROTSEQ_LRPC, 1);
 
     ok(RPC_S_OK == RpcStringFree(&binding), "RpcStringFree\n");
     ok(RPC_S_OK == RpcBindingFree(&IServer_IfHandle), "RpcBindingFree\n");
   }
   else if (strcmp(test, "np_basic") == 0)
   {
-    static unsigned char np[] = "ncacn_np";
-    static unsigned char address[] = "\\\\.";
-    static unsigned char pipe[] = PIPE;
-    unsigned char *binding;
-
-    ok(RPC_S_OK == RpcStringBindingCompose(NULL, np, address, pipe, NULL, &binding), "RpcStringBindingCompose\n");
+    ok(RPC_S_OK == RpcStringBindingCompose(NULL, np, address_np, pipe, NULL, &binding), "RpcStringBindingCompose\n");
     ok(RPC_S_OK == RpcBindingFromStringBinding(binding, &IServer_IfHandle), "RpcBindingFromStringBinding\n");
 
     run_tests();
+    authinfo_test(RPC_PROTSEQ_NMP, 0);
     stop();
 
     ok(RPC_S_OK == RpcStringFree(&binding), "RpcStringFree\n");
@@ -1409,12 +1517,19 @@ server(void)
   if (iptcp_status == RPC_S_OK)
     run_client("tcp_basic");
   else
-    skip("tcp_basic tests skipped due to earlier failure\n");
+    skip("tcp tests skipped due to earlier failure\n");
 
   if (ncalrpc_status == RPC_S_OK)
+  {
     run_client("ncalrpc_basic");
+    if (pGetUserNameExA)
+    {
+      /* we don't need to register RPC_C_AUTHN_WINNT for ncalrpc */
+      run_client("ncalrpc_secure");
+    }
+  }
   else
-    skip("ncalrpc_basic tests skipped due to earlier failure\n");
+    skip("lrpc tests skipped due to earlier failure\n");
 
   if (np_status == RPC_S_OK)
     run_client("np_basic");
@@ -1445,6 +1560,14 @@ START_TEST(server)
 
   InitFunctionPointers();
 
+  if (pGetUserNameExA)
+  {
+    ULONG size = 0;
+    ok(!pGetUserNameExA(NameSamCompatible, NULL, &size), "GetUserNameExA\n");
+    domain_and_user = HeapAlloc(GetProcessHeap(), 0, size);
+    ok(pGetUserNameExA(NameSamCompatible, domain_and_user, &size), "GetUserNameExA\n");
+  }
+
   argc = winetest_get_mainargs(&argv);
   progname = argv[0];
 
@@ -1462,4 +1585,6 @@ START_TEST(server)
   }
   else
     server();
+
+  HeapFree(GetProcessHeap(), 0, domain_and_user);
 }
diff --git a/dlls/rpcrt4/tests/server.idl b/dlls/rpcrt4/tests/server.idl
index ee03c0b..29b90b5 100644
--- a/dlls/rpcrt4/tests/server.idl
+++ b/dlls/rpcrt4/tests/server.idl
@@ -361,5 +361,7 @@ cpp_quote("#endif")
   void full_pointer_test([in, ptr] int *a, [in, ptr] int *b);
   void full_pointer_null_test([in, ptr] int *a, [in, ptr] int *b);
 
+  void authinfo_test(unsigned int protseq, int secure);
+
   void stop(void);
 }
-- 
1.6.3.3




More information about the wine-patches mailing list