Rémi Bernon : hidclass.sys: Always copy InputReportByteLength bytes into read buffer.

Alexandre Julliard julliard at winehq.org
Thu Aug 5 16:13:39 CDT 2021


Module: wine
Branch: master
Commit: e8c6f130714a96783c684f51e031bb062aba73cf
URL:    https://source.winehq.org/git/wine.git/?a=commit;h=e8c6f130714a96783c684f51e031bb062aba73cf

Author: Rémi Bernon <rbernon at codeweavers.com>
Date:   Thu Aug  5 10:36:05 2021 +0200

hidclass.sys: Always copy InputReportByteLength bytes into read buffer.

Signed-off-by: Rémi Bernon <rbernon at codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard at winehq.org>

---

 dlls/hidclass.sys/device.c         | 60 +++++++++++---------------------------
 dlls/ntoskrnl.exe/tests/ntoskrnl.c | 15 +++++-----
 2 files changed, 25 insertions(+), 50 deletions(-)

diff --git a/dlls/hidclass.sys/device.c b/dlls/hidclass.sys/device.c
index f1b0083fffd..a465cc53369 100644
--- a/dlls/hidclass.sys/device.c
+++ b/dlls/hidclass.sys/device.c
@@ -77,32 +77,6 @@ static void WINAPI read_cancel_routine(DEVICE_OBJECT *device, IRP *irp)
     IoCompleteRequest(irp, IO_NO_INCREMENT);
 }
 
-static NTSTATUS copy_packet_into_buffer(HID_XFER_PACKET *packet, BYTE* buffer, ULONG buffer_length, ULONG *out_length)
-{
-    BOOL zero_id = (packet->reportId == 0);
-
-    *out_length = 0;
-
-    if ((zero_id && buffer_length > packet->reportBufferLen) ||
-        (!zero_id && buffer_length >= packet->reportBufferLen))
-    {
-        if (packet->reportId != 0)
-        {
-            memcpy(buffer, packet->reportBuffer, packet->reportBufferLen);
-            *out_length = packet->reportBufferLen;
-        }
-        else
-        {
-            buffer[0] = 0;
-            memcpy(&buffer[1], packet->reportBuffer, packet->reportBufferLen);
-            *out_length = packet->reportBufferLen + 1;
-        }
-        return STATUS_SUCCESS;
-    }
-    else
-        return STATUS_BUFFER_OVERFLOW;
-}
-
 static void hid_device_send_input(DEVICE_OBJECT *device, HID_XFER_PACKET *packet)
 {
     BASE_DEVICE_EXTENSION *ext = device->DeviceExtension;
@@ -145,24 +119,24 @@ static void HID_Device_processQueue(DEVICE_OBJECT *device)
     IRP *irp;
     BASE_DEVICE_EXTENSION *ext = device->DeviceExtension;
     UINT buffer_size = RingBuffer_GetBufferSize(ext->u.pdo.ring_buffer);
+    const WINE_HIDP_PREPARSED_DATA *data = ext->u.pdo.preparsed_data;
     HID_XFER_PACKET *packet;
 
     packet = malloc(buffer_size);
 
     while((irp = pop_irp_from_queue(ext)))
     {
-        int ptr;
-        ptr = PtrToUlong( irp->Tail.Overlay.OriginalFileObject->FsContext );
+        BYTE *buffer = irp->AssociatedIrp.SystemBuffer, *dst = buffer;
+        int ptr = PtrToUlong( irp->Tail.Overlay.OriginalFileObject->FsContext );
 
         RingBuffer_Read(ext->u.pdo.ring_buffer, ptr, packet, &buffer_size);
         if (buffer_size)
         {
-            ULONG out_length;
-            IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation(irp);
-            packet->reportBuffer = (BYTE *)packet + sizeof(*packet);
             TRACE_(hid_report)("Processing Request (%i)\n",ptr);
-            irp->IoStatus.Status = copy_packet_into_buffer( packet, irp->AssociatedIrp.SystemBuffer, irpsp->Parameters.Read.Length, &out_length );
-            irp->IoStatus.Information = out_length;
+            if (!data->reports[0].reportID) *dst++ = 0;
+            memcpy( dst, packet + 1, data->caps.InputReportByteLength - (dst - buffer) );
+            irp->IoStatus.Information = packet->reportBufferLen;
+            irp->IoStatus.Status = STATUS_SUCCESS;
         }
         else
         {
@@ -483,8 +457,7 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp)
             HID_XFER_PACKET *packet;
             ULONG buffer_len = irpsp->Parameters.DeviceIoControl.OutputBufferLength;
             UINT packet_size = sizeof(*packet) + buffer_len;
-            BYTE *buffer = MmGetSystemAddressForMdlSafe(irp->MdlAddress, NormalPagePriority);
-            ULONG out_length;
+            BYTE *buffer = MmGetSystemAddressForMdlSafe( irp->MdlAddress, NormalPagePriority ), *dst = buffer;
 
             if (!buffer)
             {
@@ -510,7 +483,10 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp)
                              sizeof(*packet), &irp->IoStatus );
 
             if (irp->IoStatus.Status == STATUS_SUCCESS)
-                irp->IoStatus.Status = copy_packet_into_buffer( packet, buffer, buffer_len, &out_length );
+            {
+                if (!data->reports[0].reportID) *dst++ = 0;
+                memcpy( dst, packet + 1, data->caps.InputReportByteLength - (dst - buffer) );
+            }
             free(packet);
             break;
         }
@@ -568,6 +544,7 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp)
     const WINE_HIDP_PREPARSED_DATA *data = ext->u.pdo.preparsed_data;
     UINT buffer_size = RingBuffer_GetBufferSize(ext->u.pdo.ring_buffer);
     IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation(irp);
+    BYTE *buffer = irp->AssociatedIrp.SystemBuffer, *dst = buffer;
     NTSTATUS status;
     int ptr = -1;
     BOOL removed;
@@ -599,13 +576,10 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp)
 
     if (buffer_size)
     {
-        IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation( irp );
-        ULONG out_length;
-        packet->reportBuffer = (BYTE *)packet + sizeof(*packet);
-        TRACE_(hid_report)("Got Packet %p %i\n", packet->reportBuffer, packet->reportBufferLen);
-
-        irp->IoStatus.Status = copy_packet_into_buffer( packet, irp->AssociatedIrp.SystemBuffer, irpsp->Parameters.Read.Length, &out_length );
-        irp->IoStatus.Information = out_length;
+        if (!data->reports[0].reportID) *dst++ = 0;
+        memcpy( dst, packet + 1, data->caps.InputReportByteLength - (dst - buffer) );
+        irp->IoStatus.Information = packet->reportBufferLen;
+        irp->IoStatus.Status = STATUS_SUCCESS;
     }
     else
     {
diff --git a/dlls/ntoskrnl.exe/tests/ntoskrnl.c b/dlls/ntoskrnl.exe/tests/ntoskrnl.c
index 02f37773c67..aa22d72828c 100644
--- a/dlls/ntoskrnl.exe/tests/ntoskrnl.c
+++ b/dlls/ntoskrnl.exe/tests/ntoskrnl.c
@@ -2684,9 +2684,9 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
         memset(report, 0xcd, sizeof(report));
         SetLastError(0xdeadbeef);
         ret = ReadFile(file, report, caps.InputReportByteLength, &value, NULL);
-        todo_wine ok(ret, "ReadFile failed, last error %u\n", GetLastError());
+        ok(ret, "ReadFile failed, last error %u\n", GetLastError());
         todo_wine ok(value == (report_id ? 3 : 4), "ReadFile returned %x\n", value);
-        todo_wine ok(report[0] == report_id, "unexpected report data\n");
+        ok(report[0] == report_id, "unexpected report data\n");
 
         overlapped.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
         overlapped2.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
@@ -2695,9 +2695,9 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
         SetLastError(0xdeadbeef);
         while (ReadFile(async_file, report, caps.InputReportByteLength, NULL, &overlapped))
             ResetEvent(overlapped.hEvent);
-        todo_wine ok(GetLastError() == ERROR_IO_PENDING, "ReadFile returned error %u\n", GetLastError());
+        ok(GetLastError() == ERROR_IO_PENDING, "ReadFile returned error %u\n", GetLastError());
         ret = GetOverlappedResult(async_file, &overlapped, &value, TRUE);
-        todo_wine ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
+        ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
         todo_wine ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
         ResetEvent(overlapped.hEvent);
 
@@ -2716,13 +2716,14 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
 
         /* wait for first report to be ready */
         ret = GetOverlappedResult(async_file, &overlapped, &value, TRUE);
-        todo_wine ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
+        ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
         todo_wine ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
         /* second report should be ready and the same */
         ret = GetOverlappedResult(async_file, &overlapped2, &value, FALSE);
-        todo_wine ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
+        ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
         todo_wine ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
-        todo_wine ok(memcmp(report, buffer + caps.InputReportByteLength, caps.InputReportByteLength), "expected different report\n");
+        ok(memcmp(report, buffer + caps.InputReportByteLength, caps.InputReportByteLength),
+           "expected different report\n");
         ok(!memcmp(report, buffer, caps.InputReportByteLength), "expected identical reports\n");
 
         CloseHandle(overlapped.hEvent);




More information about the wine-cvs mailing list