[PATCH v2 05/22] ntdll: Only lock loader_section when calling application callbacks().

Paul Gofman pgofman at codeweavers.com
Tue Oct 5 17:49:11 CDT 2021


Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
 dlls/kernel32/tests/loader.c | 177 +++++++++++++++++++++++++++++++++++
 dlls/ntdll/loader.c          |  58 ++++++++++--
 include/winternl.h           |   2 +-
 3 files changed, 230 insertions(+), 7 deletions(-)

diff --git a/dlls/kernel32/tests/loader.c b/dlls/kernel32/tests/loader.c
index 4f1b11338a6..045338c1289 100644
--- a/dlls/kernel32/tests/loader.c
+++ b/dlls/kernel32/tests/loader.c
@@ -70,6 +70,10 @@ static NTSTATUS (WINAPI *pLdrLockLoaderLock)(ULONG, ULONG *, ULONG_PTR *);
 static NTSTATUS (WINAPI *pLdrUnlockLoaderLock)(ULONG, ULONG_PTR);
 static NTSTATUS (WINAPI *pLdrLoadDll)(LPCWSTR,DWORD,const UNICODE_STRING *,HMODULE*);
 static NTSTATUS (WINAPI *pLdrUnloadDll)(HMODULE);
+
+typedef void  (CALLBACK *LDRENUMPROC)(LDR_DATA_TABLE_ENTRY *, void *, BOOLEAN *);
+static NTSTATUS (WINAPI *pLdrEnumerateLoadedModules)( void *unknown, LDRENUMPROC callback, void *context );
+
 static void (WINAPI *pRtlInitUnicodeString)(PUNICODE_STRING,LPCWSTR);
 static void (WINAPI *pRtlAcquirePebLock)(void);
 static void (WINAPI *pRtlReleasePebLock)(void);
@@ -4036,6 +4040,177 @@ static void test_Wow64Transition(void)
             debugstr_wn(name->SectionFileName.Buffer, name->SectionFileName.Length / sizeof(WCHAR)));
 }
 
+static BOOL test_loader_lock_repeat_lock;
+static unsigned int test_loader_notification_count;
+static const WCHAR *ldr_notify_track_dll;
+static HANDLE lock_ready_event, next_test_event;
+static BOOL test_loader_lock_abort_test;
+static volatile LONG test_loader_lock_timeout_count;
+
+#define BLOCKING_TESTS_ENABLED 0
+
+static DWORD WINAPI test_loader_lock_thread(void *param)
+{
+    ULONG_PTR magic;
+    NTSTATUS status;
+    DWORD result;
+
+    WaitForSingleObject(next_test_event, INFINITE);
+    SetEvent(lock_ready_event);
+    WaitForSingleObject(next_test_event, INFINITE);
+
+    /* 1. Test with loader lock held. */
+    do
+    {
+        status = pLdrLockLoaderLock(0, NULL, &magic);
+        ok(!status, "Got unexpected status %#x.\n", status);
+        SetEvent(lock_ready_event);
+
+        result = WaitForSingleObject(next_test_event, 3000);
+        if (result == WAIT_TIMEOUT)
+            ++test_loader_lock_timeout_count;
+
+        status = pLdrUnlockLoaderLock(0, magic);
+        ok(!status, "Got unexpected status %#x.\n", status);
+
+        if (result == WAIT_TIMEOUT)
+        {
+            WaitForSingleObject(next_test_event, INFINITE);
+            test_loader_lock_timeout_count = 0;
+            if (test_loader_lock_abort_test)
+                return 0;
+        }
+    } while (test_loader_lock_repeat_lock);
+
+    SetEvent(lock_ready_event);
+    WaitForSingleObject(next_test_event, INFINITE);
+
+    SetEvent(lock_ready_event);
+    return 0;
+}
+
+static void CALLBACK test_loader_lock_ldr_notify(ULONG reason, LDR_DLL_NOTIFICATION_DATA *data, void *context)
+{
+    if (!test_loader_lock_timeout_count && !lstrcmpW(data->Loaded.BaseDllName->Buffer, ldr_notify_track_dll))
+        ++test_loader_notification_count;
+}
+
+static void CALLBACK test_loader_lock_enum_callback(LDR_DATA_TABLE_ENTRY *mod, void *context, BOOLEAN *stop)
+{
+    *stop = TRUE;
+}
+
+#define check_timeout(expected_timeout) check_timeout_(__LINE__, expected_timeout)
+static void check_timeout_(unsigned int line, BOOL expected_timeout)
+{
+    ok_(__FILE__, line)(!!test_loader_lock_timeout_count == !!expected_timeout,
+            "Got timeout count %u, expected_timeout %d.\n",
+            test_loader_lock_timeout_count, expected_timeout);
+    if (test_loader_lock_timeout_count)
+    {
+        SetEvent(next_test_event);
+        WaitForSingleObject(lock_ready_event, INFINITE);
+    }
+}
+
+static void test_loader_lock_next_test(void)
+{
+    /* Exit previous test loop. */
+    test_loader_lock_repeat_lock = FALSE;
+    SetEvent(next_test_event);
+    WaitForSingleObject(lock_ready_event, INFINITE);
+    test_loader_lock_repeat_lock = TRUE;
+
+    /* Star new loop. */
+    SetEvent(next_test_event);
+    WaitForSingleObject(lock_ready_event, INFINITE);
+}
+
+static void test_loader_lock(void)
+{
+    static const WCHAR not_loaded_dll_name[] = L"authz.dll";
+    static const WCHAR preloaded_dll_name[] = L"winmm.dll";
+    HMODULE hmodule_preloaded, hmodule;
+    NTSTATUS status;
+    HANDLE thread;
+    void *cookie;
+    BOOL bret;
+
+    lock_ready_event = CreateEventA(NULL, FALSE, FALSE, "test_lock_ready_event");
+    next_test_event = CreateEventA(NULL, FALSE, FALSE, "test_next_test_event");
+
+    thread = CreateThread(NULL, 0, test_loader_lock_thread, NULL, 0, NULL);
+
+    hmodule_preloaded = LoadLibraryW(preloaded_dll_name);
+    ok(!!hmodule_preloaded, "LoadLibrary failed, err %u.\n", GetLastError());
+    hmodule = GetModuleHandleW(not_loaded_dll_name);
+    ok(!hmodule, "%s is already loaded.\n", not_loaded_dll_name);
+
+    test_loader_notification_count = 0;
+
+    /* 1. Test with loader lock held. */
+    trace("Test 1.\n");
+    test_loader_lock_next_test();
+
+    status = LdrRegisterDllNotification(0, test_loader_lock_ldr_notify, NULL, &cookie);
+    ok(!status, "Got unexpected status %#x.\n", status);
+    if (test_loader_lock_timeout_count)
+    {
+        test_loader_lock_abort_test = TRUE;
+        SetEvent(next_test_event);
+        win_skip("Old loader, tests skipped.\n");
+        goto done;
+    }
+    check_timeout(FALSE);
+    ldr_notify_track_dll = not_loaded_dll_name;
+
+    bret = GetModuleHandleExW(0, preloaded_dll_name, &hmodule);
+    ok(bret, "GetModuleHandleEx failed, err %u.\n", GetLastError());
+    ok(hmodule == hmodule_preloaded, "Got unexpected hmodule %p, expected %p.\n", hmodule, hmodule_preloaded);
+    check_timeout(FALSE);
+
+    bret = FreeLibrary(hmodule);
+    ok(bret, "FreeLibrary failed, err %u.\n", GetLastError());
+    check_timeout(FALSE);
+
+    ok(!test_loader_notification_count, "Got unexpected test_loader_notification_count %u.\n",
+            test_loader_notification_count);
+
+    if (BLOCKING_TESTS_ENABLED)
+    {
+        /* With loader lock held notification callback is called which should mean that:
+         *  - LDR notifications themselves do not wait on loader lock;
+         *  - The library load goes far enough to call the LDR notification until it blocks on the loader lock.
+         */
+        test_loader_notification_count = 0;
+        hmodule = LoadLibraryW(not_loaded_dll_name);
+        todo_wine ok(test_loader_notification_count == 1
+                || broken(!test_loader_notification_count) /* before Win10 1607. */,
+                "Got unexpected notification count %u.\n", test_loader_notification_count);
+        check_timeout(TRUE);
+
+        bret = FreeLibrary(hmodule);
+        ok(bret, "FreeLibrary failed, err %u.\n", GetLastError());
+        check_timeout(TRUE);
+
+        pLdrEnumerateLoadedModules(NULL, test_loader_lock_enum_callback, NULL);
+        check_timeout(TRUE);
+    }
+
+    LdrUnregisterDllNotification( cookie );
+    check_timeout(FALSE);
+
+    test_loader_lock_next_test();
+
+done:
+    WaitForSingleObject(thread, INFINITE);
+    CloseHandle(thread);
+
+    FreeLibrary(hmodule_preloaded);
+    CloseHandle(lock_ready_event);
+    CloseHandle(next_test_event);
+}
+
 START_TEST(loader)
 {
     int argc;
@@ -4060,6 +4235,7 @@ START_TEST(loader)
     pLdrUnlockLoaderLock = (void *)GetProcAddress(ntdll, "LdrUnlockLoaderLock");
     pLdrLoadDll = (void *)GetProcAddress(ntdll, "LdrLoadDll");
     pLdrUnloadDll = (void *)GetProcAddress(ntdll, "LdrUnloadDll");
+    pLdrEnumerateLoadedModules = (void *)GetProcAddress(ntdll, "LdrEnumerateLoadedModules");
     pRtlInitUnicodeString = (void *)GetProcAddress(ntdll, "RtlInitUnicodeString");
     pRtlAcquirePebLock = (void *)GetProcAddress(ntdll, "RtlAcquirePebLock");
     pRtlReleasePebLock = (void *)GetProcAddress(ntdll, "RtlReleasePebLock");
@@ -4113,6 +4289,7 @@ START_TEST(loader)
     test_dll_file( "advapi32.dll" );
     test_dll_file( "user32.dll" );
     test_Wow64Transition();
+    test_loader_lock();
     /* loader test must be last, it can corrupt the internal loader state on Windows */
     test_Loader();
 }
diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c
index 5ee4215f875..d0d4e5448ed 100644
--- a/dlls/ntdll/loader.c
+++ b/dlls/ntdll/loader.c
@@ -160,6 +160,8 @@ static RTL_CRITICAL_SECTION_DEBUG critsect_debug =
 };
 static RTL_CRITICAL_SECTION loader_section = { &critsect_debug, -1, 0, 0, 0, 0 };
 
+static RTL_SRWLOCK loader_srw_lock = RTL_SRWLOCK_INIT;
+
 static CRITICAL_SECTION dlldir_section;
 static CRITICAL_SECTION_DEBUG dlldir_critsect_debug =
 {
@@ -217,24 +219,58 @@ static inline BOOL contains_path( LPCWSTR name )
     return ((*name && (name[1] == ':')) || wcschr(name, '/') || wcschr(name, '\\'));
 }
 
+/*************************************************************************
+ *		inc_recursion_count
+ *
+ * Increment thread local internal loader lock recursion count and return the old value.
+ */
+static ULONG inc_recursion_count(void)
+{
+    return NtCurrentTeb()->Spare2++;
+}
+
+/*************************************************************************
+ *		dec_recursion_count
+ *
+ * Decrement thread local internal loader lock recursion count and return the new value.
+ */
+static ULONG dec_recursion_count(void)
+{
+    return --NtCurrentTeb()->Spare2;
+}
+
 /*************************************************************************
  *		lock_loader_exclusive
  *
- * Take exclusive loader lock.
+ * Take exclusive ownership of internal loader lock.
+ * Recursive locking is allowed.
  */
 static void lock_loader_exclusive(void)
 {
-    RtlEnterCriticalSection( &loader_section );
+    ULONG recursion_count = inc_recursion_count();
+
+    TRACE( "recursion_count %u.\n", recursion_count );
+    if (!recursion_count && !RtlDllShutdownInProgress())
+        RtlAcquireSRWLockExclusive( &loader_srw_lock );
 }
 
 /*************************************************************************
  *		unlock_loader
  *
- * Release loader lock.
+ * Release internal loader lock.
  */
 static void unlock_loader(void)
 {
-    RtlLeaveCriticalSection( &loader_section );
+    ULONG recursion_count = dec_recursion_count();
+
+    TRACE( "recursion_count %u.\n", recursion_count );
+
+    if (RtlDllShutdownInProgress()) return;
+
+    assert( recursion_count != ~0u );
+
+    if (!recursion_count)
+        RtlReleaseSRWLockExclusive( &loader_srw_lock );
 }
 
 #define RTL_UNLOAD_EVENT_TRACE_NUMBER 64
@@ -494,6 +530,7 @@ static void call_ldr_notifications( ULONG reason, LDR_DATA_TABLE_ENTRY *module )
     data.Loaded.DllBase     = module->DllBase;
     data.Loaded.SizeOfImage = module->SizeOfImage;
 
+    RtlEnterCriticalSection( &loader_section );
     RtlEnterCriticalSection( &ldr_notifications_section );
     LIST_FOR_EACH_ENTRY_SAFE( notify, notify_next, &ldr_notifications, struct ldr_notification, entry )
     {
@@ -506,6 +543,7 @@ static void call_ldr_notifications( ULONG reason, LDR_DATA_TABLE_ENTRY *module )
                 notify->callback, reason, &data, notify->context );
     }
     RtlLeaveCriticalSection( &ldr_notifications_section );
+    RtlLeaveCriticalSection( &loader_section );
 }
 
 /*************************************************************************
@@ -1355,6 +1393,8 @@ static void call_tls_callbacks( HMODULE module, UINT reason )
     dir = RtlImageDirectoryEntryToData( module, TRUE, IMAGE_DIRECTORY_ENTRY_TLS, &dirsize );
     if (!dir || !dir->AddressOfCallBacks) return;
 
+    RtlEnterCriticalSection( &loader_section );
+
     for (callback = (const PIMAGE_TLS_CALLBACK *)dir->AddressOfCallBacks; *callback; callback++)
     {
         TRACE_(relay)("\1Call TLS callback (proc=%p,module=%p,reason=%s,reserved=0)\n",
@@ -1373,6 +1413,8 @@ static void call_tls_callbacks( HMODULE module, UINT reason )
         TRACE_(relay)("\1Ret  TLS callback (proc=%p,module=%p,reason=%s,reserved=0)\n",
                       *callback, module, reason_names[reason] );
     }
+
+    RtlLeaveCriticalSection( &loader_section );
 }
 
 /*************************************************************************
@@ -1405,6 +1447,8 @@ static NTSTATUS MODULE_InitDLL( WINE_MODREF *wm, UINT reason, LPVOID lpReserved
     else TRACE("(%p %s,%s,%p) - CALL\n", module, debugstr_w(wm->ldr.BaseDllName.Buffer),
                reason_names[reason], lpReserved );
 
+    RtlEnterCriticalSection( &loader_section );
+
     __TRY
     {
         retv = call_dll_entry_point( entry, module, reason, lpReserved );
@@ -1419,6 +1463,8 @@ static NTSTATUS MODULE_InitDLL( WINE_MODREF *wm, UINT reason, LPVOID lpReserved
     }
     __ENDTRY
 
+    RtlLeaveCriticalSection( &loader_section );
+
     /* The state of the module list may have changed due to the call
        to the dll. We cannot assume that this module has not been
        deleted.  */
@@ -1658,7 +1704,7 @@ NTSTATUS WINAPI LdrEnumerateLoadedModules( void *unknown, LDRENUMPROC callback,
         return STATUS_INVALID_PARAMETER;
 
     lock_loader_exclusive();
-
+    RtlEnterCriticalSection( &loader_section );
     mark = &NtCurrentTeb()->Peb->LdrData->InMemoryOrderModuleList;
     for (entry = mark->Flink; entry != mark; entry = entry->Flink)
     {
@@ -1666,7 +1712,7 @@ NTSTATUS WINAPI LdrEnumerateLoadedModules( void *unknown, LDRENUMPROC callback,
         callback( mod, context, &stop );
         if (stop) break;
     }
-
+    RtlLeaveCriticalSection( &loader_section );
     unlock_loader();
     return STATUS_SUCCESS;
 }
diff --git a/include/winternl.h b/include/winternl.h
index b6f93c116d8..521454396b4 100644
--- a/include/winternl.h
+++ b/include/winternl.h
@@ -466,7 +466,7 @@ typedef struct _TEB
     PVOID                        Instrumentation[16];               /* f2c/16b8 */
     PVOID                        WinSockData;                       /* f6c/1738 */
     ULONG                        GdiBatchCount;                     /* f70/1740 */
-    ULONG                        Spare2;                            /* f74/1744 */
+    ULONG                        Spare2;                            /* f74/1744 used for ntdll loader data in Wine */
     ULONG                        GuaranteedStackBytes;              /* f78/1748 */
     PVOID                        ReservedForPerf;                   /* f7c/1750 */
     PVOID                        ReservedForOle;                    /* f80/1758 */
-- 
2.31.1




More information about the wine-devel mailing list