[PATCH v2 19/22] ntdll: Use shared loader locking in LdrGetProcedureAddress() when possible.

Paul Gofman pgofman at codeweavers.com
Tue Oct 5 17:49:25 CDT 2021


Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
 dlls/kernel32/tests/loader.c | 24 ++++++++++++++
 dlls/ntdll/loader.c          | 62 ++++++++++++++++++++++++++----------
 2 files changed, 70 insertions(+), 16 deletions(-)

diff --git a/dlls/kernel32/tests/loader.c b/dlls/kernel32/tests/loader.c
index 9f502aaf857..ba979a26a32 100644
--- a/dlls/kernel32/tests/loader.c
+++ b/dlls/kernel32/tests/loader.c
@@ -4337,6 +4337,10 @@ static void test_loader_lock(void)
     ok(bret, "GetModuleHandleEx failed, err %u.\n", GetLastError());
     ok(hmodule == hmodule_preloaded, "Got unexpected hmodule %p, expected %p.\n", hmodule, hmodule_preloaded);
 
+    proc = GetProcAddress(hmodule, "timeGetTime");
+    check_timeout(FALSE);
+    ok(!!proc, "GetProcAddress failed.\n");
+
     status = pLdrUnloadDll(hmodule);
     ok(!status, "Got unexpected status %#x.\n", status);
     if (test_loader_lock_timeout_count)
@@ -4417,6 +4421,10 @@ static void test_loader_lock(void)
     ok(hmodule == hmodule_preloaded, "Got unexpected hmodule %p, expected %p.\n", hmodule, hmodule_preloaded);
     check_timeout(FALSE);
 
+    proc = GetProcAddress(hmodule, "timeGetTime");
+    ok(!!proc, "GetProcAddress failed.\n");
+    check_timeout(FALSE);
+
     if (!blocks_on_decref_library)
     {
         bret = FreeLibrary(hmodule);
@@ -4426,6 +4434,22 @@ static void test_loader_lock(void)
 
     ok(!!lock_dll_handle, "Got NULL lock_dll_handle.\n");
 
+    if (BLOCKING_TESTS_ENABLED)
+    {
+        proc = GetProcAddress(lock_dll_handle, "test_proc");
+        if (GetLastError() == ERROR_MOD_NOT_FOUND && !blocks_on_load_in_progress_module)
+        {
+            /* Win 10 1507-1709. */
+            check_timeout(TRUE);
+        }
+        else
+        {
+            check_timeout(FALSE);
+        }
+        ok(!proc, "GetProcAddress failed, err %u.\n", GetLastError());
+        ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_MOD_NOT_FOUND),
+                "Got unexpected error %u.\n", GetLastError());
+    }
     status = LdrAddRefDll(0, lock_dll_handle);
     check_timeout(FALSE);
     ok(status == STATUS_DLL_NOT_FOUND, "Got unexpected status %#x.\n", status);
diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c
index a503b577027..a044b9cce27 100644
--- a/dlls/ntdll/loader.c
+++ b/dlls/ntdll/loader.c
@@ -217,9 +217,9 @@ static NTSTATUS load_dll( const WCHAR *load_path, const WCHAR *libname, const WC
                           DWORD flags, WINE_MODREF** pwm );
 static NTSTATUS process_attach( WINE_MODREF *wm, LPVOID lpReserved );
 static FARPROC find_ordinal_export( HMODULE module, const IMAGE_EXPORT_DIRECTORY *exports,
-                                    DWORD exp_size, DWORD ordinal, LPCWSTR load_path );
+                                    DWORD exp_size, DWORD ordinal, LPCWSTR load_path, NTSTATUS *status );
 static FARPROC find_named_export( HMODULE module, const IMAGE_EXPORT_DIRECTORY *exports,
-                                  DWORD exp_size, const char *name, int hint, LPCWSTR load_path );
+                                  DWORD exp_size, const char *name, int hint, LPCWSTR load_path, NTSTATUS *status );
 
 /* convert PE image VirtualAddress to Real Address */
 static inline void *get_rva( HMODULE module, DWORD va )
@@ -921,7 +921,7 @@ static WINE_MODREF **grow_module_deps( WINE_MODREF *wm, int count )
  * Find the final function pointer for a forwarded function.
  * The loader must be locked while calling this function.
  */
-static FARPROC find_forwarded_export( HMODULE module, const char *forward, LPCWSTR load_path )
+static FARPROC find_forwarded_export( HMODULE module, const char *forward, LPCWSTR load_path, NTSTATUS *status )
 {
     const IMAGE_EXPORT_DIRECTORY *exports;
     DWORD exp_size;
@@ -944,6 +944,14 @@ static FARPROC find_forwarded_export( HMODULE module, const char *forward, LPCWS
 
     if (!(wm = find_basename_module( mod_name )))
     {
+        if (!locked_exclusive)
+        {
+            TRACE( "Need to load %s for '%s' while not in exclusive lock.\n", debugstr_w(mod_name), forward );
+            assert( status );
+            if (mod_name != buffer) RtlFreeHeap( GetProcessHeap(), 0, mod_name );
+            if (status) *status = STATUS_NOT_FOUND;
+            return NULL;
+        }
         TRACE( "delay loading %s for '%s'\n", debugstr_w(mod_name), forward );
         if (load_dll( load_path, mod_name, L".dll", 0, &wm ) == STATUS_SUCCESS &&
             !(wm->ldr.Flags & LDR_DONT_RESOLVE_REFS))
@@ -975,9 +983,9 @@ static FARPROC find_forwarded_export( HMODULE module, const char *forward, LPCWS
 
         if (*name == '#') { /* ordinal */
             proc = find_ordinal_export( wm->ldr.DllBase, exports, exp_size,
-                                        atoi(name+1) - exports->Base, load_path );
+                                        atoi(name+1) - exports->Base, load_path, status );
         } else
-            proc = find_named_export( wm->ldr.DllBase, exports, exp_size, name, -1, load_path );
+            proc = find_named_export( wm->ldr.DllBase, exports, exp_size, name, -1, load_path, status );
     }
 
     if (!proc)
@@ -1000,11 +1008,13 @@ static FARPROC find_forwarded_export( HMODULE module, const char *forward, LPCWS
  * The loader must be locked while calling this function.
  */
 static FARPROC find_ordinal_export( HMODULE module, const IMAGE_EXPORT_DIRECTORY *exports,
-                                    DWORD exp_size, DWORD ordinal, LPCWSTR load_path )
+                                    DWORD exp_size, DWORD ordinal, LPCWSTR load_path, NTSTATUS *status )
 {
     FARPROC proc;
     const DWORD *functions = get_rva( module, exports->AddressOfFunctions );
 
+    if (status) *status = STATUS_PROCEDURE_NOT_FOUND;
+
     if (ordinal >= exports->NumberOfFunctions)
     {
         TRACE("	ordinal %d out of range!\n", ordinal + exports->Base );
@@ -1017,7 +1027,7 @@ static FARPROC find_ordinal_export( HMODULE module, const IMAGE_EXPORT_DIRECTORY
     /* if the address falls into the export dir, it's a forward */
     if (((const char *)proc >= (const char *)exports) && 
         ((const char *)proc < (const char *)exports + exp_size))
-        return find_forwarded_export( module, (const char *)proc, load_path );
+        return find_forwarded_export( module, (const char *)proc, load_path, status );
 
     if (TRACE_ON(snoop))
     {
@@ -1063,7 +1073,7 @@ static int find_name_in_exports( HMODULE module, const IMAGE_EXPORT_DIRECTORY *e
  * The loader must be locked while calling this function.
  */
 static FARPROC find_named_export( HMODULE module, const IMAGE_EXPORT_DIRECTORY *exports,
-                                  DWORD exp_size, const char *name, int hint, LPCWSTR load_path )
+                                  DWORD exp_size, const char *name, int hint, LPCWSTR load_path, NTSTATUS *status )
 {
     const WORD *ordinals = get_rva( module, exports->AddressOfNameOrdinals );
     const DWORD *names = get_rva( module, exports->AddressOfNames );
@@ -1074,12 +1084,12 @@ static FARPROC find_named_export( HMODULE module, const IMAGE_EXPORT_DIRECTORY *
     {
         char *ename = get_rva( module, names[hint] );
         if (!strcmp( ename, name ))
-            return find_ordinal_export( module, exports, exp_size, ordinals[hint], load_path );
+            return find_ordinal_export( module, exports, exp_size, ordinals[hint], load_path, status );
     }
 
     /* then do a binary search */
     if ((ordinal = find_name_in_exports( module, exports, name )) == -1) return NULL;
-    return find_ordinal_export( module, exports, exp_size, ordinal, load_path );
+    return find_ordinal_export( module, exports, exp_size, ordinal, load_path, status );
 
 }
 
@@ -1213,7 +1223,7 @@ static BOOL import_dll( HMODULE module, const IMAGE_IMPORT_DESCRIPTOR *descr, LP
             int ordinal = IMAGE_ORDINAL(import_list->u1.Ordinal);
 
             thunk_list->u1.Function = (ULONG_PTR)find_ordinal_export( imp_mod, exports, exp_size,
-                                                                      ordinal - exports->Base, load_path );
+                                                                      ordinal - exports->Base, load_path, NULL );
             if (!thunk_list->u1.Function)
             {
                 thunk_list->u1.Function = allocate_stub( name, IntToPtr(ordinal) );
@@ -1229,7 +1239,7 @@ static BOOL import_dll( HMODULE module, const IMAGE_IMPORT_DESCRIPTOR *descr, LP
             pe_name = get_rva( module, (DWORD)import_list->u1.AddressOfData );
             thunk_list->u1.Function = (ULONG_PTR)find_named_export( imp_mod, exports, exp_size,
                                                                     (const char*)pe_name->Name,
-                                                                    pe_name->Hint, load_path );
+                                                                    pe_name->Hint, load_path, NULL );
             if (!thunk_list->u1.Function)
             {
                 thunk_list->u1.Function = allocate_stub( name, (const char*)pe_name->Name );
@@ -2061,24 +2071,37 @@ NTSTATUS WINAPI LdrUnlockLoaderLock( ULONG flags, ULONG_PTR magic )
 NTSTATUS get_procedure_address( HMODULE module, const ANSI_STRING *name,
                                 ULONG ord, void **address )
 {
+    NTSTATUS status = STATUS_PROCEDURE_NOT_FOUND;
     IMAGE_EXPORT_DIRECTORY *exports;
+    WINE_MODREF *wm;
     DWORD exp_size;
 
+retry:
     /* check if the module itself is invalid to return the proper error */
-    if (!get_modref( module )) return STATUS_DLL_NOT_FOUND;
+    if (!(wm = get_modref( module ))) return STATUS_DLL_NOT_FOUND;
+
+    RtlAcquireSRWLockExclusive( &ldr_data_srw_lock );
+    if (!is_thread_exclusive() && (!wm->ldr.LoadCount || (wm->ldr.Flags & LDR_LOAD_IN_PROGRESS)))
+    {
+        RtlReleaseSRWLockExclusive( &ldr_data_srw_lock );
+        if (!wm->ldr.LoadCount) return STATUS_INVALID_PARAMETER;
+        wait_for_exclusive_lock_release();
+        goto retry;
+    }
+    RtlReleaseSRWLockExclusive( &ldr_data_srw_lock );
 
     if ((exports = RtlImageDirectoryEntryToData( module, TRUE,
                                                       IMAGE_DIRECTORY_ENTRY_EXPORT, &exp_size )))
     {
-        void *proc = name ? find_named_export( module, exports, exp_size, name->Buffer, -1, NULL )
-                          : find_ordinal_export( module, exports, exp_size, ord - exports->Base, NULL );
+        void *proc = name ? find_named_export( module, exports, exp_size, name->Buffer, -1, NULL, &status )
+                          : find_ordinal_export( module, exports, exp_size, ord - exports->Base, NULL, &status );
         if (proc)
         {
             *address = proc;
             return STATUS_SUCCESS;
         }
     }
-    return STATUS_PROCEDURE_NOT_FOUND;
+    return status;
 }
 
 /******************************************************************
@@ -2089,9 +2112,16 @@ NTSTATUS WINAPI LdrGetProcedureAddress( HMODULE module, const ANSI_STRING *name,
 {
     NTSTATUS ret;
 
+    lock_loader_shared();
+    ret = get_procedure_address( module, name, ord, address );
+    unlock_loader();
+
+    if (ret != STATUS_NOT_FOUND) return ret;
+
     lock_loader_exclusive();
     ret = get_procedure_address( module, name, ord, address );
     unlock_loader();
+
     return ret;
 }
 
-- 
2.31.1




More information about the wine-devel mailing list