mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Clean up ZLUDA redirection helper
This commit is contained in:
parent
2753d956df
commit
164c172236
1 changed files with 61 additions and 89 deletions
|
@ -3,11 +3,7 @@
|
||||||
extern crate detours_sys;
|
extern crate detours_sys;
|
||||||
extern crate winapi;
|
extern crate winapi;
|
||||||
|
|
||||||
use std::{
|
use std::{ffi::c_void, mem, ptr, slice, usize};
|
||||||
collections::HashMap,
|
|
||||||
ffi::{c_void, CStr},
|
|
||||||
mem, ptr, slice, usize,
|
|
||||||
};
|
|
||||||
|
|
||||||
use detours_sys::{
|
use detours_sys::{
|
||||||
DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
|
DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
|
||||||
|
@ -18,6 +14,7 @@ use winapi::{
|
||||||
shared::minwindef::{BOOL, LPVOID},
|
shared::minwindef::{BOOL, LPVOID},
|
||||||
um::{
|
um::{
|
||||||
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
|
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
|
||||||
|
libloaderapi::GetModuleFileNameW,
|
||||||
minwinbase::LPSECURITY_ATTRIBUTES,
|
minwinbase::LPSECURITY_ATTRIBUTES,
|
||||||
processthreadsapi::{
|
processthreadsapi::{
|
||||||
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
|
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
|
||||||
|
@ -32,15 +29,12 @@ use winapi::{
|
||||||
};
|
};
|
||||||
use winapi::{
|
use winapi::{
|
||||||
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
|
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
|
||||||
um::{
|
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
|
||||||
libloaderapi::{GetModuleHandleA, LoadLibraryExA},
|
|
||||||
winnt::LPCSTR,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use winapi::{
|
use winapi::{
|
||||||
shared::minwindef::{FARPROC, HINSTANCE},
|
shared::minwindef::{FARPROC, HINSTANCE},
|
||||||
um::{
|
um::{
|
||||||
libloaderapi::{GetModuleFileNameA, GetProcAddress},
|
libloaderapi::GetProcAddress,
|
||||||
processthreadsapi::{CreateProcessAsUserW, CreateProcessW},
|
processthreadsapi::{CreateProcessAsUserW, CreateProcessW},
|
||||||
winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW},
|
winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW},
|
||||||
winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR},
|
winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR},
|
||||||
|
@ -158,15 +152,6 @@ unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
|
||||||
hModule: HMODULE,
|
hModule: HMODULE,
|
||||||
lpProcName: LPCSTR,
|
lpProcName: LPCSTR,
|
||||||
) -> FARPROC {
|
) -> FARPROC {
|
||||||
if let Some(detour_guard) = &DETOUR_STATE {
|
|
||||||
if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
|
|
||||||
let proc_name = CStr::from_ptr(lpProcName);
|
|
||||||
return match detour_guard.overriden_cuda_fns.get(proc_name) {
|
|
||||||
Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
|
|
||||||
None => ptr::null_mut(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GetProcAddress(hModule, lpProcName)
|
GetProcAddress(hModule, lpProcName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -384,8 +369,6 @@ struct DetourDetachGuard {
|
||||||
suspended_threads: Vec<*mut c_void>,
|
suspended_threads: Vec<*mut c_void>,
|
||||||
// First element is the original fn, second is the new fn
|
// First element is the original fn, second is the new fn
|
||||||
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
|
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
|
||||||
nvcuda_module: HMODULE,
|
|
||||||
overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DetourDetachGuard {
|
impl DetourDetachGuard {
|
||||||
|
@ -394,17 +377,11 @@ impl DetourDetachGuard {
|
||||||
// first element in the pair, because somehow otherwise original functions
|
// first element in the pair, because somehow otherwise original functions
|
||||||
// also get overriden, so for example ZludaLoadLibraryExW ends calling
|
// also get overriden, so for example ZludaLoadLibraryExW ends calling
|
||||||
// itself recursively until stack overflow exception occurs
|
// itself recursively until stack overflow exception occurs
|
||||||
unsafe fn detour_functions<'a>(
|
unsafe fn new<'a>() -> Option<Self> {
|
||||||
nvcuda_module: HMODULE,
|
|
||||||
non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
|
|
||||||
cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
|
|
||||||
) -> Option<Self> {
|
|
||||||
let mut result = DetourDetachGuard {
|
let mut result = DetourDetachGuard {
|
||||||
state: DetourUndoState::DoNothing,
|
state: DetourUndoState::DoNothing,
|
||||||
suspended_threads: Vec::new(),
|
suspended_threads: Vec::new(),
|
||||||
overriden_non_cuda_fns: non_cuda_fns,
|
overriden_non_cuda_fns: Vec::new(),
|
||||||
nvcuda_module,
|
|
||||||
overriden_cuda_fns: cuda_fns,
|
|
||||||
};
|
};
|
||||||
if DetourTransactionBegin() != NO_ERROR as i32 {
|
if DetourTransactionBegin() != NO_ERROR as i32 {
|
||||||
return None;
|
return None;
|
||||||
|
@ -419,6 +396,19 @@ impl DetourDetachGuard {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result.overriden_non_cuda_fns.extend_from_slice(&[
|
result.overriden_non_cuda_fns.extend_from_slice(&[
|
||||||
|
(
|
||||||
|
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
||||||
|
ZludaLoadLibraryA as *mut c_void,
|
||||||
|
),
|
||||||
|
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
|
||||||
|
(
|
||||||
|
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
|
||||||
|
ZludaLoadLibraryExA as _,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
|
||||||
|
ZludaLoadLibraryExW as _,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
&mut CREATE_PROCESS_A as *mut _ as _,
|
&mut CREATE_PROCESS_A as *mut _ as _,
|
||||||
ZludaCreateProcessA as _,
|
ZludaCreateProcessA as _,
|
||||||
|
@ -440,12 +430,7 @@ impl DetourDetachGuard {
|
||||||
ZludaCreateProcessWithTokenW as _,
|
ZludaCreateProcessWithTokenW as _,
|
||||||
),
|
),
|
||||||
]);
|
]);
|
||||||
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain(
|
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied() {
|
||||||
result
|
|
||||||
.overriden_cuda_fns
|
|
||||||
.values_mut()
|
|
||||||
.map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
|
|
||||||
) {
|
|
||||||
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
|
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
@ -659,23 +644,10 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
||||||
if DetourRestoreAfterWith() == FALSE {
|
if DetourRestoreAfterWith() == FALSE {
|
||||||
return FALSE;
|
return FALSE;
|
||||||
}
|
}
|
||||||
if !initialize_current_module_name(instDLL) {
|
if !initialize_globals(instDLL) {
|
||||||
return FALSE;
|
return FALSE;
|
||||||
}
|
}
|
||||||
match get_zluda_dlls_paths() {
|
match DetourDetachGuard::new() {
|
||||||
Some((nvcuda_path, nvml_path)) => {
|
|
||||||
ZLUDA_PATH_UTF8 = Some(nvcuda_path);
|
|
||||||
ZLUDA_ML_PATH_UTF8 = Some(nvml_path);
|
|
||||||
ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path)
|
|
||||||
.encode_utf16()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path)
|
|
||||||
.encode_utf16()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
}
|
|
||||||
None => return FALSE,
|
|
||||||
}
|
|
||||||
match detour_already_loaded_nvcuda() {
|
|
||||||
Some(g) => {
|
Some(g) => {
|
||||||
DETOUR_STATE = Some(g);
|
DETOUR_STATE = Some(g);
|
||||||
TRUE
|
TRUE
|
||||||
|
@ -692,55 +664,55 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
unsafe fn initialize_globals(current_module: HINSTANCE) -> bool {
|
||||||
unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool {
|
let mut module_name = vec![0; 128 as usize];
|
||||||
let mut name = vec![0; 128 as usize];
|
|
||||||
loop {
|
loop {
|
||||||
let size = GetModuleFileNameA(
|
let size = GetModuleFileNameW(
|
||||||
current_module,
|
current_module,
|
||||||
name.as_mut_ptr() as *mut _,
|
module_name.as_mut_ptr(),
|
||||||
name.len() as u32,
|
module_name.len() as u32,
|
||||||
);
|
);
|
||||||
if size == 0 {
|
if size == 0 {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if size < name.len() as u32 {
|
if size < module_name.len() as u32 {
|
||||||
name.truncate(size as usize);
|
module_name.truncate(size as usize);
|
||||||
CURRENT_MODULE_FILENAME = name;
|
module_name.push(0);
|
||||||
return true;
|
CURRENT_MODULE_FILENAME = String::from_utf16_lossy(&module_name).into_bytes();
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
name.resize(name.len() * 2, 0);
|
module_name.resize(module_name.len() * 2, 0);
|
||||||
}
|
}
|
||||||
|
if !load_global_string(
|
||||||
|
&PAYLOAD_NVML_GUID,
|
||||||
|
&mut ZLUDA_ML_PATH_UTF8,
|
||||||
|
&mut ZLUDA_ML_PATH_UTF16,
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if !load_global_string(
|
||||||
|
&PAYLOAD_NVCUDA_GUID,
|
||||||
|
&mut ZLUDA_PATH_UTF8,
|
||||||
|
&mut ZLUDA_PATH_UTF16,
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
fn load_global_string(
|
||||||
unsafe fn detour_already_loaded_nvcuda() -> Option<DetourDetachGuard> {
|
guid: &detours_sys::GUID,
|
||||||
let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _);
|
utf8_path: &mut Option<&'static [u8]>,
|
||||||
let detour_functions = vec![
|
utf16_path: &mut Vec<u16>,
|
||||||
(
|
) -> bool {
|
||||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
if let Some(payload) = get_payload(guid) {
|
||||||
ZludaLoadLibraryA as *mut c_void,
|
*utf8_path = Some(payload);
|
||||||
),
|
*utf16_path = unsafe { std::str::from_utf8_unchecked(payload) }
|
||||||
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
|
.encode_utf16()
|
||||||
(
|
.collect::<Vec<_>>();
|
||||||
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
|
true
|
||||||
ZludaLoadLibraryExA as _,
|
} else {
|
||||||
),
|
false
|
||||||
(
|
|
||||||
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
|
|
||||||
ZludaLoadLibraryExW as _,
|
|
||||||
),
|
|
||||||
];
|
|
||||||
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, HashMap::new())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> {
|
|
||||||
match get_payload(&PAYLOAD_NVCUDA_GUID) {
|
|
||||||
None => None,
|
|
||||||
Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
|
|
||||||
None => return None,
|
|
||||||
Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue