[PATCH 2/2] ntdll/tests: Add tests for UMIP instructions

Brendan Shanks bshanks at codeweavers.com
Tue Nov 5 19:17:39 CST 2019


Signed-off-by: Brendan Shanks <bshanks at codeweavers.com>
---
 dlls/ntdll/tests/Makefile.in |   1 +
 dlls/ntdll/tests/umip.c      | 301 +++++++++++++++++++++++++++++++++++
 2 files changed, 302 insertions(+)
 create mode 100644 dlls/ntdll/tests/umip.c

diff --git a/dlls/ntdll/tests/Makefile.in b/dlls/ntdll/tests/Makefile.in
index ed15c51339..e866c54149 100644
--- a/dlls/ntdll/tests/Makefile.in
+++ b/dlls/ntdll/tests/Makefile.in
@@ -23,4 +23,5 @@ C_SRCS = \
 	string.c \
 	threadpool.c \
 	time.c \
+	umip.c \
 	virtual.c
diff --git a/dlls/ntdll/tests/umip.c b/dlls/ntdll/tests/umip.c
new file mode 100644
index 0000000000..f39149dd8d
--- /dev/null
+++ b/dlls/ntdll/tests/umip.c
@@ -0,0 +1,301 @@
+/*
+ * Unit test suite for x86 instructions protected by UMIP.
+ *
+ * Copyright (C) 2019 Brendan Shanks for CodeWeavers
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
+ */
+
+#include <stdio.h>
+#include <stdint.h>
+
+#include "ntstatus.h"
+#define WIN32_NO_STATUS
+#include "windef.h"
+#include "winternl.h"
+#include "wine/test.h"
+
+#if defined(__x86_64__) || defined(__i386__)
+
+static PVOID (WINAPI *pRtlAddVectoredExceptionHandler)(ULONG first, PVECTORED_EXCEPTION_HANDLER func);
+static ULONG (WINAPI *pRtlRemoveVectoredExceptionHandler)(PVOID handler);
+
+static void *code_mem;
+
+/* sldt, str, smsw all store a 16-bit value to a register or memory.
+ * When the destination is a register, the 16-bit value is zero-extended.
+ * When the destination is memory, only the 16-bit value is stored.
+ */
+static void test_reg_mem_16(const char *insn_name,
+                            const BYTE *reg16_code, UINT reg16_code_size,
+                            const BYTE *reg32_code, UINT reg32_code_size,
+                            const BYTE *reg64_code, UINT reg64_code_size,
+                            const BYTE *mem16_code, UINT mem16_code_size)
+{
+    UINT16 value16;
+    UINT32 value32;
+
+    UINT16 (*reg16_func)(void) = code_mem;
+    UINT32 (*reg32_func)(void) = code_mem;
+    void (*mem16_func)(UINT16 *value) = code_mem;
+
+    /* Destination is 16-bit register */
+    memset(&value16, 0xcc, sizeof(value16));
+    memcpy(code_mem, reg16_code, reg16_code_size);
+    value16 = reg16_func();
+    trace("%s reg16: 0x%x\n", insn_name, value16);
+
+    /* Destination is 32-bit register. */
+    memset(&value32, 0xcc, sizeof(value32));
+    memcpy(code_mem, reg32_code, reg32_code_size);
+    value32 = reg32_func();
+    trace("%s reg32: 0x%x\n", insn_name, value32);
+#if defined(__x86_64__)
+    /* For sldt/str in 64-bit mode, the upper 16 bits is defined to be zero */
+    if (!strcmp(insn_name, "sldt") || !strcmp(insn_name, "str"))
+        ok(value32 >> 16 == 0, "%s expected upper 16 bits = 0, got 0x%x\n", insn_name, value32 >> 16);
+#endif
+
+#if defined(__x86_64__)
+    /* Destination is 64-bit register */
+    {
+        UINT64 value64;
+        UINT64 (*reg64_func)(void) = code_mem;
+
+        memset(&value64, 0xcc, sizeof(value64));
+        memcpy(code_mem, reg64_code, reg64_code_size);
+        value64 = reg64_func();
+        trace("%s reg64: 0x%llx\n", insn_name, value64);
+
+        /* For sldt/str the upper 48 bits is defined to be zero */
+        if (!strcmp(insn_name, "sldt") || !strcmp(insn_name, "str"))
+            ok(value64 >> 16 == 0, "%s expected upper 48 bits = 0, got 0x%llx\n", insn_name, value64 >> 16);
+    }
+#endif
+
+    /* Destination is memory (only the low 16 bits are defined to be written) */
+    memset(&value32, 0xcc, sizeof(value32));
+    memcpy(code_mem, mem16_code, mem16_code_size);
+    mem16_func((UINT16 *)&value32);
+    trace("%s mem: 0x%x\n", insn_name, value32);
+    ok(value32 >> 16 == 0xcccc, "%s expected upper 16 bits = 0xcccc, got 0x%x\n", insn_name, value32 >> 16);
+}
+
+/* sgdt and sidt write the descriptor table register to memory.
+ * The descriptor consists of a 2-byte limit field, and a base field which
+ * is 4 bytes in 32-bit mode and 8 bytes in 64-bit mode.
+ */
+static void test_mem_descriptor(const char *insn_name, const BYTE *code, UINT code_size)
+{
+    BYTE descriptor[10];
+    void (*func)(BYTE *value) = code_mem;
+
+    memset(descriptor, 0xcc, sizeof(descriptor));
+    memcpy(code_mem, code, code_size);
+    func(descriptor);
+    trace("%s limit: 0x%x\n", insn_name, *(UINT16 *)&descriptor[0]);
+    trace("%s base: 0x%p\n", insn_name, (void *) *(UINT_PTR *)&descriptor[2]);
+}
+
+static void test_sldt(void)
+{
+    const BYTE reg16[] = {
+        0x66, 0x0f, 0x00, 0xc0,         /* sldt ax */
+        0xc3,                           /* ret */
+    };
+    const BYTE reg32[] = {
+        0x0f, 0x00, 0xc0,               /* sldt eax */
+        0xc3,                           /* ret */
+    };
+    const BYTE reg64[] = {
+        0x48, 0x0f, 0x00, 0xc0,         /* sldt rax */
+        0xc3,                           /* ret */
+    };
+#if defined(__x86_64__)
+    const BYTE mem16[] = {
+        0x0f, 0x00, 0x01,               /* sldt word [rcx] */
+        0xc3,                           /* ret */
+    };
+#else
+    const BYTE mem16[] = {
+        0x8b, 0x4c, 0x24, 0x04,         /* mov ecx, dword [esp] */
+        0x0f, 0x00, 0x01,               /* sldt word [ecx] */
+        0xc3,                           /* ret */
+    };
+#endif
+
+    test_reg_mem_16("sldt", reg16, sizeof(reg16), reg32, sizeof(reg32), reg64, sizeof(reg64), mem16, sizeof(mem16));
+}
+
+static void test_str(void)
+{
+    const BYTE reg16[] = {
+        0x66, 0x0f, 0x00, 0xc8,         /* str ax */
+        0xc3,                           /* ret */
+    };
+    const BYTE reg32[] = {
+        0x0f, 0x00, 0xc8,               /* str eax */
+        0xc3,                           /* ret */
+    };
+    const BYTE reg64[] = {
+        0x48, 0x0f, 0x00, 0xc8,         /* str rax */
+        0xc3,                           /* ret */
+    };
+#if defined(__x86_64__)
+    const BYTE mem16[] = {
+        0x0f, 0x00, 0x09,               /* str word [rcx] */
+        0xc3,                           /* ret */
+    };
+#else
+    const BYTE mem16[] = {
+        0x8b, 0x4c, 0x24, 0x04,         /* mov ecx, dword [esp] */
+        0x0f, 0x00, 0x09,               /* str word [rcx] */
+        0xc3,                           /* ret */
+    };
+#endif
+
+    test_reg_mem_16("str", reg16, sizeof(reg16), reg32, sizeof(reg32), reg64, sizeof(reg64), mem16, sizeof(mem16));
+}
+
+static void test_sgdt(void)
+{
+    /* sgdt destination must be memory */
+#if defined(__x86_64__)
+    const BYTE mem_code[] = {
+        0x0f, 0x01, 0x01,               /* sgdt rcx */
+        0xc3,                           /* ret */
+    };
+#else
+    const BYTE mem_code[] = {
+        0x8b, 0x4c, 0x24, 0x04,         /* mov ecx, dword [esp] */
+        0x0f, 0x01, 0x01,               /* sgdt ecx */
+        0xc3,                           /* ret */
+    };
+#endif
+
+    test_mem_descriptor("sgdt", mem_code, sizeof(mem_code));
+}
+
+static void test_sidt(void)
+{
+    /* sidt destination must be memory */
+#if defined(__x86_64__)
+    const BYTE mem_code[] = {
+        0x0f, 0x01, 0x09,               /* sidt rcx */
+        0xc3,                           /* ret */
+    };
+#else
+    const BYTE mem_code[] = {
+        0x8b, 0x4c, 0x24, 0x04,         /* mov ecx, dword [esp] */
+        0x0f, 0x01, 0x09,               /* sidt ecx */
+        0xc3,                           /* ret */
+    };
+#endif
+
+    test_mem_descriptor("sidt", mem_code, sizeof(mem_code));
+}
+
+static void test_smsw(void)
+{
+    const BYTE reg16[] = {
+        0x66, 0x0f, 0x01, 0xe0,         /* smsw ax */
+        0xc3,                           /* ret */
+    };
+    const BYTE reg32[] = {
+        0x0f, 0x01, 0xe0,               /* smsw eax */
+        0xc3,                           /* ret */
+    };
+    const BYTE reg64[] = {
+        0x48, 0x0f, 0x01, 0xe0,         /* smsw rax */
+        0xc3,                           /* ret */
+    };
+
+#if defined(__x86_64__)
+    const BYTE mem16[] = {
+        0x0f, 0x01, 0x21,               /* smsw word [rcx] */
+        0xc3,                           /* ret */
+    };
+#else
+    const BYTE mem16[] = {
+        0x8b, 0x4c, 0x24, 0x04,         /* mov ecx, dword [esp] */
+        0x0f, 0x01, 0x21,               /* smsw word [ecx] */
+        0xc3,                           /* ret */
+    };
+#endif
+
+    test_reg_mem_16("smsw", reg16, sizeof(reg16), reg32, sizeof(reg32), reg64, sizeof(reg64), mem16, sizeof(mem16));
+}
+
+static LONG CALLBACK umip_vectored_handler(EXCEPTION_POINTERS *ExceptionInfo)
+{
+    PEXCEPTION_RECORD rec = ExceptionInfo->ExceptionRecord;
+    trace("vectored exception handler %08x addr:%p\n", rec->ExceptionCode, rec->ExceptionAddress);
+
+    ok (!(rec->ExceptionCode == EXCEPTION_ACCESS_VIOLATION &&
+          rec->ExceptionInformation[0] == 0 &&
+          rec->ExceptionInformation[1] == UINTPTR_MAX),
+         "vectored_handler caught fault for unemulated UMIP instruction, exiting\n");
+
+    ExitProcess(1);
+
+    return EXCEPTION_CONTINUE_SEARCH;
+}
+#endif      /* __x86_64__ || __i386__ */
+
+START_TEST(umip)
+{
+    /* Test that sldt, str, sgdt, sidt, and smsw can be executed with
+     * all possible operand types (registers/memory of different widths).
+     *
+     * We mostly cannot test/predict the returned values, but on a UMIP-enabled
+     * system without emulation the instructions willl trigger a SIGSEGV.
+     * A non-first-chance vectored exception handler is added to catch the exception,
+     * fail the test, and exit the process if that happens.
+     */
+
+#if defined(__x86_64__) || defined(__i386__)
+    PVOID vectored_handler;
+    HMODULE hntdll = GetModuleHandleA("ntdll.dll");
+
+    pRtlAddVectoredExceptionHandler = (void *)GetProcAddress(hntdll, "RtlAddVectoredExceptionHandler");
+    pRtlRemoveVectoredExceptionHandler = (void *)GetProcAddress(hntdll, "RtlRemoveVectoredExceptionHandler");
+
+    if (!pRtlAddVectoredExceptionHandler || !pRtlRemoveVectoredExceptionHandler) {
+        trace("RtlAddVectoredExceptionHandler or RtlRemoveVectoredExceptionHandler not found\n");
+        return;
+    }
+
+    vectored_handler = pRtlAddVectoredExceptionHandler(FALSE, &umip_vectored_handler);
+    if (!vectored_handler) {
+        trace("RtlAddVectoredExceptionHandler failed\n");
+        return;
+    }
+
+    code_mem = VirtualAlloc(NULL, 65536, MEM_RESERVE | MEM_COMMIT, PAGE_EXECUTE_READWRITE);
+    if(!code_mem) {
+        trace("VirtualAlloc failed\n");
+        return;
+    }
+
+    test_sldt();
+    test_str();
+    test_sgdt();
+    test_sidt();
+    test_smsw();
+
+    VirtualFree(code_mem, 0, MEM_RELEASE);
+    pRtlRemoveVectoredExceptionHandler(vectored_handler);
+#endif
+}
-- 
2.23.0




More information about the wine-devel mailing list