Clean up ZLUDA redirection helper

This commit is contained in:
Andrzej Janik 2022-02-04 14:14:51 +01:00
parent 2753d956df
commit 164c172236

View file

@ -3,11 +3,7 @@
extern crate detours_sys; extern crate detours_sys;
extern crate winapi; extern crate winapi;
use std::{ use std::{ffi::c_void, mem, ptr, slice, usize};
collections::HashMap,
ffi::{c_void, CStr},
mem, ptr, slice, usize,
};
use detours_sys::{ use detours_sys::{
DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin, DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
@ -18,6 +14,7 @@ use winapi::{
shared::minwindef::{BOOL, LPVOID}, shared::minwindef::{BOOL, LPVOID},
um::{ um::{
handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
libloaderapi::GetModuleFileNameW,
minwinbase::LPSECURITY_ATTRIBUTES, minwinbase::LPSECURITY_ATTRIBUTES,
processthreadsapi::{ processthreadsapi::{
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread, CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
@ -32,15 +29,12 @@ use winapi::{
}; };
use winapi::{ use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE}, shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
um::{ um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
libloaderapi::{GetModuleHandleA, LoadLibraryExA},
winnt::LPCSTR,
},
}; };
use winapi::{ use winapi::{
shared::minwindef::{FARPROC, HINSTANCE}, shared::minwindef::{FARPROC, HINSTANCE},
um::{ um::{
libloaderapi::{GetModuleFileNameA, GetProcAddress}, libloaderapi::GetProcAddress,
processthreadsapi::{CreateProcessAsUserW, CreateProcessW}, processthreadsapi::{CreateProcessAsUserW, CreateProcessW},
winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW}, winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW},
winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR}, winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR},
@ -158,15 +152,6 @@ unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
hModule: HMODULE, hModule: HMODULE,
lpProcName: LPCSTR, lpProcName: LPCSTR,
) -> FARPROC { ) -> FARPROC {
if let Some(detour_guard) = &DETOUR_STATE {
if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
let proc_name = CStr::from_ptr(lpProcName);
return match detour_guard.overriden_cuda_fns.get(proc_name) {
Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
None => ptr::null_mut(),
};
}
}
GetProcAddress(hModule, lpProcName) GetProcAddress(hModule, lpProcName)
} }
@ -384,8 +369,6 @@ struct DetourDetachGuard {
suspended_threads: Vec<*mut c_void>, suspended_threads: Vec<*mut c_void>,
// First element is the original fn, second is the new fn // First element is the original fn, second is the new fn
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>, overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
nvcuda_module: HMODULE,
overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
} }
impl DetourDetachGuard { impl DetourDetachGuard {
@ -394,17 +377,11 @@ impl DetourDetachGuard {
// first element in the pair, because somehow otherwise original functions // first element in the pair, because somehow otherwise original functions
// also get overriden, so for example ZludaLoadLibraryExW ends calling // also get overriden, so for example ZludaLoadLibraryExW ends calling
// itself recursively until stack overflow exception occurs // itself recursively until stack overflow exception occurs
unsafe fn detour_functions<'a>( unsafe fn new<'a>() -> Option<Self> {
nvcuda_module: HMODULE,
non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
) -> Option<Self> {
let mut result = DetourDetachGuard { let mut result = DetourDetachGuard {
state: DetourUndoState::DoNothing, state: DetourUndoState::DoNothing,
suspended_threads: Vec::new(), suspended_threads: Vec::new(),
overriden_non_cuda_fns: non_cuda_fns, overriden_non_cuda_fns: Vec::new(),
nvcuda_module,
overriden_cuda_fns: cuda_fns,
}; };
if DetourTransactionBegin() != NO_ERROR as i32 { if DetourTransactionBegin() != NO_ERROR as i32 {
return None; return None;
@ -419,6 +396,19 @@ impl DetourDetachGuard {
} }
} }
result.overriden_non_cuda_fns.extend_from_slice(&[ result.overriden_non_cuda_fns.extend_from_slice(&[
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
),
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
(
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
ZludaLoadLibraryExA as _,
),
(
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
ZludaLoadLibraryExW as _,
),
( (
&mut CREATE_PROCESS_A as *mut _ as _, &mut CREATE_PROCESS_A as *mut _ as _,
ZludaCreateProcessA as _, ZludaCreateProcessA as _,
@ -440,12 +430,7 @@ impl DetourDetachGuard {
ZludaCreateProcessWithTokenW as _, ZludaCreateProcessWithTokenW as _,
), ),
]); ]);
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain( for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied() {
result
.overriden_cuda_fns
.values_mut()
.map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
) {
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 { if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
return None; return None;
} }
@ -659,23 +644,10 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
if DetourRestoreAfterWith() == FALSE { if DetourRestoreAfterWith() == FALSE {
return FALSE; return FALSE;
} }
if !initialize_current_module_name(instDLL) { if !initialize_globals(instDLL) {
return FALSE; return FALSE;
} }
match get_zluda_dlls_paths() { match DetourDetachGuard::new() {
Some((nvcuda_path, nvml_path)) => {
ZLUDA_PATH_UTF8 = Some(nvcuda_path);
ZLUDA_ML_PATH_UTF8 = Some(nvml_path);
ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path)
.encode_utf16()
.collect::<Vec<_>>();
ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path)
.encode_utf16()
.collect::<Vec<_>>();
}
None => return FALSE,
}
match detour_already_loaded_nvcuda() {
Some(g) => { Some(g) => {
DETOUR_STATE = Some(g); DETOUR_STATE = Some(g);
TRUE TRUE
@ -692,55 +664,55 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
} }
} }
#[must_use] unsafe fn initialize_globals(current_module: HINSTANCE) -> bool {
unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool { let mut module_name = vec![0; 128 as usize];
let mut name = vec![0; 128 as usize];
loop { loop {
let size = GetModuleFileNameA( let size = GetModuleFileNameW(
current_module, current_module,
name.as_mut_ptr() as *mut _, module_name.as_mut_ptr(),
name.len() as u32, module_name.len() as u32,
); );
if size == 0 { if size == 0 {
return false; return false;
} }
if size < name.len() as u32 { if size < module_name.len() as u32 {
name.truncate(size as usize); module_name.truncate(size as usize);
CURRENT_MODULE_FILENAME = name; module_name.push(0);
return true; CURRENT_MODULE_FILENAME = String::from_utf16_lossy(&module_name).into_bytes();
break;
} }
name.resize(name.len() * 2, 0); module_name.resize(module_name.len() * 2, 0);
} }
if !load_global_string(
&PAYLOAD_NVML_GUID,
&mut ZLUDA_ML_PATH_UTF8,
&mut ZLUDA_ML_PATH_UTF16,
) {
return false;
}
if !load_global_string(
&PAYLOAD_NVCUDA_GUID,
&mut ZLUDA_PATH_UTF8,
&mut ZLUDA_PATH_UTF16,
) {
return false;
}
true
} }
#[must_use] fn load_global_string(
unsafe fn detour_already_loaded_nvcuda() -> Option<DetourDetachGuard> { guid: &detours_sys::GUID,
let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _); utf8_path: &mut Option<&'static [u8]>,
let detour_functions = vec![ utf16_path: &mut Vec<u16>,
( ) -> bool {
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void, if let Some(payload) = get_payload(guid) {
ZludaLoadLibraryA as *mut c_void, *utf8_path = Some(payload);
), *utf16_path = unsafe { std::str::from_utf8_unchecked(payload) }
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _), .encode_utf16()
( .collect::<Vec<_>>();
&mut LOAD_LIBRARY_EX_A as *mut _ as _, true
ZludaLoadLibraryExA as _, } else {
), false
(
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
ZludaLoadLibraryExW as _,
),
];
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, HashMap::new())
}
fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> {
match get_payload(&PAYLOAD_NVCUDA_GUID) {
None => None,
Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
None => return None,
Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)),
},
} }
} }