Skip to content

Commit

Permalink
Fix remaining issues with detouring nvcuda
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Dec 5, 2021
1 parent 26bf0ee commit 2c6d7ff
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 29 deletions.
1 change: 0 additions & 1 deletion zluda_dump/src/os_win.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ impl PlatformLibrary {
if module == ptr::null_mut() {
break;
}
let mut size = 0;
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
if payload != ptr::null_mut() {
return Some(module as _);
Expand Down
4 changes: 2 additions & 2 deletions zluda_inject/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchap
detours-sys = { path = "../detours-sys" }

[dev-dependencies]
# dependency for integration tests
# all of those are used in integration tests
zluda_redirect = { path = "../zluda_redirect" }
# dependency for integration tests
zluda_dump = { path = "../zluda_dump" }
zluda_ml = { path = "../zluda_ml" }
3 changes: 2 additions & 1 deletion zluda_redirect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ crate-type = ["cdylib"]
[target.'cfg(windows)'.dependencies]
detours-sys = { path = "../detours-sys" }
wchar = "0.6"
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
tempfile = "3"
133 changes: 108 additions & 25 deletions zluda_redirect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,18 @@ extern crate winapi;
use std::{
collections::HashMap,
ffi::{c_void, CStr},
mem,
io, mem,
os::raw::c_uint,
ptr, slice, usize,
};

use detours_sys::{
DetourAttach, DetourEnumerateExports, DetourRestoreAfterWith, DetourTransactionAbort,
DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll,
DetourUpdateThread,
DetourAttach, DetourEnumerateExports, DetourGetEntryPoint, DetourRestoreAfterWith,
DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit,
DetourUpdateProcessWithDll, DetourUpdateThread,
};
use tempfile::TempDir;
use wchar::wch;
use winapi::{
shared::minwindef::{BOOL, LPVOID},
um::{
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
minwinbase::LPSECURITY_ATTRIBUTES,
processthreadsapi::{
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
},
tlhelp32::{
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
},
winbase::CREATE_SUSPENDED,
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
},
};
use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
Expand All @@ -50,6 +35,26 @@ use winapi::{
shared::winerror::NO_ERROR,
um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
};
use winapi::{
shared::{
minwindef::{BOOL, LPVOID},
winerror::E_UNEXPECTED,
},
um::{
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
libloaderapi::GetModuleHandleW,
minwinbase::LPSECURITY_ATTRIBUTES,
processthreadsapi::{
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
},
tlhelp32::{
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
},
winbase::{CopyFileW, CreateSymbolicLinkW, CREATE_SUSPENDED},
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
},
};

include!("payload_guid.rs");

Expand Down Expand Up @@ -375,6 +380,59 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW(
continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
}

static mut MAIN: unsafe extern "system" fn() -> DWORD = ZludaMain;

// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications
// "If a DLL with the same module name is already loaded in memory, the system
// uses the loaded DLL, no matter which directory it is in. The system does not
// search for the DLL."
#[allow(non_snake_case)]
unsafe extern "system" fn ZludaMain() -> DWORD {
let temp_dir = match do_zluda_preload() {
Ok(f) => f,
Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as u32,
};
let result = MAIN();
drop(temp_dir);
result
}

unsafe fn do_zluda_preload() -> std::io::Result<TempDir> {
let temp_dir = tempfile::tempdir()?;
do_single_zluda_preload(&temp_dir, ZLUDA_PATH_UTF16.unwrap().as_ptr(), NVCUDA_UTF8)?;
do_single_zluda_preload(&temp_dir, ZLUDA_ML_PATH_UTF16.unwrap().as_ptr(), NVML_UTF8)?;
Ok(temp_dir)
}

unsafe fn do_single_zluda_preload(
temp_dir: &TempDir,
full_path: *const u16,
file_name: &'static str,
) -> io::Result<()> {
let mut temp_file_path = temp_dir.path().to_path_buf();
temp_file_path.push(file_name);
let mut temp_file_path_utf16 = temp_file_path
.into_os_string()
.to_string_lossy()
.encode_utf16()
.collect::<Vec<_>>();
temp_file_path_utf16.push(0);
// Probably we are not in developer mode, do a copty then
if 0 == CreateSymbolicLinkW(
temp_file_path_utf16.as_ptr(),
full_path,
0x2, //SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE
) {
if 0 == CopyFileW(full_path, temp_file_path_utf16.as_ptr(), 1) {
return Err(io::Error::last_os_error());
}
}
if ptr::null_mut() == ZludaLoadLibraryW_NoRedirect(temp_file_path_utf16.as_ptr()) {
return Err(io::Error::last_os_error());
}
Ok(())
}

// This type encapsulates typical calling sequence of detours and cleanup.
// We have two ways we do detours:
// * If we are loaded before nvcuda.dll, we hook LoadLibrary*
Expand Down Expand Up @@ -668,8 +726,8 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
// redirecting LoadLibrary* to load ZLUDA, we override already loaded
// functions
let detach_guard = match get_cuinit() {
Some((nvcuda_mod, _)) => attach_cuinit(nvcuda_mod),
None => attach_load_libary(),
Some((nvcuda_mod, _)) => detour_already_loaded_nvcuda(nvcuda_mod),
None => detour_main(),
};
match detach_guard {
Some(g) => {
Expand Down Expand Up @@ -724,7 +782,7 @@ unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> {
}

#[must_use]
unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr());
if zluda_module == ptr::null_mut() {
return None;
Expand All @@ -747,7 +805,22 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
(original_fn_address as _, override_fn_address),
);
}
DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs)
let detour_functions = vec![
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
),
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
(
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
ZludaLoadLibraryExA as _,
),
(
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
ZludaLoadLibraryExW as _,
),
];
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, override_fn_pairs)
}

unsafe extern "system" fn cuda_unsupported() -> c_uint {
Expand Down Expand Up @@ -776,8 +849,18 @@ unsafe extern "stdcall" fn gather_imports_impl(
}

#[must_use]
unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
unsafe fn detour_main() -> Option<DetourDetachGuard> {
let exe_handle = GetModuleHandleW(ptr::null());
let entry_point = DetourGetEntryPoint(exe_handle as _);
if entry_point == ptr::null_mut() {
return None;
}
MAIN = mem::transmute(entry_point);
let detour_functions = vec![
(
&mut MAIN as *mut _ as *mut *mut c_void,
ZludaMain as *mut c_void,
),
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
Expand Down

0 comments on commit 2c6d7ff

Please sign in to comment.