diff --git a/safetyhook/safetyhook.cpp b/safetyhook/safetyhook.cpp index 4f8a9ef..496d1df 100644 --- a/safetyhook/safetyhook.cpp +++ b/safetyhook/safetyhook.cpp @@ -2,7 +2,7 @@ #define NOMINMAX -#include +#include "safetyhook.hpp" // @@ -23,19 +23,8 @@ #endif -namespace safetyhook { -template constexpr T align_up(T address, size_t align) { - const auto unaligned_address = (uintptr_t)address; - const auto aligned_address = (unaligned_address + align - 1) & ~(align - 1); - return (T)aligned_address; -} - -template constexpr T align_down(T address, size_t align) { - const auto unaligned_address = (uintptr_t)address; - const auto aligned_address = unaligned_address & ~(align - 1); - return (T)aligned_address; -} +namespace safetyhook { Allocation::Allocation(Allocation&& other) noexcept { *this = std::move(other); } @@ -338,10 +327,10 @@ VmtHook create_vmt(void* object) { #error "Windows.h not found" #endif -#if __has_include() -#include -#elif __has_include() -#include +#if __has_include("Zydis/Zydis.h") +#include "Zydis/Zydis.h" +#elif __has_include("Zydis.h") +#include "Zydis.h" #else #error "Zydis not found" #endif @@ -356,7 +345,7 @@ struct JmpE9 { uint32_t offset{0}; }; -#if defined(_M_X64) +#if SAFETYHOOK_ARCH_X86_64 struct JmpFF { uint8_t opcode0{0xFF}; uint8_t opcode1{0x25}; @@ -373,7 +362,7 @@ struct TrampolineEpilogueFF { JmpFF jmp_to_original{}; uint64_t original_address{}; }; -#elif defined(_M_IX86) +#elif SAFETYHOOK_ARCH_X86_32 struct TrampolineEpilogueE9 { JmpE9 jmp_to_original{}; JmpE9 jmp_to_destination{}; @@ -381,7 +370,7 @@ struct TrampolineEpilogueE9 { #endif #pragma pack(pop) -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 static auto make_jmp_ff(uint8_t* src, uint8_t* dst, uint8_t* data) { JmpFF jmp{}; @@ -397,12 +386,6 @@ static auto make_jmp_ff(uint8_t* src, uint8_t* dst, uint8_t* data) { return std::unexpected{InlineHook::Error::not_enough_space(dst)}; } - auto um = unprotect(src, size); - - if (!um) { - return std::unexpected{InlineHook::Error::failed_to_unprotect(src)}; - } - if (size > sizeof(JmpFF)) { std::fill_n(src, size, static_cast(0x90)); } @@ -427,12 +410,6 @@ constexpr auto make_jmp_e9(uint8_t* src, uint8_t* dst) { return std::unexpected{InlineHook::Error::not_enough_space(dst)}; } - auto um = unprotect(src, size); - - if (!um) { - return std::unexpected{InlineHook::Error::failed_to_unprotect(src)}; - } - if (size > sizeof(JmpE9)) { std::fill_n(src, size, static_cast(0x90)); } @@ -446,12 +423,10 @@ static bool decode(ZydisDecodedInstruction* ix, uint8_t* ip) { ZydisDecoder decoder{}; ZyanStatus status; -#if defined(_M_X64) +#if SAFETYHOOK_ARCH_X86_64 status = ZydisDecoderInit(&decoder, ZYDIS_MACHINE_MODE_LONG_64, ZYDIS_STACK_WIDTH_64); -#elif defined(_M_IX86) +#elif SAFETYHOOK_ARCH_X86_32 status = ZydisDecoderInit(&decoder, ZYDIS_MACHINE_MODE_LEGACY_32, ZYDIS_STACK_WIDTH_32); -#else -#error "Unsupported architecture" #endif if (!ZYAN_SUCCESS(status)) { @@ -516,11 +491,11 @@ std::expected InlineHook::setup( m_destination = destination; if (auto e9_result = e9_hook(allocator); !e9_result) { -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 if (auto ff_result = ff_hook(allocator); !ff_result) { return ff_result; } -#else +#elif SAFETYHOOK_ARCH_X86_32 return e9_result; #endif } @@ -640,13 +615,13 @@ std::expected InlineHook::e9_hook(const std::shared_ptr src = reinterpret_cast(&trampoline_epilogue->jmp_to_destination); dst = m_destination; -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 auto data = reinterpret_cast(&trampoline_epilogue->destination_address); if (auto result = emit_jmp_ff(src, dst, data); !result) { return std::unexpected{result.error()}; } -#else +#elif SAFETYHOOK_ARCH_X86_32 if (auto result = emit_jmp_e9(src, dst); !result) { return std::unexpected{result.error()}; } @@ -655,19 +630,13 @@ std::expected InlineHook::e9_hook(const std::shared_ptr std::optional error; // jmp from original to trampoline. - execute_while_frozen( - [this, &trampoline_epilogue, &error] { - if (auto result = emit_jmp_e9(m_target, - reinterpret_cast(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size()); - !result) { - error = result.error(); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_target + i, m_trampoline.data() + i); - } - }); + trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &trampoline_epilogue, &error] { + if (auto result = emit_jmp_e9(m_target, reinterpret_cast(&trampoline_epilogue->jmp_to_destination), + m_original_bytes.size()); + !result) { + error = result.error(); + } + }); if (error) { return std::unexpected{*error}; @@ -676,7 +645,7 @@ std::expected InlineHook::e9_hook(const std::shared_ptr return {}; } -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 std::expected InlineHook::ff_hook(const std::shared_ptr& allocator) { m_original_bytes.clear(); m_trampoline_size = sizeof(TrampolineEpilogueFF); @@ -723,18 +692,12 @@ std::expected InlineHook::ff_hook(const std::shared_ptr std::optional error; // jmp from original to trampoline. - execute_while_frozen( - [this, &error] { - if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); - !result) { - error = result.error(); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_target + i, m_trampoline.data() + i); - } - }); + trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &error] { + if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); + !result) { + error = result.error(); + } + }); if (error) { return std::unexpected{*error}; @@ -751,17 +714,8 @@ void InlineHook::destroy() { return; } - execute_while_frozen( - [this] { - if (auto um = unprotect(m_target, m_original_bytes.size())) { - std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_trampoline.data() + i, m_target + i); - } - }); + trap_threads(m_trampoline.data(), m_target, m_original_bytes.size(), + [this] { std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); }); m_trampoline.free(); } @@ -778,7 +732,7 @@ void InlineHook::destroy() { namespace safetyhook { -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 constexpr std::array asm_data = {0xFF, 0x35, 0x79, 0x01, 0x00, 0x00, 0x54, 0x54, 0x55, 0x50, 0x53, 0x51, 0x52, 0x56, 0x57, 0x41, 0x50, 0x41, 0x51, 0x41, 0x52, 0x41, 0x53, 0x41, 0x54, 0x41, 0x55, 0x41, 0x56, 0x41, 0x57, 0x9C, 0x48, 0x81, 0xEC, 0x00, 0x01, 0x00, 0x00, 0xF3, 0x44, 0x0F, 0x7F, 0xBC, 0x24, 0xF0, 0x00, 0x00, 0x00, 0xF3, @@ -800,7 +754,7 @@ constexpr std::array asm_data = {0xFF, 0x35, 0x79, 0x01, 0x00, 0x0 0x00, 0x00, 0x48, 0x81, 0xC4, 0x00, 0x01, 0x00, 0x00, 0x9D, 0x41, 0x5F, 0x41, 0x5E, 0x41, 0x5D, 0x41, 0x5C, 0x41, 0x5B, 0x41, 0x5A, 0x41, 0x59, 0x41, 0x58, 0x5F, 0x5E, 0x5A, 0x59, 0x5B, 0x58, 0x5D, 0x48, 0x8D, 0x64, 0x24, 0x08, 0x5C, 0xC3, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; -#else +#elif SAFETYHOOK_ARCH_X86_32 constexpr std::array asm_data = {0xFF, 0x35, 0xA7, 0x00, 0x00, 0x00, 0x54, 0x54, 0x55, 0x50, 0x53, 0x51, 0x52, 0x56, 0x57, 0x9C, 0x81, 0xEC, 0x80, 0x00, 0x00, 0x00, 0xF3, 0x0F, 0x7F, 0x7C, 0x24, 0x70, 0xF3, 0x0F, 0x7F, 0x74, 0x24, 0x60, 0xF3, 0x0F, 0x7F, 0x6C, 0x24, 0x50, 0xF3, 0x0F, 0x7F, 0x64, 0x24, 0x40, 0xF3, 0x0F, 0x7F, 0x5C, @@ -866,9 +820,9 @@ std::expected MidHook::setup( std::copy(asm_data.begin(), asm_data.end(), m_stub.data()); -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 store(m_stub.data() + sizeof(asm_data) - 16, m_destination); -#else +#elif SAFETYHOOK_ARCH_X86_32 store(m_stub.data() + sizeof(asm_data) - 8, m_destination); // 32-bit has some relocations we need to fix up as well. @@ -885,9 +839,9 @@ std::expected MidHook::setup( m_hook = std::move(*hook_result); -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 store(m_stub.data() + sizeof(asm_data) - 8, m_hook.trampoline().data()); -#else +#elif SAFETYHOOK_ARCH_X86_32 store(m_stub.data() + sizeof(asm_data) - 4, m_hook.trampoline().data()); #endif @@ -899,6 +853,9 @@ std::expected MidHook::setup( // Source file: thread_freezer.cpp // +#include +#include + #if __has_include() #include #elif __has_include() @@ -906,123 +863,143 @@ std::expected MidHook::setup( #else #error "Windows.h not found" #endif -#include -#pragma comment(lib, "ntdll") - -extern "C" { -NTSTATUS -NTAPI -NtGetNextThread(HANDLE ProcessHandle, HANDLE ThreadHandle, ACCESS_MASK DesiredAccess, ULONG HandleAttributes, - ULONG Flags, PHANDLE NewThreadHandle); -} namespace safetyhook { -void execute_while_frozen( - const std::function& run_fn, const std::function& visit_fn) { - // Freeze all threads. - int num_threads_frozen; - auto first_run = true; - - do { - num_threads_frozen = 0; - HANDLE thread{}; - - while (true) { - HANDLE next_thread{}; - const auto status = NtGetNextThread(GetCurrentProcess(), thread, - THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, - 0, &next_thread); - - if (thread != nullptr) { - CloseHandle(thread); - } +struct TrapInfo { + uint8_t* page_start; + uint8_t* page_end; + uint8_t* from; + uint8_t* to; + size_t len; +}; - if (!NT_SUCCESS(status)) { - break; - } +class TrapManager { +public: + static std::mutex mutex; + static std::unique_ptr instance; - thread = next_thread; + TrapManager() { m_trap_veh = AddVectoredExceptionHandler(1, trap_handler); } + ~TrapManager() { + if (m_trap_veh != nullptr) { + RemoveVectoredExceptionHandler(m_trap_veh); + } + } - const auto thread_id = GetThreadId(thread); + TrapInfo* find_trap(uint8_t* address) { + auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { + return address >= trap.second.from && address < trap.second.from + trap.second.len; + }); - if (thread_id == 0 || thread_id == GetCurrentThreadId()) { - continue; - } + if (search == m_traps.end()) { + return nullptr; + } - const auto suspend_count = SuspendThread(thread); + return &search->second; + } - if (suspend_count == static_cast(-1)) { - continue; - } + TrapInfo* find_trap_page(uint8_t* address) { + auto search = std::find_if(m_traps.begin(), m_traps.end(), + [address](auto& trap) { return address >= trap.second.page_start && address < trap.second.page_end; }); - // Check if the thread was already frozen. Only resume if the thread was already frozen, and it wasn't the - // first run of this freeze loop to account for threads that may have already been frozen for other reasons. - if (suspend_count != 0 && !first_run) { - ResumeThread(thread); - continue; - } + if (search == m_traps.end()) { + return nullptr; + } - CONTEXT thread_ctx{}; + return &search->second; + } - thread_ctx.ContextFlags = CONTEXT_FULL; + void add_trap(uint8_t* from, uint8_t* to, size_t len) { + m_traps.insert_or_assign(from, TrapInfo{.page_start = align_down(from, 0x1000), + .page_end = align_up(from + len, 0x1000), + .from = from, + .to = to, + .len = len}); + } - if (GetThreadContext(thread, &thread_ctx) == FALSE) { - continue; - } +private: + std::map m_traps; + PVOID m_trap_veh{}; - if (visit_fn) { - visit_fn(static_cast(thread_id), static_cast(thread), - static_cast(&thread_ctx)); - } + static LONG CALLBACK trap_handler(PEXCEPTION_POINTERS exp) { + auto exception_code = exp->ExceptionRecord->ExceptionCode; - ++num_threads_frozen; + if (exception_code != EXCEPTION_ACCESS_VIOLATION) { + return EXCEPTION_CONTINUE_SEARCH; } - first_run = false; - } while (num_threads_frozen != 0); + std::scoped_lock lock{mutex}; + auto* faulting_address = reinterpret_cast(exp->ExceptionRecord->ExceptionInformation[1]); + auto* trap = instance->find_trap(faulting_address); - // Run the function. - if (run_fn) { - run_fn(); + if (trap == nullptr) { + if (instance->find_trap_page(faulting_address) != nullptr) { + return EXCEPTION_CONTINUE_EXECUTION; + } else { + return EXCEPTION_CONTINUE_SEARCH; + } + } + + auto* ctx = exp->ContextRecord; + + for (size_t i = 0; i < trap->len; i++) { + fix_ip(ctx, trap->from + i, trap->to + i); + } + + return EXCEPTION_CONTINUE_EXECUTION; } +}; - // Resume all threads. - HANDLE thread{}; +std::mutex TrapManager::mutex; +std::unique_ptr TrapManager::instance; - while (true) { - HANDLE next_thread{}; - const auto status = NtGetNextThread(GetCurrentProcess(), thread, - THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, 0, - &next_thread); +void find_me() { +} - if (thread != nullptr) { - CloseHandle(thread); - } +void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function& run_fn) { + MEMORY_BASIC_INFORMATION find_me_mbi{}; + MEMORY_BASIC_INFORMATION from_mbi{}; + MEMORY_BASIC_INFORMATION to_mbi{}; - if (!NT_SUCCESS(status)) { - break; - } + VirtualQuery(reinterpret_cast(find_me), &find_me_mbi, sizeof(find_me_mbi)); + VirtualQuery(from, &from_mbi, sizeof(from_mbi)); + VirtualQuery(to, &to_mbi, sizeof(to_mbi)); - thread = next_thread; + auto new_protect = PAGE_READWRITE; - const auto thread_id = GetThreadId(thread); + if (from_mbi.AllocationBase == find_me_mbi.AllocationBase || to_mbi.AllocationBase == find_me_mbi.AllocationBase) { + new_protect = PAGE_EXECUTE_READWRITE; + } - if (thread_id == 0 || thread_id == GetCurrentThreadId()) { - continue; - } + std::scoped_lock lock{TrapManager::mutex}; - ResumeThread(thread); + if (TrapManager::instance == nullptr) { + TrapManager::instance = std::make_unique(); + } + + TrapManager::instance->add_trap(from, to, len); + + DWORD from_protect; + DWORD to_protect; + + VirtualProtect(from, len, new_protect, &from_protect); + VirtualProtect(to, len, new_protect, &to_protect); + + if (run_fn) { + run_fn(); } + + VirtualProtect(to, len, to_protect, &to_protect); + VirtualProtect(from, len, from_protect, &from_protect); } void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) { auto* ctx = reinterpret_cast(thread_ctx); -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 auto ip = ctx->Rip; -#else +#elif SAFETYHOOK_ARCH_X86_32 auto ip = ctx->Eip; #endif @@ -1030,9 +1007,9 @@ void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) { ip = reinterpret_cast(new_ip); } -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 ctx->Rip = ip; -#else +#elif SAFETYHOOK_ARCH_X86_32 ctx->Eip = ip; #endif } @@ -1249,17 +1226,17 @@ void VmtHook::remove(void* object) { const auto original_vmt = search->second; - execute_while_frozen([&] { - if (IsBadWritePtr(object, sizeof(void*))) { - return; - } + if (IsBadWritePtr(object, sizeof(void*))) { + m_objects.erase(search); + return; + } - if (*reinterpret_cast(object) != &m_new_vmt[1]) { - return; - } + if (*reinterpret_cast(object) != &m_new_vmt[1]) { + m_objects.erase(search); + return; + } - *reinterpret_cast(object) = original_vmt; - }); + *reinterpret_cast(object) = original_vmt; m_objects.erase(search); } @@ -1269,19 +1246,17 @@ void VmtHook::reset() { } void VmtHook::destroy() { - execute_while_frozen([this] { - for (const auto [object, original_vmt] : m_objects) { - if (IsBadWritePtr(object, sizeof(void*))) { - return; - } - - if (*reinterpret_cast(object) != &m_new_vmt[1]) { - return; - } + for (const auto [object, original_vmt] : m_objects) { + if (IsBadWritePtr(object, sizeof(void*))) { + continue; + } - *reinterpret_cast(object) = original_vmt; + if (*reinterpret_cast(object) != &m_new_vmt[1]) { + continue; } - }); + + *reinterpret_cast(object) = original_vmt; + } m_objects.clear(); m_new_vmt_allocation.reset(); diff --git a/safetyhook/safetyhook.hpp b/safetyhook/safetyhook.hpp index ea37370..9a33880 100644 --- a/safetyhook/safetyhook.hpp +++ b/safetyhook/safetyhook.hpp @@ -179,6 +179,55 @@ class Allocator final : public std::enable_shared_from_this { }; } // namespace safetyhook +// +// Header: safetyhook/common.hpp +// +// Include stack: +// - safetyhook.hpp +// - safetyhook/easy.hpp +// - safetyhook/inline_hook.hpp +// + +#pragma once + +#if defined(_MSC_VER) +#define SAFETYHOOK_COMPILER_MSVC 1 +#define SAFETYHOOK_COMPILER_GCC 0 +#define SAFETYHOOK_COMPILER_CLANG 0 +#elif defined(__GNUC__) +#define SAFETYHOOK_COMPILER_MSVC 0 +#define SAFETYHOOK_COMPILER_GCC 1 +#define SAFETYHOOK_COMPILER_CLANG 0 +#elif defined(__clang__) +#define SAFETYHOOK_COMPILER_MSVC 0 +#define SAFETYHOOK_COMPILER_GCC 0 +#define SAFETYHOOK_COMPILER_CLANG 1 +#else +#error "Unsupported compiler" +#endif + +#if SAFETYHOOK_COMPILER_MSVC +#if defined(_M_IX86) +#define SAFETYHOOK_ARCH_X86_32 1 +#define SAFETYHOOK_ARCH_X86_64 0 +#elif defined(_M_X64) +#define SAFETYHOOK_ARCH_X86_32 0 +#define SAFETYHOOK_ARCH_X86_64 1 +#else +#error "Unsupported architecture" +#endif +#elif SAFETYHOOK_COMPILER_GCC || SAFETYHOOK_COMPILER_CLANG +#if defined(__i386__) +#define SAFETYHOOK_ARCH_X86_32 1 +#define SAFETYHOOK_ARCH_X86_64 0 +#elif defined(__x86_64__) +#define SAFETYHOOK_ARCH_X86_32 0 +#define SAFETYHOOK_ARCH_X86_64 1 +#else +#error "Unsupported architecture" +#endif +#endif + // // Header: safetyhook/utility.hpp // @@ -225,6 +274,19 @@ class UnprotectMemory { }; [[nodiscard]] std::optional unprotect(uint8_t* address, size_t size); + +template constexpr T align_up(T address, size_t align) { + const auto unaligned_address = (uintptr_t)address; + const auto aligned_address = (unaligned_address + align - 1) & ~(align - 1); + return (T)aligned_address; +} + +template constexpr T align_down(T address, size_t align) { + const auto unaligned_address = (uintptr_t)address; + const auto aligned_address = unaligned_address & ~(align - 1); + return (T)aligned_address; +} + } // namespace safetyhook namespace safetyhook { @@ -508,7 +570,7 @@ class InlineHook final { const std::shared_ptr& allocator, uint8_t* target, uint8_t* destination); std::expected e9_hook(const std::shared_ptr& allocator); -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 std::expected ff_hook(const std::shared_ptr& allocator); #endif @@ -549,6 +611,7 @@ class InlineHook final { #include + namespace safetyhook { union Xmm { uint8_t u8[16]; @@ -586,9 +649,9 @@ struct Context32 { /// to the registers at the moment the hook is called. /// @note The structure is different depending on architecture. /// @note The structure only provides access to integer registers. -#ifdef _M_X64 +#if SAFETYHOOK_ARCH_X86_64 using Context = Context64; -#else +#elif SAFETYHOOK_ARCH_X86_32 using Context = Context32; #endif @@ -951,15 +1014,7 @@ using ThreadId = uint32_t; using ThreadHandle = void*; using ThreadContext = void*; -/// @brief Executes a function while all other threads are frozen. Also allows for visiting each frozen thread and -/// modifying it's context. -/// @param run_fn The function to run while all other threads are frozen. -/// @param visit_fn The function that will be called for each frozen thread. -/// @note The visit function will be called in the order that the threads were frozen. -/// @note The visit function will be called before the run function. -/// @note Keep the logic inside run_fn and visit_fn as simple as possible to avoid deadlocks. -void execute_while_frozen(const std::function& run_fn, - const std::function& visit_fn = {}); +void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function& run_fn); /// @brief Will modify the context of a thread's IP to point to a new address if its IP is at the old address. /// @param ctx The thread context to modify.