[PATCH 3/3] ntdll/tests: Add tests for hooking exports.
Andrew Wesie
awesie at gmail.com
Thu May 30 01:34:53 CDT 2019
Signed-off-by: Andrew Wesie <awesie at gmail.com>
---
dlls/ntdll/tests/Makefile.in | 1 +
dlls/ntdll/tests/hooks.c | 728 +++++++++++++++++++++++++++++++++++++++++++
2 files changed, 729 insertions(+)
create mode 100644 dlls/ntdll/tests/hooks.c
diff --git a/dlls/ntdll/tests/Makefile.in b/dlls/ntdll/tests/Makefile.in
index 5c70f3f..542def9 100644
--- a/dlls/ntdll/tests/Makefile.in
+++ b/dlls/ntdll/tests/Makefile.in
@@ -10,6 +10,7 @@ C_SRCS = \
exception.c \
file.c \
generated.c \
+ hooks.c \
info.c \
large_int.c \
om.c \
diff --git a/dlls/ntdll/tests/hooks.c b/dlls/ntdll/tests/hooks.c
new file mode 100644
index 0000000..64dea87
--- /dev/null
+++ b/dlls/ntdll/tests/hooks.c
@@ -0,0 +1,728 @@
+/*
+ * Unit test suite for hooking ntdll functions
+ *
+ * Copyright 2018 Andrew Wesie
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
+ */
+
+#define NONAMELESSUNION
+#include "ntdll_test.h"
+#include "excpt.h"
+
+#ifdef __i386__
+
+/* ntdll exports */
+static void (WINAPI *pLdrInitializeThunk)(PCONTEXT,ULONG_PTR,ULONG_PTR,ULONG_PTR);
+static NTSTATUS (WINAPI *pNtContinue)(PCONTEXT,BOOLEAN);
+static NTSTATUS (WINAPI *pNtCreateThread)(PHANDLE,ACCESS_MASK,POBJECT_ATTRIBUTES,HANDLE,PCLIENT_ID,PCONTEXT,PINITIAL_TEB,BOOLEAN);
+static ULONG (WINAPI *pNtGetTickCount)(VOID);
+static NTSTATUS (WINAPI *pNtQueryInformationThread)(HANDLE,THREADINFOCLASS,PVOID,ULONG,PULONG);
+
+/* kernel32 exports */
+static DWORD (WINAPI *pGetTickCount)(VOID);
+static ULONGLONG (WINAPI *pGetTickCount64)(VOID);
+
+struct hook_state
+{
+ void *target;
+ void *mem;
+ BYTE original[64];
+ SIZE_T count;
+ void *callback;
+ PVOID exception_handler;
+};
+
+static DWORD callback_result;
+static ULONG callback_process_id;
+static ULONG callback_thread_id;
+static CLIENT_ID callback_client_id;
+static CLIENT_ID *callback_client_id_ptr;
+static HANDLE *callback_handle_ptr;
+
+/* Code from dlls/kernel32/thread.c. Required since API is missing on WinXP. */
+static DWORD get_thread_id(HANDLE thread)
+{
+ THREAD_BASIC_INFORMATION tbi;
+ NTSTATUS status;
+
+ status = pNtQueryInformationThread(thread, ThreadBasicInformation, &tbi,
+ sizeof(tbi), NULL);
+ if (status)
+ {
+ SetLastError( RtlNtStatusToDosError(status) );
+ return 0;
+ }
+
+ return HandleToULong(tbi.ClientId.UniqueThread);
+}
+
+static BOOL modrm_size(PBYTE *pp, BOOL addr16)
+{
+ PBYTE p = *pp;
+ BYTE mod = (p[0] >> 6) & 3, rm = p[0] & 7;
+ p++;
+
+ if (!addr16 && mod != 3 && rm == 4)
+ p++; /* SIB */
+
+ if (!addr16 && (mod == 0 && rm == 5))
+ p += 4; /* disp32 */
+ else if (mod == 1)
+ p += 1; /* ... + disp 8 */
+ else if (mod == 2)
+ p += addr16 ? 2 : 4; /* ... + disp16 / disp32 */
+
+ *pp = p;
+ return TRUE;
+}
+
+static int instruction_size(void *ip)
+{
+ BOOL pfx66 = FALSE; /* operand-size prefix */
+ BOOL pfx67 = FALSE; /* address-size prefix */
+ PBYTE p = ip;
+ BYTE ext;
+
+ while (1)
+ {
+ switch (*p)
+ {
+ case 0x66:
+ pfx66 = TRUE;
+ p++;
+ break;
+ case 0x67:
+ pfx67 = TRUE;
+ p++;
+ break;
+ case 0x26:
+ case 0x2E:
+ case 0x36:
+ case 0x3E:
+ case 0x64:
+ case 0x65:
+ case 0x9B:
+ case 0xF0:
+ case 0xF2:
+ case 0xF3:
+ p++;
+ break;
+ default:
+ goto no_prefix;
+ }
+ }
+no_prefix:
+
+ switch (*p)
+ {
+ case 0x00: case 0x01: case 0x02: case 0x03:
+ case 0x08: case 0x09: case 0x0A: case 0x0B:
+ case 0x10: case 0x11: case 0x12: case 0x13:
+ case 0x18: case 0x19: case 0x1A: case 0x1B:
+ case 0x20: case 0x21: case 0x22: case 0x23:
+ case 0x28: case 0x29: case 0x2A: case 0x2B:
+ case 0x30: case 0x31: case 0x32: case 0x33:
+ case 0x38: case 0x39: case 0x3A: case 0x3B:
+ case 0x84: case 0x85: case 0x86: case 0x87:
+ case 0x88: case 0x89: case 0x8A: case 0x8B:
+ case 0x8C: case 0x8D: case 0x8E: case 0x8F:
+ case 0xC4: case 0xC5: case 0xD0: case 0xD1:
+ case 0xD2: case 0xD3: case 0xFE: case 0xFF:
+ p++;
+ if (!modrm_size(&p, pfx67)) return -1;
+ break;
+ case 0x04: case 0x0C: case 0x14: case 0x1C:
+ case 0x24: case 0x2C: case 0x34: case 0x3C:
+ case 0x6A: case 0xA8:
+ case 0xB0: case 0xB1: case 0xB2: case 0xB3:
+ case 0xB4: case 0xB5: case 0xB6: case 0xB7:
+ case 0xCD: case 0xD4: case 0xD5:
+ p += 2;
+ break;
+ case 0x05: case 0x0D: case 0x15: case 0x1D:
+ case 0x25: case 0x2D: case 0x35: case 0x3D:
+ case 0x68: case 0xA9:
+ case 0xB8: case 0xB9: case 0xBA: case 0xBB:
+ case 0xBC: case 0xBD: case 0xBE: case 0xBF:
+ p += 1 + (pfx66 ? 2 : 4);
+ break;
+ case 0x06: case 0x07: case 0x0E: case 0x16:
+ case 0x17: case 0x1E: case 0x1F: case 0x27:
+ case 0x2F: case 0x37: case 0x3F:
+ case 0x40: case 0x41: case 0x42: case 0x43:
+ case 0x44: case 0x45: case 0x46: case 0x47:
+ case 0x48: case 0x49: case 0x4A: case 0x4B:
+ case 0x4C: case 0x4D: case 0x4E: case 0x4F:
+ case 0x50: case 0x51: case 0x52: case 0x53:
+ case 0x54: case 0x55: case 0x56: case 0x57:
+ case 0x58: case 0x59: case 0x5A: case 0x5B:
+ case 0x5C: case 0x5D: case 0x5E: case 0x5F:
+ case 0x60: case 0x61:
+ case 0x90: case 0x91: case 0x92: case 0x93:
+ case 0x94: case 0x95: case 0x96: case 0x97:
+ case 0x98: case 0x99: case 0x9B: case 0x9C:
+ case 0x9D: case 0x9E: case 0x9F: case 0xA4:
+ case 0xA5: case 0xA6: case 0xA7: case 0xAA:
+ case 0xAB: case 0xAC: case 0xAD: case 0xAE:
+ case 0xAF: case 0xC3: case 0xC9: case 0xCB:
+ case 0xCC: case 0xCE: case 0xCF: case 0xF1:
+ case 0xF4: case 0xF5: case 0xF8: case 0xF9:
+ case 0xFA: case 0xFB: case 0xFC: case 0xFD:
+ p++;
+ break;
+ case 0x6B: case 0x80: case 0x82: case 0x83:
+ case 0xC0: case 0xC1: case 0xC6:
+ p++;
+ if (!modrm_size(&p, pfx67)) return -1;
+ p++;
+ break;
+ case 0x69: case 0x81: case 0xC7:
+ p++;
+ if (!modrm_size(&p, pfx67)) return -1;
+ p += pfx66 ? 2 : 4;
+ break;
+ case 0xA0: case 0xA1: case 0xA2: case 0xA3:
+ p += 1 + (pfx67 ? 2 : 4);
+ break;
+ case 0xC2: case 0xCA:
+ p += 3;
+ break;
+ case 0xF6:
+ p++;
+ ext = (*p >> 3) & 7;
+ if (!modrm_size(&p, pfx67)) return -1;
+
+ switch (ext)
+ {
+ case 0: case 1:
+ p++;
+ break;
+ }
+ break;
+ case 0xF7:
+ p++;
+ ext = (*p >> 3) & 7;
+ if (!modrm_size(&p, pfx67)) return -1;
+
+ switch (ext)
+ {
+ case 0: case 1:
+ p += pfx66 ? 2 : 4;
+ break;
+ }
+ break;
+ /* 2-byte opcodes */
+ case 0x0F:
+ p++;
+ switch (*p)
+ {
+ case 0x0D: case 0x18: case 0x19: case 0x1A:
+ case 0x1B: case 0x1C: case 0x1D: case 0x1E:
+ case 0x1F:
+ case 0x90: case 0x91: case 0x92: case 0x93:
+ case 0x94: case 0x95: case 0x96: case 0x97:
+ case 0x98: case 0x99: case 0x9A: case 0x9B:
+ case 0x9C: case 0x9D: case 0x9E: case 0x9F:
+ case 0xA3: case 0xAB: case 0xAE: case 0xAF:
+ case 0xB0: case 0xB1: case 0xB3: case 0xB6:
+ case 0xB7: case 0xBB: case 0xBC: case 0xBD:
+ case 0xBE: case 0xBF: case 0xC0: case 0xC1:
+ case 0xC7:
+ p++;
+ if (!modrm_size(&p, pfx67)) return -1;
+ break;
+ case 0x31: case 0xA0: case 0xA1: case 0xA2:
+ case 0xA8: case 0xA9:
+ case 0xC8: case 0xC9: case 0xCA: case 0xCB:
+ case 0xCC: case 0xCD: case 0xCE: case 0xCF:
+ p++;
+ break;
+ case 0xBA:
+ p++;
+ if (!modrm_size(&p, pfx67)) return -1;
+ p++;
+ break;
+ /* unsupported instructions */
+ default:
+ return -1;
+ }
+ /* unsupported instructions */
+ case 0x62: case 0x63: /* bound / arpl */
+ case 0x6C: case 0x6D: case 0x6E: case 0x6F: /* in / out */
+ case 0xE4: case 0xE5: case 0xE6: case 0xE7:
+ case 0xEC: case 0xED: case 0xEE: case 0xEF:
+ case 0x70: case 0x71: case 0x72: case 0x73: /* jump */
+ case 0x74: case 0x75: case 0x76: case 0x77:
+ case 0x78: case 0x79: case 0x7A: case 0x7B:
+ case 0x7C: case 0x7D: case 0x7E: case 0x7F:
+ case 0xE0: case 0xE1: case 0xE2: case 0xE3:
+ case 0xE9: case 0xEA: case 0xEB:
+ case 0x9A: case 0xE8: /* call */
+ case 0xC8: /* enter */
+ case 0xD7: /* xlat */
+ case 0xD8: case 0xD9: case 0xDA: case 0xDB: /* fpu */
+ case 0xDC: case 0xDD: case 0xDE: case 0xDF:
+ default:
+ return -1;
+ }
+ return p - (PBYTE)ip;
+}
+
+static BOOL hook_function(struct hook_state *hook, void *func, void *callback)
+{
+ DWORD prot;
+ PBYTE p = func;
+ PBYTE mem;
+
+ hook->callback = callback;
+ hook->exception_handler = NULL;
+ hook->target = p;
+ hook->mem = mem = VirtualAlloc(NULL, 0x1000, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE);
+ if (!mem)
+ return FALSE;
+
+ hook->count = 0;
+ while (hook->count < 5)
+ {
+ int rv = instruction_size(p + hook->count);
+ if (rv < 0)
+ return FALSE;
+ hook->count += rv;
+ }
+
+ memcpy(hook->original, p, hook->count);
+
+ if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+ return FALSE;
+
+ *p = 0xe9; /* jmp */
+ *(PDWORD)(p + 1) = mem - (p + 5);
+
+ /* copy the arguments */
+ *mem++ = 0x56; /* push esi */
+ *mem++ = 0x57; /* push edi */
+ *mem++ = 0x8d; /* lea esi, [esp + 12] */
+ *mem++ = 0x74;
+ *mem++ = 0x24;
+ *mem++ = 0x0c;
+ *mem++ = 0x83; /* sub esp, 64 */
+ *mem++ = 0xec;
+ *mem++ = 0x40;
+ *mem++ = 0x89; /* mov edi, esp */
+ *mem++ = 0xe7;
+ *mem++ = 0xb9; /* mov ecx, 16 */
+ *mem++ = 0x10;
+ *mem++ = 0x00;
+ *mem++ = 0x00;
+ *mem++ = 0x00;
+ *mem++ = 0xf3; /* rep movsd */
+ *mem++ = 0xa5;
+ /* execute the callback */
+ *mem = 0xe8; /* call */
+ *(PDWORD)(mem + 1) = (PBYTE)callback - (mem + 5);
+ mem += 5;
+ /* restore the stack and registers */
+ *mem++ = 0x83; /* add esp, 64 */
+ *mem++ = 0xc4;
+ *mem++ = 0x40;
+ *mem++ = 0x5f; /* pop edi */
+ *mem++ = 0x5e; /* pop esi */
+ /* execute the original code */
+ memcpy(mem, hook->original, hook->count);
+ mem += hook->count;
+ *mem = 0xe9; /* jmp */
+ *(PDWORD)(mem + 1) = (p + hook->count) - (mem + 5);
+
+ if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+ return FALSE;
+
+ return TRUE;
+}
+
+static BOOL hook_syscall(struct hook_state *hook, void *syscall, void *callback)
+{
+ DWORD prot;
+ PBYTE p = syscall;
+ PBYTE mem;
+
+ hook->callback = callback;
+ hook->exception_handler = NULL;
+ hook->target = p;
+ hook->mem = mem = VirtualAlloc(NULL, 0x1000, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE);
+ if (!mem)
+ return FALSE;
+
+ /* Both 32-bit Windows and WoW64 start with "mov eax, 0x..". */
+ todo_wine ok(p[0] == 0xb8, "syscall does not begin with 0xb8 got: 0x%02x\n", p[0]);
+
+ hook->count = 5;
+
+ /* make a copy of the original bytes */
+ memcpy(hook->original, p, hook->count);
+
+ if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+ return FALSE;
+
+ *p = 0xe9; /* jmp */
+ *(PDWORD)(p + 1) = mem - (p + 5);
+
+ /* copy the arguments */
+ *mem++ = 0x56; /* push esi */
+ *mem++ = 0x57; /* push edi */
+ *mem++ = 0x8d; /* lea esi, [esp + 12] */
+ *mem++ = 0x74;
+ *mem++ = 0x24;
+ *mem++ = 0x0c;
+ *mem++ = 0x83; /* sub esp, 64 */
+ *mem++ = 0xec;
+ *mem++ = 0x40;
+ *mem++ = 0x89; /* mov edi, esp */
+ *mem++ = 0xe7;
+ *mem++ = 0xb9; /* mov ecx, 16 */
+ *mem++ = 0x10;
+ *mem++ = 0x00;
+ *mem++ = 0x00;
+ *mem++ = 0x00;
+ *mem++ = 0xf3; /* rep movsd */
+ *mem++ = 0xa5;
+ /* execute the callback */
+ *mem = 0xe8; /* call */
+ *(PDWORD)(mem + 1) = (PBYTE)callback - (mem + 5);
+ mem += 5;
+ /* restore the stack and registers */
+ *mem++ = 0x83; /* add esp, 64 */
+ *mem++ = 0xc4;
+ *mem++ = 0x40;
+ *mem++ = 0x5f; /* pop edi */
+ *mem++ = 0x5e; /* pop esi */
+ /* execute the original code */
+ memcpy(mem, hook->original, hook->count);
+ mem += hook->count;
+ *mem = 0xe9; /* jmp */
+ *(PDWORD)(mem + 1) = (p + hook->count) - (mem + 5);
+
+ if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+ return FALSE;
+
+ return TRUE;
+}
+
+struct hook_state *hook_syscall_ret_hook;
+static LONG CALLBACK hook_syscall_ret_handler(EXCEPTION_POINTERS *info)
+{
+ WORD args_size;
+
+ ok(info->ExceptionRecord->ExceptionCode == EXCEPTION_BREAKPOINT, "wrong exception expected: %08X got: %08X\n",
+ EXCEPTION_BREAKPOINT, info->ExceptionRecord->ExceptionCode);
+
+ ((void (*)(void))hook_syscall_ret_hook->callback)();
+
+ if (hook_syscall_ret_hook->original[0] == 0xc2)
+ args_size = *(WORD *)(hook_syscall_ret_hook->original + 1);
+ else
+ args_size = 0;
+ info->ContextRecord->Eip = *(ULONG*)info->ContextRecord->Esp;
+ info->ContextRecord->Esp += 4 + args_size;
+ return EXCEPTION_CONTINUE_EXECUTION;
+}
+
+static BOOL hook_syscall_ret(struct hook_state *hook, void *syscall, void *callback)
+{
+ DWORD prot;
+ PBYTE p = syscall;
+
+ hook->callback = callback;
+ hook->exception_handler = NULL;
+ hook->target = NULL;
+ hook->mem = NULL;
+
+ /* We do not have enough space for a jump, so we use an exception handler instead. */
+ hook_syscall_ret_hook = hook;
+ hook->exception_handler = AddVectoredExceptionHandler(TRUE, hook_syscall_ret_handler);
+ if (hook->exception_handler == NULL)
+ return FALSE;
+
+ ok(p[0] == 0xb8, "syscall[0] expected: 0xb8 got: 0x%02x\n", p[0]);
+ p += 5;
+
+ while (*p != 0xc2 && *p != 0xc3)
+ {
+ int rv = instruction_size(p);
+ if (rv < 0)
+ return FALSE;
+ p += rv;
+ }
+
+ /* make a copy of the original bytes */
+ hook->count = instruction_size(p);
+ hook->target = p;
+ memcpy(hook->original, p, hook->count);
+
+ if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+ return FALSE;
+
+ *p = 0xcc; /* int 3 */
+
+ if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+ return FALSE;
+
+ return TRUE;
+}
+
+static BOOL remove_hook(struct hook_state *hook)
+{
+ DWORD prot;
+
+ if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+ return FALSE;
+
+ memcpy(hook->target, hook->original, hook->count);
+
+ if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+ return FALSE;
+
+ if (hook->exception_handler && !RemoveVectoredExceptionHandler(hook->exception_handler))
+ return FALSE;
+
+ if (hook->mem && !VirtualFree(hook->mem, 0, MEM_RELEASE))
+ return FALSE;
+
+ return TRUE;
+}
+
+static DWORD WINAPI dummy_thread(LPVOID param)
+{
+ callback_process_id = GetCurrentProcessId();
+ callback_thread_id = GetCurrentThreadId();
+ return 0xdeadbeef;
+}
+
+static void callback_LdrInitializeThunk(PCONTEXT context, ULONG_PTR unknown1, ULONG_PTR unknown2, ULONG_PTR unknown3)
+{
+ callback_result = 1;
+
+ /* context is NULL on Windows XP */
+ if (context)
+ {
+ ok(context->Eax == (ULONG)dummy_thread, "context->Eax expected: %08X got: %08X\n",
+ (ULONG)dummy_thread, context->Eax);
+ ok(context->Ebx == (ULONG)0x12345678, "context->Ebx expected: 12345678 got: %08X\n", context->Ebx);
+ }
+}
+
+static void test_LdrInitializeThunk(void)
+{
+ HANDLE handle;
+ struct hook_state hook;
+
+ ok(hook_function(&hook, pLdrInitializeThunk, callback_LdrInitializeThunk), "failed to hook LdrInitializeThunk\n");
+
+ callback_result = 0;
+ handle = CreateThread(NULL, 0, dummy_thread, (LPVOID)0x12345678, 0, NULL);
+ ok(handle != NULL, "CreateThread failed\n");
+ ok(WaitForSingleObject(handle, 1000) == WAIT_OBJECT_0, "wait for thread failed\n");
+ ok(callback_result == 1, "callback never ran\n");
+ CloseHandle(handle);
+
+ ok(remove_hook(&hook), "failed to remove hook\n");
+}
+
+static LONG CALLBACK dummy_exception_handler(EXCEPTION_POINTERS *info)
+{
+ ok(info->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION,
+ "wrong exception expected: %08X got: %08X\n", EXCEPTION_ACCESS_VIOLATION,
+ info->ExceptionRecord->ExceptionCode);
+
+ info->ContextRecord->Eax = 0x12345678;
+ info->ContextRecord->Eip = *(ULONG*)info->ContextRecord->Esp;
+ info->ContextRecord->Esp += 4;
+
+ return EXCEPTION_CONTINUE_EXECUTION;
+}
+
+static void callback_NtContinue(PCONTEXT context, BOOLEAN alert)
+{
+ callback_result = 1;
+ ok((context->ContextFlags & CONTEXT_FULL) == CONTEXT_FULL, "wrong context flags expected: %08x got: %08x\n",
+ CONTEXT_FULL, context->ContextFlags);
+ ok(context->Eax == 0x12345678, "wrong Eax expected: 0x12345678 got: %08x\n", context->Eax);
+ ok(context->Eip == *(ULONG*)(context->Esp - 4), "wrong Eip expected: %08x got: %08x\n",
+ *(ULONG*)(context->Esp - 4), context->Eip);
+ context->Eax = 0xdeadbeef;
+}
+
+static void test_NtContinue(void)
+{
+ int result;
+ PVOID handle;
+ struct hook_state hook;
+
+ ok(hook_function(&hook, pNtContinue, callback_NtContinue), "failed to hook NtContinue\n");
+
+ handle = AddVectoredExceptionHandler(TRUE, dummy_exception_handler);
+ ok(handle != NULL, "failed to register exception handler\n");
+
+ callback_result = 0;
+ /* raise an exception */
+ result = ((int (*)(void))0)();
+ todo_wine ok(callback_result == 1, "callback never ran\n");
+ todo_wine ok(result == 0xdeadbeef, "wrong return value expected: deadbeef got: %08x\n", result);
+
+ RemoveVectoredExceptionHandler(handle);
+
+ ok(remove_hook(&hook), "failed to remove hook\n");
+}
+
+static void callback_NtCreateThread( HANDLE *handle_ptr, ACCESS_MASK access, OBJECT_ATTRIBUTES *attr, HANDLE process,
+ CLIENT_ID *id, CONTEXT *context, INITIAL_TEB *teb, BOOLEAN suspended )
+{
+ static CLIENT_ID tmp_id;
+ static HANDLE tmp_handle;
+
+ callback_result |= 1;
+ callback_client_id_ptr = id ? id : &tmp_id;
+ callback_handle_ptr = handle_ptr ? handle_ptr : &tmp_handle;
+
+ ok(context != NULL, "context is NULL\n");
+ if (context)
+ {
+ ok(context->Eax == (ULONG)dummy_thread, "context->Eax expected: %08X got: %08X\n",
+ (ULONG)dummy_thread, context->Eax);
+ ok(context->Ebx == (ULONG)0x12345678, "context->Ebx expected: 12345678 got: %08X\n", context->Ebx);
+ }
+}
+
+static void callback_NtCreateThread_ret(void)
+{
+ callback_result |= 2;
+ callback_client_id = *callback_client_id_ptr;
+ ok(HandleToULong(callback_client_id.UniqueThread) == get_thread_id(*callback_handle_ptr),
+ "wrong thread id %d != %d\n", HandleToULong(callback_client_id.UniqueThread),
+ get_thread_id(*callback_handle_ptr));
+}
+
+static void test_NtCreateThread(void)
+{
+ BOOL is_winxp = NtCurrentTeb()->Peb->OSMajorVersion < 6;
+ HANDLE handle;
+ struct hook_state hook, hook2;
+
+ todo_wine ok(hook_syscall_ret(&hook2, pNtCreateThread, callback_NtCreateThread_ret), "failed to hook NtCreateThread ret\n");
+ ok(hook_syscall(&hook, pNtCreateThread, callback_NtCreateThread), "failed to hook NtCreateThread\n");
+
+ callback_result = 0;
+ callback_client_id_ptr = NULL;
+ callback_handle_ptr = NULL;
+ handle = CreateThread(NULL, 0, dummy_thread, (LPVOID)0x12345678, 0, NULL);
+ ok(handle != NULL, "CreateThread failed\n");
+ ok(WaitForSingleObject(handle, 1000) == WAIT_OBJECT_0, "wait for thread failed\n");
+
+ if (is_winxp)
+ {
+ ok(callback_result & 1, "callback never ran\n");
+ ok(callback_result & 2, "ret callback never ran\n");
+ ok(callback_process_id == HandleToULong(callback_client_id.UniqueProcess),
+ "wrong process id expected: %d got: %d\n", callback_process_id,
+ HandleToULong(callback_client_id.UniqueProcess));
+ ok(callback_thread_id == HandleToULong(callback_client_id.UniqueThread),
+ "wrong thread id expected: %d got: %d\n", callback_thread_id,
+ HandleToULong(callback_client_id.UniqueThread));
+ }
+ else
+ {
+ ok(callback_result == 0, "callbacks ran but should not on major version %d\n",
+ NtCurrentTeb()->Peb->OSMajorVersion);
+ }
+
+ CloseHandle(handle);
+
+ ok(remove_hook(&hook), "failed to remove hook\n");
+ todo_wine ok(remove_hook(&hook2), "failed to remove hook\n");
+}
+
+static void callback_NtGetTickCount(void)
+{
+ callback_result |= 1;
+}
+
+static void test_NtGetTickCount(void)
+{
+ struct hook_state hook;
+ BYTE tmp[64];
+
+ /* NtGetTickCount may not be a system call. */
+ ok(hook_function(&hook, pNtGetTickCount, callback_NtGetTickCount),
+ "failed too hook NtGetTickCount\n");
+
+ callback_result = 0;
+ pGetTickCount();
+ ok(callback_result == 0, "callbacks ran during GetTickCount\n");
+
+ callback_result = 0;
+ pGetTickCount64();
+ ok(callback_result == 0, "callbacks ran during GetTickCount64\n");
+
+ callback_result = 0;
+ pNtGetTickCount();
+ ok(callback_result & 1, "callback never ran\n");
+
+ ok(remove_hook(&hook), "failed to remove hook\n");
+
+ memcpy(tmp, pNtGetTickCount, sizeof(tmp));
+
+ ok(hook_function(&hook, pGetTickCount, callback_NtGetTickCount),
+ "failed to hook GetTickCount\n");
+ ok(memcmp(pNtGetTickCount, tmp, sizeof(tmp)) == 0, "hook modified NtGetTickCount\n");
+
+ callback_result = 0;
+ pNtGetTickCount();
+ ok(callback_result == 0, "callbacks ran during NtGetTickCount\n");
+
+ ok(remove_hook(&hook), "failed to remove hook\n");
+
+ ok(hook_function(&hook, pGetTickCount64, callback_NtGetTickCount),
+ "failed to hook GetTickCount64\n");
+ ok(memcmp(pNtGetTickCount, tmp, sizeof(tmp)) == 0, "hook modified NtGetTickCount\n");
+
+ callback_result = 0;
+ pNtGetTickCount();
+ ok(callback_result == 0, "callbacks ran during NtGetTickCount\n");
+
+ ok(remove_hook(&hook), "failed to remove hook\n");
+}
+#endif
+
+START_TEST(hooks)
+{
+#ifdef __i386__
+ HMODULE hntdll = GetModuleHandleA("ntdll.dll");
+ HMODULE hkernel32 = GetModuleHandleA("kernel32.dll");
+
+ pLdrInitializeThunk = (void *)GetProcAddress(hntdll, "LdrInitializeThunk");
+ pNtContinue = (void *)GetProcAddress(hntdll, "NtContinue");
+ pNtCreateThread = (void *)GetProcAddress(hntdll, "NtCreateThread");
+ pNtGetTickCount = (void *)GetProcAddress(hntdll, "NtGetTickCount");
+ pNtQueryInformationThread = (void *)GetProcAddress(hntdll, "NtQueryInformationThread");
+
+ pGetTickCount = (void *)GetProcAddress(hkernel32, "GetTickCount");
+ pGetTickCount64 = (void *)GetProcAddress(hkernel32, "GetTickCount64");
+
+ test_LdrInitializeThunk();
+ test_NtContinue();
+ test_NtCreateThread();
+ test_NtGetTickCount();
+#endif
+}
--
2.1.4
More information about the wine-devel
mailing list