[PATCH 7/7] kernel32: Write the wait handle before executing the callback.

Rémi Bernon rbernon at codeweavers.com
Wed Dec 2 04:01:16 CST 2020


Otherwise we may execute the callback before the value is actually
returned from RegisterWaitForSingleObject.

Gears Tactics shares a pointer to the returned handle with its callbacks
and calls UnregisterWait from there. This creates a race condition that
sometimes causes a double free.

Signed-off-by: Rémi Bernon <rbernon at codeweavers.com>
---

The loop test isn't very interesting and the number of iterations is
a little bit low to trigger the race condition consistently, but it's
there to illustrate the problem.

Note that calling RtlDeregisterWait from the wait callback crashes on
Windows whereas it is fine to call UnregisterWait. Maybe it should be
implemented separately, but I think it doesn't matter too much.

The test also shows that UnregisterWait does not return ERROR_IO_PENDING
when called from the callback itself, but that it should when called
outside of the callback while it is running. We currently do in both
cases.

 dlls/kernel32/sync.c         |  3 +-
 dlls/kernel32/tests/thread.c | 61 ++++++++++++++++++++++++++++++++++++
 dlls/ntdll/threadpool.c      |  3 ++
 3 files changed, 66 insertions(+), 1 deletion(-)

diff --git a/dlls/kernel32/sync.c b/dlls/kernel32/sync.c
index ab392c5d995..331cd80772a 100644
--- a/dlls/kernel32/sync.c
+++ b/dlls/kernel32/sync.c
@@ -110,7 +110,8 @@ DWORD WINAPI DECLSPEC_HOTPATCH GetTickCount(void)
 BOOL WINAPI RegisterWaitForSingleObject( HANDLE *wait, HANDLE object, WAITORTIMERCALLBACK callback,
                                          void *context, ULONG timeout, ULONG flags )
 {
-    return (*wait = RegisterWaitForSingleObjectEx( object, callback, context, timeout, flags)) != NULL;
+    if (!set_ntstatus( RtlRegisterWait( wait, object, callback, context, timeout, flags ))) return FALSE;
+    return TRUE;
 }
 
 /***********************************************************************
diff --git a/dlls/kernel32/tests/thread.c b/dlls/kernel32/tests/thread.c
index e861aa751e9..f33e05741a2 100644
--- a/dlls/kernel32/tests/thread.c
+++ b/dlls/kernel32/tests/thread.c
@@ -1346,6 +1346,15 @@ static void CALLBACK signaled_function(PVOID p, BOOLEAN TimerOrWaitFired)
     ok(!TimerOrWaitFired, "wait shouldn't have timed out\n");
 }
 
+static void CALLBACK wait_complete_function(PVOID p, BOOLEAN TimerOrWaitFired)
+{
+    HANDLE event = p;
+    DWORD res;
+    ok(!TimerOrWaitFired, "wait shouldn't have timed out\n");
+    res = WaitForSingleObject(event, INFINITE);
+    ok(res == WAIT_OBJECT_0, "WaitForSingleObject returned %x\n", res);
+}
+
 static void CALLBACK timeout_function(PVOID p, BOOLEAN TimerOrWaitFired)
 {
     HANDLE event = p;
@@ -1371,6 +1380,23 @@ static void CALLBACK waitthread_test_function(PVOID p, BOOLEAN TimerOrWaitFired)
     SetEvent(param->complete_event);
 }
 
+struct unregister_params
+{
+    HANDLE wait_handle;
+    HANDLE complete_event;
+};
+
+static void CALLBACK unregister_function(PVOID p, BOOLEAN TimerOrWaitFired)
+{
+    struct unregister_params *param = p;
+    HANDLE wait_handle = param->wait_handle;
+    BOOL ret;
+    ok(wait_handle != INVALID_HANDLE_VALUE, "invalid wait handle\n");
+    ret = pUnregisterWait(param->wait_handle);
+    todo_wine ok(ret, "UnregisterWait failed with error %d\n", GetLastError());
+    SetEvent(param->complete_event);
+}
+
 static void test_RegisterWaitForSingleObject(void)
 {
     BOOL ret;
@@ -1379,6 +1405,8 @@ static void test_RegisterWaitForSingleObject(void)
     HANDLE complete_event;
     HANDLE waitthread_trigger_event, waitthread_wait_event;
     struct waitthread_test_param param;
+    struct unregister_params unregister_param;
+    DWORD i;
 
     if (!pRegisterWaitForSingleObject || !pUnregisterWait)
     {
@@ -1411,8 +1439,26 @@ static void test_RegisterWaitForSingleObject(void)
     ret = pUnregisterWait(wait_handle);
     ok(ret, "UnregisterWait failed with error %d\n", GetLastError());
 
+    /* test unregister while running */
+
+    SetEvent(handle);
+    ret = pRegisterWaitForSingleObject(&wait_handle, handle, wait_complete_function, complete_event, INFINITE, WT_EXECUTEONLYONCE);
+    ok(ret, "RegisterWaitForSingleObject failed with error %d\n", GetLastError());
+
+    /* give worker thread chance to start */
+    Sleep(50);
+    ret = pUnregisterWait(wait_handle);
+    ok(!ret, "UnregisterWait succeeded\n");
+    ok(GetLastError() == ERROR_IO_PENDING, "UnregisterWait failed with error %d\n", GetLastError());
+
+    /* give worker thread chance to complete */
+    SetEvent(complete_event);
+    Sleep(50);
+
     /* test timeout case */
 
+    ResetEvent(handle);
+
     ret = pRegisterWaitForSingleObject(&wait_handle, handle, timeout_function, complete_event, 0, WT_EXECUTEONLYONCE);
     ok(ret, "RegisterWaitForSingleObject failed with error %d\n", GetLastError());
 
@@ -1442,6 +1488,21 @@ static void test_RegisterWaitForSingleObject(void)
     ret = pUnregisterWait(wait_handle);
     ok(ret, "UnregisterWait failed with error %d\n", GetLastError());
 
+    /* the callback execution should be sequentially consistent with the wait handle return,
+       even if the event is already set */
+
+    for (i = 0; i < 100; ++i)
+    {
+        SetEvent(handle);
+        unregister_param.complete_event = complete_event;
+        unregister_param.wait_handle = INVALID_HANDLE_VALUE;
+
+        ret = pRegisterWaitForSingleObject(&unregister_param.wait_handle, handle, unregister_function, &unregister_param, INFINITE, WT_EXECUTEONLYONCE | WT_EXECUTEINWAITTHREAD);
+        ok(ret, "RegisterWaitForSingleObject failed with error %d\n", GetLastError());
+
+        WaitForSingleObject(complete_event, INFINITE);
+    }
+
     /* test multiple waits with WT_EXECUTEINWAITTHREAD.
      * Windows puts multiple waits on the same wait thread, and using WT_EXECUTEINWAITTHREAD causes the callbacks to run serially.
      */
diff --git a/dlls/ntdll/threadpool.c b/dlls/ntdll/threadpool.c
index 331149855ab..bdc1ebddd7d 100644
--- a/dlls/ntdll/threadpool.c
+++ b/dlls/ntdll/threadpool.c
@@ -3226,9 +3226,12 @@ NTSTATUS WINAPI RtlRegisterWait( HANDLE *out, HANDLE handle, RTL_WAITORTIMERCALL
     object = impl_from_TP_WAIT(wait);
     object->u.wait.rtl_callback = callback;
 
+    RtlEnterCriticalSection( &waitqueue.cs );
     TpSetWait( (TP_WAIT *)object, handle, get_nt_timeout( &timeout, milliseconds ) );
 
     *out = object;
+    RtlLeaveCriticalSection( &waitqueue.cs );
+
     return STATUS_SUCCESS;
 }
 
-- 
2.29.2




More information about the wine-devel mailing list