[PATCH v2 12/22] ntdll: Protect global cached_modref access with ldr_data_srw_lock.

Paul Gofman pgofman at codeweavers.com
Tue Oct 5 17:49:18 CDT 2021


Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
v2:
    - protect cached_modref with ldr_data_srw_lock instead of introducing interlocked access.

 dlls/ntdll/loader.c | 65 +++++++++++++++++++++++++++++++--------------
 1 file changed, 45 insertions(+), 20 deletions(-)

diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c
index 047e837238c..8e2ed03ad4c 100644
--- a/dlls/ntdll/loader.c
+++ b/dlls/ntdll/loader.c
@@ -206,7 +206,10 @@ static RTL_SRWLOCK ldr_data_srw_lock = RTL_SRWLOCK_INIT;
 static RTL_BITMAP tls_bitmap;
 static RTL_BITMAP tls_expansion_bitmap;
 
+/* Guarded by ldr_data_srw_lock. */
 static WINE_MODREF *cached_modref;
+
+/* Used with exclusive loader lock only. */
 static WINE_MODREF *current_modref;
 static WINE_MODREF *last_failed_modref;
 
@@ -470,6 +473,33 @@ static void lock_loader_restore_exclusive(void)
     locked_exclusive = TRUE;
 }
 
+/*************************************************************************
+ *		get_cached_modref
+ *
+ */
+static WINE_MODREF *get_cached_modref(void)
+{
+    WINE_MODREF *ret;
+
+    RtlAcquireSRWLockShared( &ldr_data_srw_lock );
+    ret = cached_modref;
+    RtlReleaseSRWLockShared( &ldr_data_srw_lock );
+    return ret;
+}
+
+/*************************************************************************
+ *		set_cached_modref
+ *
+ * Returns the new cached modref.
+ */
+static WINE_MODREF *set_cached_modref( WINE_MODREF *new )
+{
+    RtlAcquireSRWLockExclusive( &ldr_data_srw_lock );
+    cached_modref = new;
+    RtlReleaseSRWLockExclusive( &ldr_data_srw_lock );
+    return new;
+}
+
 #define RTL_UNLOAD_EVENT_TRACE_NUMBER 64
 
 typedef struct _RTL_UNLOAD_EVENT_TRACE
@@ -753,17 +783,18 @@ static void call_ldr_notifications( ULONG reason, LDR_DATA_TABLE_ENTRY *module )
  */
 static WINE_MODREF *get_modref( HMODULE hmod )
 {
+    WINE_MODREF *cached = get_cached_modref();
     PLIST_ENTRY mark, entry;
     PLDR_DATA_TABLE_ENTRY mod;
 
-    if (cached_modref && cached_modref->ldr.DllBase == hmod) return cached_modref;
+    if (cached && cached->ldr.DllBase == hmod) return cached;
 
     mark = &NtCurrentTeb()->Peb->LdrData->InMemoryOrderModuleList;
     for (entry = mark->Flink; entry != mark; entry = entry->Flink)
     {
         mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InMemoryOrderLinks);
         if (mod->DllBase == hmod)
-            return cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr);
+            return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) );
     }
     return NULL;
 }
@@ -777,23 +808,21 @@ static WINE_MODREF *get_modref( HMODULE hmod )
  */
 static WINE_MODREF *find_basename_module( LPCWSTR name )
 {
+    WINE_MODREF *cached = get_cached_modref();
     PLIST_ENTRY mark, entry;
     UNICODE_STRING name_str;
 
     RtlInitUnicodeString( &name_str, name );
 
-    if (cached_modref && RtlEqualUnicodeString( &name_str, &cached_modref->ldr.BaseDllName, TRUE ))
-        return cached_modref;
+    if (cached && RtlEqualUnicodeString( &name_str, &cached->ldr.BaseDllName, TRUE ))
+        return cached;
 
     mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList;
     for (entry = mark->Flink; entry != mark; entry = entry->Flink)
     {
         LDR_DATA_TABLE_ENTRY *mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks);
         if (RtlEqualUnicodeString( &name_str, &mod->BaseDllName, TRUE ))
-        {
-            cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr);
-            return cached_modref;
-        }
+            return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) );
     }
     return NULL;
 }
@@ -807,6 +836,7 @@ static WINE_MODREF *find_basename_module( LPCWSTR name )
  */
 static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name )
 {
+    WINE_MODREF *cached = get_cached_modref();
     PLIST_ENTRY mark, entry;
     UNICODE_STRING name = *nt_name;
 
@@ -814,18 +844,15 @@ static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name )
     name.Length -= 4 * sizeof(WCHAR);  /* for \??\ prefix */
     name.Buffer += 4;
 
-    if (cached_modref && RtlEqualUnicodeString( &name, &cached_modref->ldr.FullDllName, TRUE ))
-        return cached_modref;
+    if (cached && RtlEqualUnicodeString( &name, &cached->ldr.FullDllName, TRUE ))
+        return cached;
 
     mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList;
     for (entry = mark->Flink; entry != mark; entry = entry->Flink)
     {
         LDR_DATA_TABLE_ENTRY *mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks);
         if (RtlEqualUnicodeString( &name, &mod->FullDllName, TRUE ))
-        {
-            cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr);
-            return cached_modref;
-        }
+            return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) );
     }
     return NULL;
 }
@@ -839,9 +866,10 @@ static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name )
  */
 static WINE_MODREF *find_fileid_module( const struct file_id *id )
 {
+    WINE_MODREF *cached = get_cached_modref();
     LIST_ENTRY *mark, *entry;
 
-    if (cached_modref && !memcmp( &cached_modref->id, id, sizeof(*id) )) return cached_modref;
+    if (cached && !memcmp( &cached->id, id, sizeof(*id) )) return cached;
 
     mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList;
     for (entry = mark->Flink; entry != mark; entry = entry->Flink)
@@ -850,10 +878,7 @@ static WINE_MODREF *find_fileid_module( const struct file_id *id )
         WINE_MODREF *wm = CONTAINING_RECORD( mod, WINE_MODREF, ldr );
 
         if (!memcmp( &wm->id, id, sizeof(*id) ))
-        {
-            cached_modref = wm;
-            return wm;
-        }
+            return set_cached_modref( wm );
     }
     return NULL;
 }
@@ -3810,6 +3835,7 @@ static void free_modref( WINE_MODREF *wm )
     RemoveEntryList(&wm->ldr.InMemoryOrderLinks);
     if (wm->ldr.InInitializationOrderLinks.Flink)
         RemoveEntryList(&wm->ldr.InInitializationOrderLinks);
+    if (cached_modref == wm) cached_modref = NULL;
     RtlReleaseSRWLockExclusive( &ldr_data_srw_lock );
 
     TRACE(" unloading %s\n", debugstr_w(wm->ldr.FullDllName.Buffer));
@@ -3821,7 +3847,6 @@ static void free_modref( WINE_MODREF *wm )
     free_tls_slot( &wm->ldr );
     RtlReleaseActivationContext( wm->ldr.ActivationContext );
     NtUnmapViewOfSection( NtCurrentProcess(), wm->ldr.DllBase );
-    if (cached_modref == wm) cached_modref = NULL;
     RtlFreeUnicodeString( &wm->ldr.FullDllName );
     RtlFreeHeap( GetProcessHeap(), 0, wm->deps );
     RtlFreeHeap( GetProcessHeap(), 0, wm );
-- 
2.31.1




More information about the wine-devel mailing list