[PATCH v2 3/3] ntdll/tests: Add tests for hooking exports.

Andrew Wesie awesie at gmail.com
Thu May 30 02:47:59 CDT 2019


Signed-off-by: Andrew Wesie <awesie at gmail.com>
---
 dlls/ntdll/tests/Makefile.in |   1 +
 dlls/ntdll/tests/hooks.c     | 773 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 774 insertions(+)
 create mode 100644 dlls/ntdll/tests/hooks.c

diff --git a/dlls/ntdll/tests/Makefile.in b/dlls/ntdll/tests/Makefile.in
index 5c70f3f..542def9 100644
--- a/dlls/ntdll/tests/Makefile.in
+++ b/dlls/ntdll/tests/Makefile.in
@@ -10,6 +10,7 @@ C_SRCS = \
 	exception.c \
 	file.c \
 	generated.c \
+	hooks.c \
 	info.c \
 	large_int.c \
 	om.c \
diff --git a/dlls/ntdll/tests/hooks.c b/dlls/ntdll/tests/hooks.c
new file mode 100644
index 0000000..e7a8779
--- /dev/null
+++ b/dlls/ntdll/tests/hooks.c
@@ -0,0 +1,773 @@
+/*
+ * Unit test suite for hooking ntdll functions
+ *
+ * Copyright 2018 Andrew Wesie
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
+ */
+
+#define NONAMELESSUNION
+#include "ntdll_test.h"
+#include "excpt.h"
+
+#ifdef __i386__
+
+/* ntdll exports */
+static void     (WINAPI *pLdrInitializeThunk)(CONTEXT *,ULONG_PTR,ULONG_PTR,ULONG_PTR);
+static NTSTATUS (WINAPI *pNtContinue)(CONTEXT *,BOOLEAN);
+static NTSTATUS (WINAPI *pNtCreateThread)(HANDLE *,ACCESS_MASK,OBJECT_ATTRIBUTES *,HANDLE,CLIENT_ID *,CONTEXT *,INITIAL_TEB *,BOOLEAN);
+static ULONG (WINAPI *pNtGetTickCount)(void);
+static NTSTATUS (WINAPI *pNtQueryInformationThread)(HANDLE,THREADINFOCLASS,void *,ULONG,ULONG *);
+
+/* kernel32 exports */
+static DWORD (WINAPI *pGetTickCount)(void);
+static ULONGLONG (WINAPI *pGetTickCount64)(void);
+
+struct hook_state
+{
+    void *target;
+    void *mem;
+    BYTE original[64];
+    SIZE_T count;
+    void *callback;
+    void *exception_handler;
+};
+
+static DWORD callback_result;
+static ULONG callback_process_id;
+static ULONG callback_thread_id;
+static CLIENT_ID callback_client_id;
+static CLIENT_ID *callback_client_id_ptr;
+static HANDLE *callback_handle_ptr;
+
+/* Code from dlls/kernel32/thread.c. Required since API is missing on WinXP. */
+static DWORD get_thread_id(HANDLE thread)
+{
+    THREAD_BASIC_INFORMATION tbi;
+    NTSTATUS status;
+
+    status = pNtQueryInformationThread(thread, ThreadBasicInformation, &tbi,
+                                       sizeof(tbi), NULL);
+    if (status)
+    {
+        SetLastError( RtlNtStatusToDosError(status) );
+        return 0;
+    }
+
+    return HandleToULong(tbi.ClientId.UniqueThread);
+}
+
+static BOOL modrm_size(BYTE **pp, BOOL addr16)
+{
+    BYTE *p = *pp;
+    BYTE mod = (p[0] >> 6) & 3, rm = p[0] & 7;
+    p++;
+
+    if (!addr16 && mod != 3 && rm == 4)
+        p++; /* SIB */
+
+    if (!addr16 && (mod == 0 && rm == 5))
+        p += 4; /* disp32 */
+    else if (mod == 1)
+        p += 1; /* ... + disp 8 */
+    else if (mod == 2)
+        p += addr16 ? 2 : 4; /* ... + disp16 / disp32 */
+
+    *pp = p;
+    return TRUE;
+}
+
+static int is_nop(const BYTE *ip, unsigned int len)
+{
+    BYTE *p = ip;
+
+    while (p < ip + len)
+    {
+        switch (*p)
+        {
+        case 0x90:
+            p++;
+            break;
+        case 0xF3:
+            p++;
+            if (p == ip + len || *p != 0x90)
+                return 0;
+            p++;
+            break;
+        }
+    }
+
+    return 1;
+}
+
+static int instruction_size(void *ip)
+{
+    BOOL pfx66 = FALSE; /* operand-size prefix */
+    BOOL pfx67 = FALSE; /* address-size prefix */
+    BYTE *p = ip;
+    BYTE ext;
+
+    while (1)
+    {
+        switch (*p)
+        {
+        case 0x66:
+            pfx66 = TRUE;
+            p++;
+            break;
+        case 0x67:
+            pfx67 = TRUE;
+            p++;
+            break;
+        case 0x26:
+        case 0x2E:
+        case 0x36:
+        case 0x3E:
+        case 0x64:
+        case 0x65:
+        case 0x9B:
+        case 0xF0:
+        case 0xF2:
+        case 0xF3:
+            p++;
+            break;
+        default:
+            goto no_prefix;
+        }
+    }
+no_prefix:
+
+    switch (*p)
+    {
+    case 0x00: case 0x01: case 0x02: case 0x03:
+    case 0x08: case 0x09: case 0x0A: case 0x0B:
+    case 0x10: case 0x11: case 0x12: case 0x13:
+    case 0x18: case 0x19: case 0x1A: case 0x1B:
+    case 0x20: case 0x21: case 0x22: case 0x23:
+    case 0x28: case 0x29: case 0x2A: case 0x2B:
+    case 0x30: case 0x31: case 0x32: case 0x33:
+    case 0x38: case 0x39: case 0x3A: case 0x3B:
+    case 0x84: case 0x85: case 0x86: case 0x87:
+    case 0x88: case 0x89: case 0x8A: case 0x8B:
+    case 0x8C: case 0x8D: case 0x8E: case 0x8F:
+    case 0xC4: case 0xC5: case 0xD0: case 0xD1:
+    case 0xD2: case 0xD3: case 0xFE: case 0xFF:
+        p++;
+        if (!modrm_size(&p, pfx67)) return -1;
+        break;
+    case 0x04: case 0x0C: case 0x14: case 0x1C:
+    case 0x24: case 0x2C: case 0x34: case 0x3C:
+    case 0x6A: case 0xA8:
+    case 0xB0: case 0xB1: case 0xB2: case 0xB3:
+    case 0xB4: case 0xB5: case 0xB6: case 0xB7:
+    case 0xCD: case 0xD4: case 0xD5:
+        p += 2;
+        break;
+    case 0x05: case 0x0D: case 0x15: case 0x1D:
+    case 0x25: case 0x2D: case 0x35: case 0x3D:
+    case 0x68: case 0xA9:
+    case 0xB8: case 0xB9: case 0xBA: case 0xBB:
+    case 0xBC: case 0xBD: case 0xBE: case 0xBF:
+        p += 1 + (pfx66 ? 2 : 4);
+        break;
+    case 0x06: case 0x07: case 0x0E: case 0x16:
+    case 0x17: case 0x1E: case 0x1F: case 0x27:
+    case 0x2F: case 0x37: case 0x3F:
+    case 0x40: case 0x41: case 0x42: case 0x43:
+    case 0x44: case 0x45: case 0x46: case 0x47:
+    case 0x48: case 0x49: case 0x4A: case 0x4B:
+    case 0x4C: case 0x4D: case 0x4E: case 0x4F:
+    case 0x50: case 0x51: case 0x52: case 0x53:
+    case 0x54: case 0x55: case 0x56: case 0x57:
+    case 0x58: case 0x59: case 0x5A: case 0x5B:
+    case 0x5C: case 0x5D: case 0x5E: case 0x5F:
+    case 0x60: case 0x61:
+    case 0x90: case 0x91: case 0x92: case 0x93:
+    case 0x94: case 0x95: case 0x96: case 0x97:
+    case 0x98: case 0x99: case 0x9B: case 0x9C:
+    case 0x9D: case 0x9E: case 0x9F: case 0xA4:
+    case 0xA5: case 0xA6: case 0xA7: case 0xAA:
+    case 0xAB: case 0xAC: case 0xAD: case 0xAE:
+    case 0xAF: case 0xC3: case 0xC9: case 0xCB:
+    case 0xCC: case 0xCE: case 0xCF: case 0xF1:
+    case 0xF4: case 0xF5: case 0xF8: case 0xF9:
+    case 0xFA: case 0xFB: case 0xFC: case 0xFD:
+        p++;
+        break;
+    case 0x6B: case 0x80: case 0x82: case 0x83:
+    case 0xC0: case 0xC1: case 0xC6:
+        p++;
+        if (!modrm_size(&p, pfx67)) return -1;
+        p++;
+        break;
+    case 0x69: case 0x81: case 0xC7:
+        p++;
+        if (!modrm_size(&p, pfx67)) return -1;
+        p += pfx66 ? 2 : 4;
+        break;
+    case 0xA0: case 0xA1: case 0xA2: case 0xA3:
+        p += 1 + (pfx67 ? 2 : 4);
+        break;
+    case 0xC2: case 0xCA:
+        p += 3;
+        break;
+    case 0xF6:
+        p++;
+        ext = (*p >> 3) & 7;
+        if (!modrm_size(&p, pfx67)) return -1;
+
+        switch (ext)
+        {
+        case 0: case 1:
+            p++;
+            break;
+        }
+        break;
+    case 0xF7:
+        p++;
+        ext = (*p >> 3) & 7;
+        if (!modrm_size(&p, pfx67)) return -1;
+
+        switch (ext)
+        {
+        case 0: case 1:
+            p += pfx66 ? 2 : 4;
+            break;
+        }
+        break;
+    /* 2-byte opcodes */
+    case 0x0F:
+        p++;
+        switch (*p)
+        {
+        case 0x0D: case 0x18: case 0x19: case 0x1A:
+        case 0x1B: case 0x1C: case 0x1D: case 0x1E:
+        case 0x1F:
+        case 0x90: case 0x91: case 0x92: case 0x93:
+        case 0x94: case 0x95: case 0x96: case 0x97:
+        case 0x98: case 0x99: case 0x9A: case 0x9B:
+        case 0x9C: case 0x9D: case 0x9E: case 0x9F:
+        case 0xA3: case 0xAB: case 0xAE: case 0xAF:
+        case 0xB0: case 0xB1: case 0xB3: case 0xB6:
+        case 0xB7: case 0xBB: case 0xBC: case 0xBD:
+        case 0xBE: case 0xBF: case 0xC0: case 0xC1:
+        case 0xC7:
+            p++;
+            if (!modrm_size(&p, pfx67)) return -1;
+            break;
+        case 0x31: case 0xA0: case 0xA1: case 0xA2:
+        case 0xA8: case 0xA9:
+        case 0xC8: case 0xC9: case 0xCA: case 0xCB:
+        case 0xCC: case 0xCD: case 0xCE: case 0xCF:
+            p++;
+            break;
+        case 0xBA:
+            p++;
+            if (!modrm_size(&p, pfx67)) return -1;
+            p++;
+            break;
+        /* unsupported instructions */
+        default:
+            return -1;
+        }
+    case 0xEB:
+        p++;
+        if (*p <= 0x7f && is_nop(p + 1, *p))
+            return (p + 1 + *p) - (BYTE *)ip;
+        else
+            return -1;
+    /* unsupported instructions */
+    case 0x62: case 0x63: /* bound / arpl */
+    case 0x6C: case 0x6D: case 0x6E: case 0x6F: /* in / out */
+    case 0xE4: case 0xE5: case 0xE6: case 0xE7:
+    case 0xEC: case 0xED: case 0xEE: case 0xEF:
+    case 0x70: case 0x71: case 0x72: case 0x73: /* jump */
+    case 0x74: case 0x75: case 0x76: case 0x77:
+    case 0x78: case 0x79: case 0x7A: case 0x7B:
+    case 0x7C: case 0x7D: case 0x7E: case 0x7F:
+    case 0xE0: case 0xE1: case 0xE2: case 0xE3:
+    case 0xE9: case 0xEA:
+    case 0x9A: case 0xE8: /* call */
+    case 0xC8: /* enter */
+    case 0xD7: /* xlat */
+    case 0xD8: case 0xD9: case 0xDA: case 0xDB: /* fpu */
+    case 0xDC: case 0xDD: case 0xDE: case 0xDF:
+    default:
+        return -1;
+    }
+    return p - (BYTE *)ip;
+}
+
+static BOOL hook_function(struct hook_state *hook, void *func, void *callback)
+{
+    DWORD prot;
+    BYTE *p = func;
+    BYTE *mem;
+
+    hook->callback = callback;
+    hook->exception_handler = NULL;
+    hook->target = p;
+    hook->mem = mem = VirtualAlloc(NULL, 0x1000, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE);
+    if (!mem)
+        return FALSE;
+
+    hook->count = 0;
+    while (hook->count < 5)
+    {
+        int rv = instruction_size(p + hook->count);
+        if (rv < 0)
+            return FALSE;
+        hook->count += rv;
+    }
+
+    memcpy(hook->original, p, hook->count);
+
+    if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+        return FALSE;
+
+    *p = 0xe9; /* jmp */
+    *(DWORD *)(p + 1) = mem - (p + 5);
+
+    /* copy the arguments */
+    *mem++ = 0x56; /* push esi */
+    *mem++ = 0x57; /* push edi */
+    *mem++ = 0x8d; /* lea esi, [esp + 12] */
+    *mem++ = 0x74;
+    *mem++ = 0x24;
+    *mem++ = 0x0c;
+    *mem++ = 0x83; /* sub esp, 64 */
+    *mem++ = 0xec;
+    *mem++ = 0x40;
+    *mem++ = 0x89; /* mov edi, esp */
+    *mem++ = 0xe7;
+    *mem++ = 0xb9; /* mov ecx, 16 */
+    *mem++ = 0x10;
+    *mem++ = 0x00;
+    *mem++ = 0x00;
+    *mem++ = 0x00;
+    *mem++ = 0xf3; /* rep movsd */
+    *mem++ = 0xa5;
+    /* execute the callback */
+    *mem = 0xe8; /* call */
+    *(DWORD *)(mem + 1) = (BYTE *)callback - (mem + 5);
+    mem += 5;
+    /* restore the stack and registers */
+    *mem++ = 0x83; /* add esp, 64 */
+    *mem++ = 0xc4;
+    *mem++ = 0x40;
+    *mem++ = 0x5f; /* pop edi */
+    *mem++ = 0x5e; /* pop esi */
+    /* execute the original code */
+    memcpy(mem, hook->original, hook->count);
+    mem += hook->count;
+    *mem = 0xe9; /* jmp */
+    *(DWORD *)(mem + 1) = (p + hook->count) - (mem + 5);
+
+    if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+        return FALSE;
+
+    return TRUE;
+}
+
+static BOOL hook_syscall(struct hook_state *hook, void *syscall, void *callback)
+{
+    DWORD prot;
+    BYTE *p = syscall;
+    BYTE *mem;
+
+    hook->callback = callback;
+    hook->exception_handler = NULL;
+    hook->target = p;
+    hook->mem = mem = VirtualAlloc(NULL, 0x1000, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE);
+    if (!mem)
+        return FALSE;
+
+    /* Both 32-bit Windows and WoW64 start with "mov eax, 0x..". */
+    todo_wine ok(p[0] == 0xb8, "syscall does not begin with 0xb8  got: 0x%02x\n", p[0]);
+
+    hook->count = 5;
+
+    /* make a copy of the original bytes */
+    memcpy(hook->original, p, hook->count);
+
+    if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+        return FALSE;
+
+    *p = 0xe9; /* jmp */
+    *(DWORD *)(p + 1) = mem - (p + 5);
+
+    /* copy the arguments */
+    *mem++ = 0x56; /* push esi */
+    *mem++ = 0x57; /* push edi */
+    *mem++ = 0x8d; /* lea esi, [esp + 12] */
+    *mem++ = 0x74;
+    *mem++ = 0x24;
+    *mem++ = 0x0c;
+    *mem++ = 0x83; /* sub esp, 64 */
+    *mem++ = 0xec;
+    *mem++ = 0x40;
+    *mem++ = 0x89; /* mov edi, esp */
+    *mem++ = 0xe7;
+    *mem++ = 0xb9; /* mov ecx, 16 */
+    *mem++ = 0x10;
+    *mem++ = 0x00;
+    *mem++ = 0x00;
+    *mem++ = 0x00;
+    *mem++ = 0xf3; /* rep movsd */
+    *mem++ = 0xa5;
+    /* execute the callback */
+    *mem = 0xe8; /* call */
+    *(DWORD *)(mem + 1) = (BYTE *)callback - (mem + 5);
+    mem += 5;
+    /* restore the stack and registers */
+    *mem++ = 0x83; /* add esp, 64 */
+    *mem++ = 0xc4;
+    *mem++ = 0x40;
+    *mem++ = 0x5f; /* pop edi */
+    *mem++ = 0x5e; /* pop esi */
+    /* execute the original code */
+    memcpy(mem, hook->original, hook->count);
+    mem += hook->count;
+    *mem = 0xe9; /* jmp */
+    *(DWORD *)(mem + 1) = (p + hook->count) - (mem + 5);
+
+    if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+        return FALSE;
+
+    return TRUE;
+}
+
+struct hook_state *hook_syscall_ret_hook;
+static LONG CALLBACK hook_syscall_ret_handler(EXCEPTION_POINTERS *info)
+{
+    WORD args_size;
+
+    ok(info->ExceptionRecord->ExceptionCode == EXCEPTION_BREAKPOINT, "wrong exception expected: %08X got: %08X\n",
+       EXCEPTION_BREAKPOINT, info->ExceptionRecord->ExceptionCode);
+
+    ((void (*)(void))hook_syscall_ret_hook->callback)();
+
+    if (hook_syscall_ret_hook->original[0] == 0xc2)
+        args_size = *(WORD *)(hook_syscall_ret_hook->original + 1);
+    else
+        args_size = 0;
+    info->ContextRecord->Eip = *(ULONG*)info->ContextRecord->Esp;
+    info->ContextRecord->Esp += 4 + args_size;
+    return EXCEPTION_CONTINUE_EXECUTION;
+}
+
+static BOOL hook_syscall_ret(struct hook_state *hook, void *syscall, void *callback)
+{
+    DWORD prot;
+    BYTE *p = syscall;
+
+    hook->callback = callback;
+    hook->exception_handler = NULL;
+    hook->target = NULL;
+    hook->mem = NULL;
+
+    /* We do not have enough space for a jump, so we use an exception handler instead. */
+    hook_syscall_ret_hook = hook;
+    hook->exception_handler = AddVectoredExceptionHandler(TRUE, hook_syscall_ret_handler);
+    if (hook->exception_handler == NULL)
+        return FALSE;
+
+    ok(p[0] == 0xb8, "syscall[0] expected: 0xb8  got: 0x%02x\n", p[0]);
+    p += 5;
+
+    while (*p != 0xc2 && *p != 0xc3)
+    {
+        int rv;
+        if (*p == 0xe8) /* Win8 may have a call instruction after the syscall. */
+            rv = 5;
+        else
+            rv = instruction_size(p);
+        if (rv < 0)
+            return FALSE;
+        p += rv;
+    }
+
+    /* make a copy of the original bytes */
+    hook->count = instruction_size(p);
+    hook->target = p;
+    memcpy(hook->original, p, hook->count);
+
+    if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+        return FALSE;
+
+    *p = 0xcc; /* int 3 */
+
+    if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+        return FALSE;
+
+    return TRUE;
+}
+
+static BOOL remove_hook(struct hook_state *hook)
+{
+    DWORD prot;
+
+    if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot))
+        return FALSE;
+
+    memcpy(hook->target, hook->original, hook->count);
+
+    if (!VirtualProtect(hook->target, hook->count, prot, &prot))
+        return FALSE;
+
+    if (hook->exception_handler && !RemoveVectoredExceptionHandler(hook->exception_handler))
+        return FALSE;
+
+    if (hook->mem && !VirtualFree(hook->mem, 0, MEM_RELEASE))
+        return FALSE;
+
+    return TRUE;
+}
+
+static DWORD WINAPI dummy_thread(void *param)
+{
+    callback_process_id = GetCurrentProcessId();
+    callback_thread_id = GetCurrentThreadId();
+    return 0xdeadbeef;
+}
+
+static void callback_LdrInitializeThunk(CONTEXT *context, ULONG_PTR unknown1, ULONG_PTR unknown2, ULONG_PTR unknown3)
+{
+    callback_result = 1;
+
+    /* context is NULL on Windows XP */
+    if (context)
+    {
+        ok(context->Eax == (ULONG)dummy_thread, "context->Eax expected: %08X  got: %08X\n",
+           (ULONG)dummy_thread, context->Eax);
+        ok(context->Ebx == (ULONG)0x12345678, "context->Ebx expected: 12345678  got: %08X\n", context->Ebx);
+    }
+}
+
+static void test_LdrInitializeThunk(void)
+{
+    HANDLE handle;
+    struct hook_state hook;
+
+    ok(hook_function(&hook, pLdrInitializeThunk, callback_LdrInitializeThunk), "failed to hook LdrInitializeThunk\n");
+
+    callback_result = 0;
+    handle = CreateThread(NULL, 0, dummy_thread, (void *)0x12345678, 0, NULL);
+    ok(handle != NULL, "CreateThread failed\n");
+    ok(WaitForSingleObject(handle, 1000) == WAIT_OBJECT_0, "wait for thread failed\n");
+    ok(callback_result == 1, "callback never ran\n");
+    CloseHandle(handle);
+
+    ok(remove_hook(&hook), "failed to remove hook\n");
+}
+
+static LONG CALLBACK dummy_exception_handler(EXCEPTION_POINTERS *info)
+{
+    ok(info->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION,
+       "wrong exception expected: %08X got: %08X\n", EXCEPTION_ACCESS_VIOLATION,
+       info->ExceptionRecord->ExceptionCode);
+
+    info->ContextRecord->Eax = 0x12345678;
+    info->ContextRecord->Eip = *(ULONG*)info->ContextRecord->Esp;
+    info->ContextRecord->Esp += 4;
+
+    return EXCEPTION_CONTINUE_EXECUTION;
+}
+
+static void callback_NtContinue(CONTEXT *context, BOOLEAN alert)
+{
+    callback_result = 1;
+    ok((context->ContextFlags & CONTEXT_FULL) == CONTEXT_FULL, "wrong context flags expected: %08x  got: %08x\n",
+       CONTEXT_FULL, context->ContextFlags);
+    ok(context->Eax == 0x12345678, "wrong Eax expected: 0x12345678  got: %08x\n", context->Eax);
+    ok(context->Eip == *(ULONG*)(context->Esp - 4), "wrong Eip expected: %08x  got: %08x\n",
+       *(ULONG*)(context->Esp - 4), context->Eip);
+    context->Eax = 0xdeadbeef;
+}
+
+static void test_NtContinue(void)
+{
+    int result;
+    void *handle;
+    struct hook_state hook;
+
+    ok(hook_function(&hook, pNtContinue, callback_NtContinue), "failed to hook NtContinue\n");
+
+    handle = AddVectoredExceptionHandler(TRUE, dummy_exception_handler);
+    ok(handle != NULL, "failed to register exception handler\n");
+
+    callback_result = 0;
+    /* raise an exception */
+    result = ((int (*)(void))0)();
+    todo_wine ok(callback_result == 1, "callback never ran\n");
+    todo_wine ok(result == 0xdeadbeef, "wrong return value expected: deadbeef got: %08x\n", result);
+
+    RemoveVectoredExceptionHandler(handle);
+
+    ok(remove_hook(&hook), "failed to remove hook\n");
+}
+
+static void callback_NtCreateThread( HANDLE *handle_ptr, ACCESS_MASK access, OBJECT_ATTRIBUTES *attr, HANDLE process,
+                                     CLIENT_ID *id, CONTEXT *context, INITIAL_TEB *teb, BOOLEAN suspended )
+{
+    static CLIENT_ID tmp_id;
+    static HANDLE tmp_handle;
+
+    callback_result |= 1;
+    callback_client_id_ptr = id ? id : &tmp_id;
+    callback_handle_ptr = handle_ptr ? handle_ptr : &tmp_handle;
+
+    ok(context != NULL, "context is NULL\n");
+    if (context)
+    {
+        ok(context->Eax == (ULONG)dummy_thread, "context->Eax expected: %08X  got: %08X\n",
+           (ULONG)dummy_thread, context->Eax);
+        ok(context->Ebx == (ULONG)0x12345678, "context->Ebx expected: 12345678  got: %08X\n", context->Ebx);
+    }
+}
+
+static void callback_NtCreateThread_ret(void)
+{
+    callback_result |= 2;
+    callback_client_id = *callback_client_id_ptr;
+    ok(HandleToULong(callback_client_id.UniqueThread) == get_thread_id(*callback_handle_ptr),
+        "wrong thread id  %d != %d\n", HandleToULong(callback_client_id.UniqueThread),
+        get_thread_id(*callback_handle_ptr));
+}
+
+static void test_NtCreateThread(void)
+{
+    BOOL is_winxp = NtCurrentTeb()->Peb->OSMajorVersion < 6;
+    HANDLE handle;
+    struct hook_state hook, hook2;
+
+    todo_wine ok(hook_syscall_ret(&hook2, pNtCreateThread, callback_NtCreateThread_ret), "failed to hook NtCreateThread ret\n");
+    ok(hook_syscall(&hook, pNtCreateThread, callback_NtCreateThread), "failed to hook NtCreateThread\n");
+
+    callback_result = 0;
+    callback_client_id_ptr = NULL;
+    callback_handle_ptr = NULL;
+    handle = CreateThread(NULL, 0, dummy_thread, (void *)0x12345678, 0, NULL);
+    ok(handle != NULL, "CreateThread failed\n");
+    ok(WaitForSingleObject(handle, 1000) == WAIT_OBJECT_0, "wait for thread failed\n");
+
+    if (is_winxp)
+    {
+        ok(callback_result & 1, "callback never ran\n");
+        ok(callback_result & 2, "ret callback never ran\n");
+        ok(callback_process_id == HandleToULong(callback_client_id.UniqueProcess),
+           "wrong process id expected: %d  got: %d\n", callback_process_id,
+           HandleToULong(callback_client_id.UniqueProcess));
+        ok(callback_thread_id == HandleToULong(callback_client_id.UniqueThread),
+           "wrong thread id expected: %d  got: %d\n", callback_thread_id,
+           HandleToULong(callback_client_id.UniqueThread));
+    }
+    else
+    {
+        ok(callback_result == 0, "callbacks ran but should not on major version %d\n",
+           NtCurrentTeb()->Peb->OSMajorVersion);
+    }
+
+    CloseHandle(handle);
+
+    ok(remove_hook(&hook), "failed to remove hook\n");
+    todo_wine ok(remove_hook(&hook2), "failed to remove hook\n");
+}
+
+static void callback_NtGetTickCount(void)
+{
+    callback_result |= 1;
+}
+
+static void test_NtGetTickCount(void)
+{
+    struct hook_state hook;
+    BYTE tmp[64];
+
+    if (!pNtGetTickCount)
+    {
+        win_skip("NtGetTickCount is not available.\n");
+        return;
+    }
+
+    /* NtGetTickCount may not be a system call. */
+    ok(hook_function(&hook, pNtGetTickCount, callback_NtGetTickCount),
+        "failed to hook NtGetTickCount\n");
+
+    callback_result = 0;
+    pGetTickCount();
+    ok(callback_result == 0, "callbacks ran during GetTickCount\n");
+
+    if (pGetTickCount64)
+    {
+        callback_result = 0;
+        pGetTickCount64();
+        ok(callback_result == 0, "callbacks ran during GetTickCount64\n");
+    }
+
+    callback_result = 0;
+    pNtGetTickCount();
+    ok(callback_result & 1, "callback never ran\n");
+
+    ok(remove_hook(&hook), "failed to remove hook\n");
+
+    memcpy(tmp, pNtGetTickCount, sizeof(tmp));
+
+    ok(hook_function(&hook, pGetTickCount, callback_NtGetTickCount),
+        "failed to hook GetTickCount\n");
+    ok(memcmp(pNtGetTickCount, tmp, sizeof(tmp)) == 0, "hook modified NtGetTickCount\n");
+
+    callback_result = 0;
+    pNtGetTickCount();
+    ok(callback_result == 0, "callbacks ran during NtGetTickCount\n");
+
+    ok(remove_hook(&hook), "failed to remove hook\n");
+
+    if (pGetTickCount64)
+    {
+        ok(hook_function(&hook, pGetTickCount64, callback_NtGetTickCount),
+            "failed to hook GetTickCount64\n");
+        ok(memcmp(pNtGetTickCount, tmp, sizeof(tmp)) == 0, "hook modified NtGetTickCount\n");
+
+        callback_result = 0;
+        pNtGetTickCount();
+        ok(callback_result == 0, "callbacks ran during NtGetTickCount\n");
+
+        ok(remove_hook(&hook), "failed to remove hook\n");
+    }
+}
+#endif
+
+START_TEST(hooks)
+{
+#ifdef __i386__
+    HMODULE hntdll = GetModuleHandleA("ntdll.dll");
+    HMODULE hkernel32 = GetModuleHandleA("kernel32.dll");
+
+    pLdrInitializeThunk = (void *)GetProcAddress(hntdll, "LdrInitializeThunk");
+    pNtContinue = (void *)GetProcAddress(hntdll, "NtContinue");
+    pNtCreateThread = (void *)GetProcAddress(hntdll, "NtCreateThread");
+    pNtGetTickCount = (void *)GetProcAddress(hntdll, "NtGetTickCount");
+    pNtQueryInformationThread = (void *)GetProcAddress(hntdll, "NtQueryInformationThread");
+
+    pGetTickCount = (void *)GetProcAddress(hkernel32, "GetTickCount");
+    pGetTickCount64 = (void *)GetProcAddress(hkernel32, "GetTickCount64");
+
+    test_LdrInitializeThunk();
+    test_NtContinue();
+    test_NtCreateThread();
+    test_NtGetTickCount();
+#endif
+}
-- 
2.1.4




More information about the wine-devel mailing list