[PATCH 5/6] hidclass.sys: Adjust buffer length according to report IDs usage.

Rémi Bernon rbernon at codeweavers.com
Thu Aug 5 03:36:06 CDT 2021


Signed-off-by: Rémi Bernon <rbernon at codeweavers.com>
---
 dlls/hidclass.sys/device.c           | 117 +++++++++++++++------------
 dlls/ntoskrnl.exe/tests/driver_hid.c |   4 +-
 dlls/ntoskrnl.exe/tests/ntoskrnl.c   |  10 +--
 3 files changed, 72 insertions(+), 59 deletions(-)

diff --git a/dlls/hidclass.sys/device.c b/dlls/hidclass.sys/device.c
index a465cc53369..a7cb6a843e0 100644
--- a/dlls/hidclass.sys/device.c
+++ b/dlls/hidclass.sys/device.c
@@ -81,12 +81,10 @@ static void hid_device_send_input(DEVICE_OBJECT *device, HID_XFER_PACKET *packet
 {
     BASE_DEVICE_EXTENSION *ext = device->DeviceExtension;
     RAWINPUT *rawinput;
-    UCHAR *report, id;
     ULONG data_size;
     INPUT input;
 
     data_size = offsetof(RAWINPUT, data.hid.bRawData) + packet->reportBufferLen;
-    if (!(id = ext->u.pdo.preparsed_data->reports[0].reportID)) data_size += 1;
 
     if (!(rawinput = malloc(data_size)))
     {
@@ -100,10 +98,7 @@ static void hid_device_send_input(DEVICE_OBJECT *device, HID_XFER_PACKET *packet
     rawinput->header.wParam = RIM_INPUT;
     rawinput->data.hid.dwCount = 1;
     rawinput->data.hid.dwSizeHid = data_size - offsetof(RAWINPUT, data.hid.bRawData);
-
-    report = rawinput->data.hid.bRawData;
-    if (!id) *report++ = 0;
-    memcpy(report, packet->reportBuffer, packet->reportBufferLen);
+    memcpy( rawinput->data.hid.bRawData, packet->reportBuffer, packet->reportBufferLen );
 
     input.type = INPUT_HARDWARE;
     input.hi.uMsg = WM_INPUT;
@@ -126,15 +121,13 @@ static void HID_Device_processQueue(DEVICE_OBJECT *device)
 
     while((irp = pop_irp_from_queue(ext)))
     {
-        BYTE *buffer = irp->AssociatedIrp.SystemBuffer, *dst = buffer;
         int ptr = PtrToUlong( irp->Tail.Overlay.OriginalFileObject->FsContext );
 
         RingBuffer_Read(ext->u.pdo.ring_buffer, ptr, packet, &buffer_size);
         if (buffer_size)
         {
             TRACE_(hid_report)("Processing Request (%i)\n",ptr);
-            if (!data->reports[0].reportID) *dst++ = 0;
-            memcpy( dst, packet + 1, data->caps.InputReportByteLength - (dst - buffer) );
+            memcpy( irp->AssociatedIrp.SystemBuffer, packet + 1, data->caps.InputReportByteLength );
             irp->IoStatus.Information = packet->reportBufferLen;
             irp->IoStatus.Status = STATUS_SUCCESS;
         }
@@ -151,28 +144,42 @@ static void HID_Device_processQueue(DEVICE_OBJECT *device)
 static DWORD CALLBACK hid_device_thread(void *args)
 {
     DEVICE_OBJECT *device = (DEVICE_OBJECT*)args;
-    HID_XFER_PACKET *packet;
+    BASE_DEVICE_EXTENSION *ext = device->DeviceExtension;
+    const WINE_HIDP_PREPARSED_DATA *data = ext->u.pdo.preparsed_data;
+    BYTE report_id = HID_INPUT_VALUE_CAPS( data )->report_id;
+    ULONG buffer_len = data->caps.InputReportByteLength;
     IO_STATUS_BLOCK io;
+    HID_XFER_PACKET *packet;
+    BYTE *buffer;
     DWORD rc;
 
-    BASE_DEVICE_EXTENSION *ext = device->DeviceExtension;
-    USHORT report_size = ext->u.pdo.preparsed_data->caps.InputReportByteLength;
-
-    packet = malloc(sizeof(*packet) + report_size);
-    packet->reportBuffer = (BYTE *)packet + sizeof(*packet);
+    packet = malloc( sizeof(*packet) + buffer_len );
+    buffer = (BYTE *)(packet + 1);
+    packet->reportBuffer = buffer;
 
     if (ext->u.pdo.information.Polled)
     {
         while(1)
         {
-            packet->reportBufferLen = report_size;
-            packet->reportId = 0;
+            packet->reportId = buffer[0] = report_id;
+            packet->reportBufferLen = buffer_len;
+
+            if (!report_id)
+            {
+                packet->reportBuffer++;
+                packet->reportBufferLen--;
+            }
 
             call_minidriver( IOCTL_HID_GET_INPUT_REPORT, ext->u.pdo.parent_fdo, NULL, 0, packet,
                              sizeof(*packet), &io );
 
             if (io.Status == STATUS_SUCCESS)
             {
+                if (!report_id) io.Information++;
+                packet->reportId = buffer[0];
+                packet->reportBuffer = buffer;
+                packet->reportBufferLen = io.Information;
+
                 RingBuffer_Write(ext->u.pdo.ring_buffer, packet);
                 hid_device_send_input(device, packet);
                 HID_Device_processQueue(device);
@@ -193,8 +200,17 @@ static DWORD CALLBACK hid_device_thread(void *args)
 
         while(1)
         {
+            packet->reportId = buffer[0] = report_id;
+            packet->reportBufferLen = buffer_len;
+
+            if (!report_id)
+            {
+                packet->reportBuffer++;
+                packet->reportBufferLen--;
+            }
+
             call_minidriver( IOCTL_HID_READ_REPORT, ext->u.pdo.parent_fdo, NULL, 0,
-                             packet->reportBuffer, report_size, &io );
+                             packet->reportBuffer, packet->reportBufferLen, &io );
 
             rc = WaitForSingleObject(ext->u.pdo.halt_event, 0);
             if (rc == WAIT_OBJECT_0)
@@ -202,11 +218,11 @@ static DWORD CALLBACK hid_device_thread(void *args)
 
             if (!exit_now && io.Status == STATUS_SUCCESS)
             {
+                if (!report_id) io.Information++;
+                packet->reportId = buffer[0];
+                packet->reportBuffer = buffer;
                 packet->reportBufferLen = io.Information;
-                if (ext->u.pdo.preparsed_data->reports[0].reportID)
-                    packet->reportId = packet->reportBuffer[0];
-                else
-                    packet->reportId = 0;
+
                 RingBuffer_Write(ext->u.pdo.ring_buffer, packet);
                 hid_device_send_input(device, packet);
                 HID_Device_processQueue(device);
@@ -375,6 +391,7 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp)
     IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation( irp );
     BASE_DEVICE_EXTENSION *ext = device->DeviceExtension;
     const WINE_HIDP_PREPARSED_DATA *data = ext->u.pdo.preparsed_data;
+    BYTE report_id = HID_INPUT_VALUE_CAPS( data )->report_id;
     NTSTATUS status;
     BOOL removed;
     KIRQL irql;
@@ -454,10 +471,9 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp)
         }
         case IOCTL_HID_GET_INPUT_REPORT:
         {
-            HID_XFER_PACKET *packet;
+            HID_XFER_PACKET packet;
             ULONG buffer_len = irpsp->Parameters.DeviceIoControl.OutputBufferLength;
-            UINT packet_size = sizeof(*packet) + buffer_len;
-            BYTE *buffer = MmGetSystemAddressForMdlSafe( irp->MdlAddress, NormalPagePriority ), *dst = buffer;
+            BYTE *buffer = MmGetSystemAddressForMdlSafe( irp->MdlAddress, NormalPagePriority );
 
             if (!buffer)
             {
@@ -470,24 +486,19 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp)
                 break;
             }
 
-            packet = malloc(packet_size);
+            packet.reportId = buffer[0];
+            packet.reportBuffer = buffer;
+            packet.reportBufferLen = buffer_len;
 
-            if (ext->u.pdo.preparsed_data->reports[0].reportID)
-                packet->reportId = buffer[0];
-            else
-                packet->reportId = 0;
-            packet->reportBuffer = (BYTE *)packet + sizeof(*packet);
-            packet->reportBufferLen = buffer_len - 1;
-
-            call_minidriver( IOCTL_HID_GET_INPUT_REPORT, ext->u.pdo.parent_fdo, NULL, 0, packet,
-                             sizeof(*packet), &irp->IoStatus );
-
-            if (irp->IoStatus.Status == STATUS_SUCCESS)
+            if (!report_id)
             {
-                if (!data->reports[0].reportID) *dst++ = 0;
-                memcpy( dst, packet + 1, data->caps.InputReportByteLength - (dst - buffer) );
+                packet.reportId = 0;
+                packet.reportBuffer++;
+                packet.reportBufferLen--;
             }
-            free(packet);
+
+            call_minidriver( IOCTL_HID_GET_INPUT_REPORT, ext->u.pdo.parent_fdo, NULL, 0, &packet,
+                             sizeof(packet), &irp->IoStatus );
             break;
         }
         case IOCTL_SET_NUM_DEVICE_INPUT_BUFFERS:
@@ -544,7 +555,7 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp)
     const WINE_HIDP_PREPARSED_DATA *data = ext->u.pdo.preparsed_data;
     UINT buffer_size = RingBuffer_GetBufferSize(ext->u.pdo.ring_buffer);
     IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation(irp);
-    BYTE *buffer = irp->AssociatedIrp.SystemBuffer, *dst = buffer;
+    BYTE report_id = HID_INPUT_VALUE_CAPS( data )->report_id;
     NTSTATUS status;
     int ptr = -1;
     BOOL removed;
@@ -576,8 +587,7 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp)
 
     if (buffer_size)
     {
-        if (!data->reports[0].reportID) *dst++ = 0;
-        memcpy( dst, packet + 1, data->caps.InputReportByteLength - (dst - buffer) );
+        memcpy( irp->AssociatedIrp.SystemBuffer, packet + 1, data->caps.InputReportByteLength );
         irp->IoStatus.Information = packet->reportBufferLen;
         irp->IoStatus.Status = STATUS_SUCCESS;
     }
@@ -608,19 +618,24 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp)
         else
         {
             HID_XFER_PACKET packet;
+            BYTE *buffer = irp->AssociatedIrp.SystemBuffer;
+            ULONG buffer_len = irpsp->Parameters.Read.Length;
+
             TRACE("No packet, but opportunistic reads enabled\n");
-            packet.reportId = ((BYTE*)irp->AssociatedIrp.SystemBuffer)[0];
-            packet.reportBuffer = &((BYTE*)irp->AssociatedIrp.SystemBuffer)[1];
-            packet.reportBufferLen = irpsp->Parameters.Read.Length - 1;
 
-            call_minidriver( IOCTL_HID_GET_INPUT_REPORT, ext->u.pdo.parent_fdo, NULL, 0, &packet,
-                             sizeof(packet), &irp->IoStatus );
+            packet.reportId = buffer[0];
+            packet.reportBuffer = buffer;
+            packet.reportBufferLen = buffer_len;
 
-            if (irp->IoStatus.Status == STATUS_SUCCESS)
+            if (!report_id)
             {
-                ((BYTE*)irp->AssociatedIrp.SystemBuffer)[0] = packet.reportId;
-                irp->IoStatus.Information = packet.reportBufferLen + 1;
+                packet.reportId = 0;
+                packet.reportBuffer++;
+                packet.reportBufferLen--;
             }
+
+            call_minidriver( IOCTL_HID_GET_INPUT_REPORT, ext->u.pdo.parent_fdo, NULL, 0, &packet,
+                             sizeof(packet), &irp->IoStatus );
         }
     }
     free(packet);
diff --git a/dlls/ntoskrnl.exe/tests/driver_hid.c b/dlls/ntoskrnl.exe/tests/driver_hid.c
index a45e5f56928..831b08c5c97 100644
--- a/dlls/ntoskrnl.exe/tests/driver_hid.c
+++ b/dlls/ntoskrnl.exe/tests/driver_hid.c
@@ -479,7 +479,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp)
         {
             ULONG expected_size = 23;
             ok(!in_size, "got input size %u\n", in_size);
-            todo_wine_if(!report_id)
             ok(out_size == expected_size, "got output size %u\n", out_size);
 
             if (polled)
@@ -533,9 +532,8 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp)
             ok(!in_size, "got input size %u\n", in_size);
             ok(out_size == sizeof(*packet), "got output size %u\n", out_size);
 
-            todo_wine_if(packet->reportId == 0x5a || (polled && report_id && packet->reportId == 0))
+            todo_wine_if(packet->reportId == 0x5a)
             ok(packet->reportId == report_id, "got id %u\n", packet->reportId);
-            todo_wine_if(packet->reportBufferLen == 22)
             ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen);
             ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer);
 
diff --git a/dlls/ntoskrnl.exe/tests/ntoskrnl.c b/dlls/ntoskrnl.exe/tests/ntoskrnl.c
index aa22d72828c..17d82fa3482 100644
--- a/dlls/ntoskrnl.exe/tests/ntoskrnl.c
+++ b/dlls/ntoskrnl.exe/tests/ntoskrnl.c
@@ -2463,7 +2463,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
     else
     {
         ok(ret, "HidD_GetInputReport failed, last error %u\n", GetLastError());
-        todo_wine ok(buffer[0] == 0x5a, "got buffer[0] %x, expected 0x5a\n", (BYTE)buffer[0]);
+        ok(buffer[0] == 0x5a, "got buffer[0] %x, expected 0x5a\n", (BYTE)buffer[0]);
     }
 
     SetLastError(0xdeadbeef);
@@ -2685,7 +2685,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
         SetLastError(0xdeadbeef);
         ret = ReadFile(file, report, caps.InputReportByteLength, &value, NULL);
         ok(ret, "ReadFile failed, last error %u\n", GetLastError());
-        todo_wine ok(value == (report_id ? 3 : 4), "ReadFile returned %x\n", value);
+        ok(value == (report_id ? 3 : 4), "ReadFile returned %x\n", value);
         ok(report[0] == report_id, "unexpected report data\n");
 
         overlapped.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
@@ -2698,7 +2698,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
         ok(GetLastError() == ERROR_IO_PENDING, "ReadFile returned error %u\n", GetLastError());
         ret = GetOverlappedResult(async_file, &overlapped, &value, TRUE);
         ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
-        todo_wine ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
+        ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
         ResetEvent(overlapped.hEvent);
 
         memcpy(buffer, report, caps.InputReportByteLength);
@@ -2717,11 +2717,11 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
         /* wait for first report to be ready */
         ret = GetOverlappedResult(async_file, &overlapped, &value, TRUE);
         ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
-        todo_wine ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
+        ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
         /* second report should be ready and the same */
         ret = GetOverlappedResult(async_file, &overlapped2, &value, FALSE);
         ok(ret, "GetOverlappedResult failed, last error %u\n", GetLastError());
-        todo_wine ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
+        ok(value == (report_id ? 3 : 4), "GetOverlappedResult returned length %u, expected 3\n", value);
         ok(memcmp(report, buffer + caps.InputReportByteLength, caps.InputReportByteLength),
            "expected different report\n");
         ok(!memcmp(report, buffer, caps.InputReportByteLength), "expected identical reports\n");
-- 
2.32.0




More information about the wine-devel mailing list