[PATCH v2 12/22] ntdll: Protect global cached_modref access with ldr_data_srw_lock.
Paul Gofman
pgofman at codeweavers.com
Tue Oct 5 17:49:18 CDT 2021
Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
v2:
- protect cached_modref with ldr_data_srw_lock instead of introducing interlocked access.
dlls/ntdll/loader.c | 65 +++++++++++++++++++++++++++++++--------------
1 file changed, 45 insertions(+), 20 deletions(-)
diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c
index 047e837238c..8e2ed03ad4c 100644
--- a/dlls/ntdll/loader.c
+++ b/dlls/ntdll/loader.c
@@ -206,7 +206,10 @@ static RTL_SRWLOCK ldr_data_srw_lock = RTL_SRWLOCK_INIT;
static RTL_BITMAP tls_bitmap;
static RTL_BITMAP tls_expansion_bitmap;
+/* Guarded by ldr_data_srw_lock. */
static WINE_MODREF *cached_modref;
+
+/* Used with exclusive loader lock only. */
static WINE_MODREF *current_modref;
static WINE_MODREF *last_failed_modref;
@@ -470,6 +473,33 @@ static void lock_loader_restore_exclusive(void)
locked_exclusive = TRUE;
}
+/*************************************************************************
+ * get_cached_modref
+ *
+ */
+static WINE_MODREF *get_cached_modref(void)
+{
+ WINE_MODREF *ret;
+
+ RtlAcquireSRWLockShared( &ldr_data_srw_lock );
+ ret = cached_modref;
+ RtlReleaseSRWLockShared( &ldr_data_srw_lock );
+ return ret;
+}
+
+/*************************************************************************
+ * set_cached_modref
+ *
+ * Returns the new cached modref.
+ */
+static WINE_MODREF *set_cached_modref( WINE_MODREF *new )
+{
+ RtlAcquireSRWLockExclusive( &ldr_data_srw_lock );
+ cached_modref = new;
+ RtlReleaseSRWLockExclusive( &ldr_data_srw_lock );
+ return new;
+}
+
#define RTL_UNLOAD_EVENT_TRACE_NUMBER 64
typedef struct _RTL_UNLOAD_EVENT_TRACE
@@ -753,17 +783,18 @@ static void call_ldr_notifications( ULONG reason, LDR_DATA_TABLE_ENTRY *module )
*/
static WINE_MODREF *get_modref( HMODULE hmod )
{
+ WINE_MODREF *cached = get_cached_modref();
PLIST_ENTRY mark, entry;
PLDR_DATA_TABLE_ENTRY mod;
- if (cached_modref && cached_modref->ldr.DllBase == hmod) return cached_modref;
+ if (cached && cached->ldr.DllBase == hmod) return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InMemoryOrderModuleList;
for (entry = mark->Flink; entry != mark; entry = entry->Flink)
{
mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InMemoryOrderLinks);
if (mod->DllBase == hmod)
- return cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr);
+ return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) );
}
return NULL;
}
@@ -777,23 +808,21 @@ static WINE_MODREF *get_modref( HMODULE hmod )
*/
static WINE_MODREF *find_basename_module( LPCWSTR name )
{
+ WINE_MODREF *cached = get_cached_modref();
PLIST_ENTRY mark, entry;
UNICODE_STRING name_str;
RtlInitUnicodeString( &name_str, name );
- if (cached_modref && RtlEqualUnicodeString( &name_str, &cached_modref->ldr.BaseDllName, TRUE ))
- return cached_modref;
+ if (cached && RtlEqualUnicodeString( &name_str, &cached->ldr.BaseDllName, TRUE ))
+ return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList;
for (entry = mark->Flink; entry != mark; entry = entry->Flink)
{
LDR_DATA_TABLE_ENTRY *mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks);
if (RtlEqualUnicodeString( &name_str, &mod->BaseDllName, TRUE ))
- {
- cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr);
- return cached_modref;
- }
+ return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) );
}
return NULL;
}
@@ -807,6 +836,7 @@ static WINE_MODREF *find_basename_module( LPCWSTR name )
*/
static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name )
{
+ WINE_MODREF *cached = get_cached_modref();
PLIST_ENTRY mark, entry;
UNICODE_STRING name = *nt_name;
@@ -814,18 +844,15 @@ static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name )
name.Length -= 4 * sizeof(WCHAR); /* for \??\ prefix */
name.Buffer += 4;
- if (cached_modref && RtlEqualUnicodeString( &name, &cached_modref->ldr.FullDllName, TRUE ))
- return cached_modref;
+ if (cached && RtlEqualUnicodeString( &name, &cached->ldr.FullDllName, TRUE ))
+ return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList;
for (entry = mark->Flink; entry != mark; entry = entry->Flink)
{
LDR_DATA_TABLE_ENTRY *mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks);
if (RtlEqualUnicodeString( &name, &mod->FullDllName, TRUE ))
- {
- cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr);
- return cached_modref;
- }
+ return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) );
}
return NULL;
}
@@ -839,9 +866,10 @@ static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name )
*/
static WINE_MODREF *find_fileid_module( const struct file_id *id )
{
+ WINE_MODREF *cached = get_cached_modref();
LIST_ENTRY *mark, *entry;
- if (cached_modref && !memcmp( &cached_modref->id, id, sizeof(*id) )) return cached_modref;
+ if (cached && !memcmp( &cached->id, id, sizeof(*id) )) return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList;
for (entry = mark->Flink; entry != mark; entry = entry->Flink)
@@ -850,10 +878,7 @@ static WINE_MODREF *find_fileid_module( const struct file_id *id )
WINE_MODREF *wm = CONTAINING_RECORD( mod, WINE_MODREF, ldr );
if (!memcmp( &wm->id, id, sizeof(*id) ))
- {
- cached_modref = wm;
- return wm;
- }
+ return set_cached_modref( wm );
}
return NULL;
}
@@ -3810,6 +3835,7 @@ static void free_modref( WINE_MODREF *wm )
RemoveEntryList(&wm->ldr.InMemoryOrderLinks);
if (wm->ldr.InInitializationOrderLinks.Flink)
RemoveEntryList(&wm->ldr.InInitializationOrderLinks);
+ if (cached_modref == wm) cached_modref = NULL;
RtlReleaseSRWLockExclusive( &ldr_data_srw_lock );
TRACE(" unloading %s\n", debugstr_w(wm->ldr.FullDllName.Buffer));
@@ -3821,7 +3847,6 @@ static void free_modref( WINE_MODREF *wm )
free_tls_slot( &wm->ldr );
RtlReleaseActivationContext( wm->ldr.ActivationContext );
NtUnmapViewOfSection( NtCurrentProcess(), wm->ldr.DllBase );
- if (cached_modref == wm) cached_modref = NULL;
RtlFreeUnicodeString( &wm->ldr.FullDllName );
RtlFreeHeap( GetProcessHeap(), 0, wm->deps );
RtlFreeHeap( GetProcessHeap(), 0, wm );
--
2.31.1
More information about the wine-devel
mailing list