mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Clean up ZLUDA redirection helper
This commit is contained in:
parent
2753d956df
commit
164c172236
1 changed files with 61 additions and 89 deletions
|
@ -3,11 +3,7 @@
|
|||
extern crate detours_sys;
|
||||
extern crate winapi;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ffi::{c_void, CStr},
|
||||
mem, ptr, slice, usize,
|
||||
};
|
||||
use std::{ffi::c_void, mem, ptr, slice, usize};
|
||||
|
||||
use detours_sys::{
|
||||
DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
|
||||
|
@ -18,6 +14,7 @@ use winapi::{
|
|||
shared::minwindef::{BOOL, LPVOID},
|
||||
um::{
|
||||
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
|
||||
libloaderapi::GetModuleFileNameW,
|
||||
minwinbase::LPSECURITY_ATTRIBUTES,
|
||||
processthreadsapi::{
|
||||
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
|
||||
|
@ -32,15 +29,12 @@ use winapi::{
|
|||
};
|
||||
use winapi::{
|
||||
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
|
||||
um::{
|
||||
libloaderapi::{GetModuleHandleA, LoadLibraryExA},
|
||||
winnt::LPCSTR,
|
||||
},
|
||||
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
|
||||
};
|
||||
use winapi::{
|
||||
shared::minwindef::{FARPROC, HINSTANCE},
|
||||
um::{
|
||||
libloaderapi::{GetModuleFileNameA, GetProcAddress},
|
||||
libloaderapi::GetProcAddress,
|
||||
processthreadsapi::{CreateProcessAsUserW, CreateProcessW},
|
||||
winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW},
|
||||
winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR},
|
||||
|
@ -158,15 +152,6 @@ unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
|
|||
hModule: HMODULE,
|
||||
lpProcName: LPCSTR,
|
||||
) -> FARPROC {
|
||||
if let Some(detour_guard) = &DETOUR_STATE {
|
||||
if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
|
||||
let proc_name = CStr::from_ptr(lpProcName);
|
||||
return match detour_guard.overriden_cuda_fns.get(proc_name) {
|
||||
Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
|
||||
None => ptr::null_mut(),
|
||||
};
|
||||
}
|
||||
}
|
||||
GetProcAddress(hModule, lpProcName)
|
||||
}
|
||||
|
||||
|
@ -384,8 +369,6 @@ struct DetourDetachGuard {
|
|||
suspended_threads: Vec<*mut c_void>,
|
||||
// First element is the original fn, second is the new fn
|
||||
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
|
||||
nvcuda_module: HMODULE,
|
||||
overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
|
||||
}
|
||||
|
||||
impl DetourDetachGuard {
|
||||
|
@ -394,17 +377,11 @@ impl DetourDetachGuard {
|
|||
// first element in the pair, because somehow otherwise original functions
|
||||
// also get overriden, so for example ZludaLoadLibraryExW ends calling
|
||||
// itself recursively until stack overflow exception occurs
|
||||
unsafe fn detour_functions<'a>(
|
||||
nvcuda_module: HMODULE,
|
||||
non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
|
||||
cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
|
||||
) -> Option<Self> {
|
||||
unsafe fn new<'a>() -> Option<Self> {
|
||||
let mut result = DetourDetachGuard {
|
||||
state: DetourUndoState::DoNothing,
|
||||
suspended_threads: Vec::new(),
|
||||
overriden_non_cuda_fns: non_cuda_fns,
|
||||
nvcuda_module,
|
||||
overriden_cuda_fns: cuda_fns,
|
||||
overriden_non_cuda_fns: Vec::new(),
|
||||
};
|
||||
if DetourTransactionBegin() != NO_ERROR as i32 {
|
||||
return None;
|
||||
|
@ -419,6 +396,19 @@ impl DetourDetachGuard {
|
|||
}
|
||||
}
|
||||
result.overriden_non_cuda_fns.extend_from_slice(&[
|
||||
(
|
||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
||||
ZludaLoadLibraryA as *mut c_void,
|
||||
),
|
||||
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
|
||||
(
|
||||
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
|
||||
ZludaLoadLibraryExA as _,
|
||||
),
|
||||
(
|
||||
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
|
||||
ZludaLoadLibraryExW as _,
|
||||
),
|
||||
(
|
||||
&mut CREATE_PROCESS_A as *mut _ as _,
|
||||
ZludaCreateProcessA as _,
|
||||
|
@ -440,12 +430,7 @@ impl DetourDetachGuard {
|
|||
ZludaCreateProcessWithTokenW as _,
|
||||
),
|
||||
]);
|
||||
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain(
|
||||
result
|
||||
.overriden_cuda_fns
|
||||
.values_mut()
|
||||
.map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
|
||||
) {
|
||||
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied() {
|
||||
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
|
||||
return None;
|
||||
}
|
||||
|
@ -659,23 +644,10 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
|||
if DetourRestoreAfterWith() == FALSE {
|
||||
return FALSE;
|
||||
}
|
||||
if !initialize_current_module_name(instDLL) {
|
||||
if !initialize_globals(instDLL) {
|
||||
return FALSE;
|
||||
}
|
||||
match get_zluda_dlls_paths() {
|
||||
Some((nvcuda_path, nvml_path)) => {
|
||||
ZLUDA_PATH_UTF8 = Some(nvcuda_path);
|
||||
ZLUDA_ML_PATH_UTF8 = Some(nvml_path);
|
||||
ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path)
|
||||
.encode_utf16()
|
||||
.collect::<Vec<_>>();
|
||||
ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path)
|
||||
.encode_utf16()
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
None => return FALSE,
|
||||
}
|
||||
match detour_already_loaded_nvcuda() {
|
||||
match DetourDetachGuard::new() {
|
||||
Some(g) => {
|
||||
DETOUR_STATE = Some(g);
|
||||
TRUE
|
||||
|
@ -692,55 +664,55 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
|||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool {
|
||||
let mut name = vec![0; 128 as usize];
|
||||
unsafe fn initialize_globals(current_module: HINSTANCE) -> bool {
|
||||
let mut module_name = vec![0; 128 as usize];
|
||||
loop {
|
||||
let size = GetModuleFileNameA(
|
||||
let size = GetModuleFileNameW(
|
||||
current_module,
|
||||
name.as_mut_ptr() as *mut _,
|
||||
name.len() as u32,
|
||||
module_name.as_mut_ptr(),
|
||||
module_name.len() as u32,
|
||||
);
|
||||
if size == 0 {
|
||||
return false;
|
||||
}
|
||||
if size < name.len() as u32 {
|
||||
name.truncate(size as usize);
|
||||
CURRENT_MODULE_FILENAME = name;
|
||||
return true;
|
||||
if size < module_name.len() as u32 {
|
||||
module_name.truncate(size as usize);
|
||||
module_name.push(0);
|
||||
CURRENT_MODULE_FILENAME = String::from_utf16_lossy(&module_name).into_bytes();
|
||||
break;
|
||||
}
|
||||
name.resize(name.len() * 2, 0);
|
||||
module_name.resize(module_name.len() * 2, 0);
|
||||
}
|
||||
if !load_global_string(
|
||||
&PAYLOAD_NVML_GUID,
|
||||
&mut ZLUDA_ML_PATH_UTF8,
|
||||
&mut ZLUDA_ML_PATH_UTF16,
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
if !load_global_string(
|
||||
&PAYLOAD_NVCUDA_GUID,
|
||||
&mut ZLUDA_PATH_UTF8,
|
||||
&mut ZLUDA_PATH_UTF16,
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
unsafe fn detour_already_loaded_nvcuda() -> Option<DetourDetachGuard> {
|
||||
let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _);
|
||||
let detour_functions = vec![
|
||||
(
|
||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
||||
ZludaLoadLibraryA as *mut c_void,
|
||||
),
|
||||
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
|
||||
(
|
||||
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
|
||||
ZludaLoadLibraryExA as _,
|
||||
),
|
||||
(
|
||||
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
|
||||
ZludaLoadLibraryExW as _,
|
||||
),
|
||||
];
|
||||
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, HashMap::new())
|
||||
}
|
||||
|
||||
fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> {
|
||||
match get_payload(&PAYLOAD_NVCUDA_GUID) {
|
||||
None => None,
|
||||
Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
|
||||
None => return None,
|
||||
Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)),
|
||||
},
|
||||
fn load_global_string(
|
||||
guid: &detours_sys::GUID,
|
||||
utf8_path: &mut Option<&'static [u8]>,
|
||||
utf16_path: &mut Vec<u16>,
|
||||
) -> bool {
|
||||
if let Some(payload) = get_payload(guid) {
|
||||
*utf8_path = Some(payload);
|
||||
*utf16_path = unsafe { std::str::from_utf8_unchecked(payload) }
|
||||
.encode_utf16()
|
||||
.collect::<Vec<_>>();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue