[PATCH 4/4] ntdll: Manage TPIO object destruction based on the expected completions.

Paul Gofman pgofman at codeweavers.com
Tue Jul 27 17:07:02 CDT 2021


Signed-off-by: Paul Gofman <pgofman at codeweavers.com>
---
 dlls/ntdll/tests/threadpool.c | 56 +++++++++++++++++++++++++++++--
 dlls/ntdll/threadpool.c       | 63 +++++++++++++++++++++++++++++------
 2 files changed, 107 insertions(+), 12 deletions(-)

diff --git a/dlls/ntdll/tests/threadpool.c b/dlls/ntdll/tests/threadpool.c
index 986cbbcf8f1..6c28d0642d7 100644
--- a/dlls/ntdll/tests/threadpool.c
+++ b/dlls/ntdll/tests/threadpool.c
@@ -2157,10 +2157,62 @@ static void test_tp_io(void)
     }
     ok(userdata.count == 0, "callback ran %u times\n", userdata.count);
 
-    CloseHandle(ovl.hEvent);
-    CloseHandle(client);
+    pTpReleaseIoCompletion(io);
     CloseHandle(server);
+
+    /* Test TPIO object destruction. */
+    server = CreateNamedPipeA("\\\\.\\pipe\\wine_tp_test",
+            PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, 0, 1, 1024, 1024, 0, NULL);
+    ok(server != INVALID_HANDLE_VALUE, "Failed to create server pipe, error %u.\n", GetLastError());
+    io = NULL;
+    status = pTpAllocIoCompletion(&io, server, io_cb, &userdata, &environment);
+    ok(!status, "got %#x\n", status);
+
+    ret = HeapValidate(GetProcessHeap(), 0, io);
+    ok(ret, "Got unexpected ret %#x.\n", ret);
     pTpReleaseIoCompletion(io);
+    ret = HeapValidate(GetProcessHeap(), 0, io);
+    ok(!ret, "Got unexpected ret %#x.\n", ret);
+    CloseHandle(server);
+    CloseHandle(client);
+
+    server = CreateNamedPipeA("\\\\.\\pipe\\wine_tp_test",
+            PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, 0, 1, 1024, 1024, 0, NULL);
+    ok(server != INVALID_HANDLE_VALUE, "Failed to create server pipe, error %u.\n", GetLastError());
+    client = CreateFileA("\\\\.\\pipe\\wine_tp_test", GENERIC_READ | GENERIC_WRITE,
+            0, NULL, OPEN_EXISTING, 0, 0);
+    ok(client != INVALID_HANDLE_VALUE, "Failed to create client pipe, error %u.\n", GetLastError());
+
+    io = NULL;
+    status = pTpAllocIoCompletion(&io, server, io_cb, &userdata, &environment);
+    ok(!status, "got %#x\n", status);
+    pTpStartAsyncIoOperation(io);
+    pTpWaitForIoCompletion(io, TRUE);
+    ret = HeapValidate(GetProcessHeap(), 0, io);
+    ok(ret, "Got unexpected ret %#x.\n", ret);
+    pTpReleaseIoCompletion(io);
+    ret = HeapValidate(GetProcessHeap(), 0, io);
+    ok(ret, "Got unexpected ret %#x.\n", ret);
+
+    if (0)
+    {
+        /* Object destruction will wait until one completion arrives (which was started but not cancelled).
+         * Commented out to save test time. */
+        Sleep(1000);
+        ret = HeapValidate(GetProcessHeap(), 0, io);
+        ok(ret, "Got unexpected ret %#x.\n", ret);
+        ret = ReadFile(server, in, sizeof(in), NULL, &ovl);
+        ok(!ret, "wrong ret %d\n", ret);
+        ret = WriteFile(client, out, sizeof(out), &ret_size, NULL);
+        ok(ret, "WriteFile() failed, error %u\n", GetLastError());
+        Sleep(2000);
+        ret = HeapValidate(GetProcessHeap(), 0, io);
+        ok(!ret, "Got unexpected ret %#x.\n", ret);
+    }
+
+    CloseHandle(server);
+    CloseHandle(ovl.hEvent);
+    CloseHandle(client);
     pTpReleasePool(pool);
 }
 
diff --git a/dlls/ntdll/threadpool.c b/dlls/ntdll/threadpool.c
index 50433b7c009..ca323919d05 100644
--- a/dlls/ntdll/threadpool.c
+++ b/dlls/ntdll/threadpool.c
@@ -201,7 +201,8 @@ struct threadpool_object
         {
             PTP_IO_CALLBACK callback;
             /* locked via .pool->cs */
-            unsigned int    pending_count, completion_count, completion_max;
+            unsigned int    pending_count, skipped_count, completion_count, completion_max;
+            BOOL            shutting_down;
             struct io_completion *completions;
         } io;
     } u;
@@ -1506,6 +1507,7 @@ static void CALLBACK ioqueue_thread_proc( void *param )
     struct threadpool_object *io;
     IO_STATUS_BLOCK iosb;
     ULONG_PTR key, value;
+    BOOL destroy, skip;
     NTSTATUS status;
 
     TRACE( "starting I/O completion thread\n" );
@@ -1519,17 +1521,33 @@ static void CALLBACK ioqueue_thread_proc( void *param )
             ERR("NtRemoveIoCompletion failed, status %#x.\n", status);
         RtlEnterCriticalSection( &ioqueue.cs );
 
+        destroy = skip = FALSE;
         io = (struct threadpool_object *)key;
 
-        if (io && io->shutdown)
+        TRACE( "io %p, iosb.Status %#x.\n", io, iosb.u.Status );
+
+        if (io && (io->shutdown || io->u.io.shutting_down))
         {
-            if (iosb.u.Status != STATUS_THREADPOOL_RELEASED_DURING_OPERATION)
+            RtlEnterCriticalSection( &io->pool->cs );
+            if (!io->u.io.pending_count)
             {
-                /* Skip remaining completions until the final one. */
-                continue;
+                if (io->u.io.skipped_count)
+                    --io->u.io.skipped_count;
+
+                if (io->u.io.skipped_count)
+                    skip = TRUE;
+                else
+                    destroy = TRUE;
             }
+            RtlLeaveCriticalSection( &io->pool->cs );
+            if (skip) continue;
+        }
+
+        if (destroy)
+        {
             --ioqueue.objcount;
             TRACE( "Releasing io %p.\n", io );
+            io->shutdown = TRUE;
             tp_object_release( io );
         }
         else if (io)
@@ -2004,7 +2022,10 @@ static void tp_object_cancel( struct threadpool_object *object )
             object->u.wait.signaled = 0;
     }
     if (object->type == TP_OBJECT_TYPE_IO)
+    {
+        object->u.io.skipped_count += object->u.io.pending_count;
         object->u.io.pending_count = 0;
+    }
     RtlLeaveCriticalSection( &pool->cs );
 
     while (pending_callbacks--)
@@ -2045,6 +2066,20 @@ static void tp_object_wait( struct threadpool_object *object, BOOL group_wait )
     RtlLeaveCriticalSection( &pool->cs );
 }
 
+static void tp_ioqueue_unlock( struct threadpool_object *io )
+{
+    assert( io->type == TP_OBJECT_TYPE_IO );
+
+    RtlEnterCriticalSection( &ioqueue.cs );
+
+    assert(ioqueue.objcount);
+
+    if (!io->shutdown && !--ioqueue.objcount)
+        NtSetIoCompletion( ioqueue.port, 0, 0, STATUS_SUCCESS, 0 );
+
+    RtlLeaveCriticalSection( &ioqueue.cs );
+}
+
 /***********************************************************************
  *           tp_object_prepare_shutdown    (internal)
  *
@@ -2056,6 +2091,8 @@ static void tp_object_prepare_shutdown( struct threadpool_object *object )
         tp_timerqueue_unlock( object );
     else if (object->type == TP_OBJECT_TYPE_WAIT)
         tp_waitqueue_unlock( object );
+    else if (object->type == TP_OBJECT_TYPE_IO)
+        tp_ioqueue_unlock( object );
 }
 
 /***********************************************************************
@@ -2797,15 +2834,21 @@ VOID WINAPI TpReleaseCleanupGroupMembers( TP_CLEANUP_GROUP *group, BOOL cancel_p
 void WINAPI TpReleaseIoCompletion( TP_IO *io )
 {
     struct threadpool_object *this = impl_from_TP_IO( io );
+    BOOL can_destroy;
 
     TRACE( "%p\n", io );
 
-    RtlEnterCriticalSection( &ioqueue.cs );
+    RtlEnterCriticalSection( &this->pool->cs );
+    this->u.io.shutting_down = TRUE;
+    can_destroy = !this->u.io.pending_count && !this->u.io.skipped_count;
+    RtlLeaveCriticalSection( &this->pool->cs );
 
-    assert( ioqueue.objcount );
-    this->shutdown = TRUE;
-    NtSetIoCompletion( ioqueue.port, (ULONG_PTR)this, 0, STATUS_THREADPOOL_RELEASED_DURING_OPERATION, 1 );
-    RtlLeaveCriticalSection( &ioqueue.cs );
+    if (can_destroy)
+    {
+        tp_object_prepare_shutdown( this );
+        this->shutdown = TRUE;
+        tp_object_release( this );
+    }
 }
 
 /***********************************************************************
-- 
2.31.1




More information about the wine-devel mailing list