Refactor L0 bindings

This commit is contained in:
Andrzej Janik 2021-05-27 02:05:17 +02:00
parent 58a7fe53c6
commit e40785aa74
9 changed files with 577 additions and 419 deletions

File diff suppressed because it is too large Load diff

View file

@ -201,8 +201,8 @@ impl<T: Debug> error::Error for DisplayError<T> {}
fn test_ptx_assert< fn test_ptx_assert<
'a, 'a,
Input: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq, Input: From<u8> + Debug + Copy + PartialEq,
Output: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq, Output: From<u8> + Debug + Copy + PartialEq,
>( >(
name: &str, name: &str,
ptx_text: &'a str, ptx_text: &'a str,
@ -220,10 +220,7 @@ fn test_ptx_assert<
Ok(()) Ok(())
} }
fn run_spirv< fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug>(
Input: From<u8> + ze::SafeRepr + Copy + Debug,
Output: From<u8> + ze::SafeRepr + Copy + Debug,
>(
name: &CStr, name: &CStr,
module: translate::Module, module: translate::Module,
input: &[Input], input: &[Input],
@ -242,25 +239,25 @@ fn run_spirv<
.get(name.to_str().unwrap()) .get(name.to_str().unwrap())
.map(|info| info.uses_shared_mem) .map(|info| info.uses_shared_mem)
.unwrap_or(false); .unwrap_or(false);
let mut result = vec![0u8.into(); output.len()]; let result = vec![0u8.into(); output.len()];
{ {
let mut drivers = ze::Driver::get()?; let mut drivers = ze::Driver::get()?;
let drv = drivers.drain(0..1).next().unwrap(); let drv = drivers.drain(0..1).next().unwrap();
let mut ctx = ze::Context::new(&drv)?;
let mut devices = drv.devices()?; let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap(); let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?; let ctx = ze::Context::new(drv, None)?;
let queue = ze::CommandQueue::new(&ctx, dev)?;
let (module, maybe_log) = match module.should_link_ptx_impl { let (module, maybe_log) = match module.should_link_ptx_impl {
Some(ptx_impl) => ze::Module::build_link_spirv( Some(ptx_impl) => ze::Module::build_link_spirv(
&mut ctx, &ctx,
&dev, dev,
&[ptx_impl, byte_il], &[ptx_impl, byte_il],
Some(module.build_options.as_c_str()), Some(module.build_options.as_c_str()),
), ),
None => { None => {
let (module, log) = ze::Module::build_spirv_logged( let (module, log) = ze::Module::build_spirv_logged(
&mut ctx, &ctx,
&dev, dev,
byte_il, byte_il,
Some(module.build_options.as_c_str()), Some(module.build_options.as_c_str()),
); );
@ -271,38 +268,38 @@ fn run_spirv<
Ok(m) => m, Ok(m) => m,
Err(err) => { Err(err) => {
let raw_err_string = maybe_log let raw_err_string = maybe_log
.map(|log| log.get_cstring()) .map(|log| log.to_cstring())
.transpose()? .transpose()?
.unwrap_or(CString::default()); .unwrap_or(CString::default());
let err_string = raw_err_string.to_string_lossy(); let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string); panic!("{:?}\n{}", err, err_string);
} }
}; };
let mut kernel = ze::Kernel::new_resident(&module, name)?; let kernel = ze::Kernel::new_resident(&module, name)?;
kernel.set_indirect_access( kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE, ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
)?; )?;
let mut inp_b = ze::DeviceBuffer::<Input>::new(&mut ctx, &dev, cmp::max(input.len(), 1))?; let inp_b = ze::DeviceBuffer::<Input>::new(&ctx, dev, cmp::max(input.len(), 1))?;
let mut out_b = ze::DeviceBuffer::<Output>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?; let out_b = ze::DeviceBuffer::<Output>::new(&ctx, dev, cmp::max(output.len(), 1))?;
let inp_b_ptr_mut: ze::BufferPtrMut<Input> = (&mut inp_b).into(); let event_pool = ze::EventPool::new(&ctx, 3, Some(&[dev]))?;
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?; let ev0 = ze::Event::new(&event_pool, 0)?;
let ev1 = ze::Event::new(&event_pool, 1)?; let ev1 = ze::Event::new(&event_pool, 1)?;
let mut ev2 = ze::Event::new(&event_pool, 2)?; let ev2 = ze::Event::new(&event_pool, 2)?;
let mut cmd_list = ze::CommandList::new(&mut ctx, &dev)?; {
let out_b_ptr_mut: ze::BufferPtrMut<Output> = (&mut out_b).into(); let cmd_list = ze::CommandList::new(&ctx, dev)?;
let mut init_evs = [ev0, ev1]; let init_evs = [ev0, ev1];
cmd_list.append_memory_copy(inp_b_ptr_mut, input, Some(&mut init_evs[0]), &mut [])?; cmd_list.append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?;
cmd_list.append_memory_fill(out_b_ptr_mut, 0, Some(&mut init_evs[1]), &mut [])?; cmd_list.append_memory_fill(&out_b, 0, Some(&init_evs[1]), &[])?;
kernel.set_group_size(1, 1, 1)?; kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, inp_b_ptr_mut)?; kernel.set_arg_buffer(0, &inp_b)?;
kernel.set_arg_buffer(1, out_b_ptr_mut)?; kernel.set_arg_buffer(1, &out_b)?;
if use_shared_mem { if use_shared_mem {
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? }; unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
}
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?;
cmd_list.append_memory_copy(&*result, &out_b, None, &[ev2])?;
queue.execute_and_synchronize(cmd_list)?;
} }
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
queue.execute(cmd_list)?;
} }
Ok(result) Ok(result)
} }

View file

@ -98,8 +98,8 @@ pub struct ContextData {
impl ContextData { impl ContextData {
pub fn new( pub fn new(
l0_ctx: &mut l0::Context, l0_ctx: &'static l0::Context,
l0_dev: &l0::Device, l0_dev: l0::Device,
flags: c_uint, flags: c_uint,
is_primary: bool, is_primary: bool,
dev: *mut device::Device, dev: *mut device::Device,
@ -137,7 +137,7 @@ pub fn create_v2(
let dev_ptr = dev as *mut _; let dev_ptr = dev as *mut _;
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
&mut dev.l0_context, &mut dev.l0_context,
&dev.base, dev.base,
flags, flags,
false, false,
dev_ptr as *mut _, dev_ptr as *mut _,

View file

@ -18,7 +18,7 @@ pub struct Index(pub c_int);
pub struct Device { pub struct Device {
pub index: Index, pub index: Index,
pub base: l0::Device, pub base: l0::Device,
pub default_queue: l0::CommandQueue, pub default_queue: l0::CommandQueue<'static>,
pub l0_context: l0::Context, pub l0_context: l0::Context,
pub primary_context: context::Context, pub primary_context: context::Context,
properties: Option<Box<l0::sys::ze_device_properties_t>>, properties: Option<Box<l0::sys::ze_device_properties_t>>,
@ -31,12 +31,13 @@ unsafe impl Send for Device {}
impl Device { impl Device {
// Unsafe because it does not fully initalize primary_context // Unsafe because it does not fully initalize primary_context
// and we transmute lifetimes left and right
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> { unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
let mut ctx = l0::Context::new(drv)?; let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?;
let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?; let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?;
let primary_context = context::Context::new(context::ContextData::new( let primary_context = context::Context::new(context::ContextData::new(
&mut ctx, mem::transmute(&ctx),
&l0_dev, l0_dev,
0, 0,
true, true,
ptr::null_mut(), ptr::null_mut(),
@ -58,20 +59,18 @@ impl Device {
if let Some(ref prop) = self.properties { if let Some(ref prop) = self.properties {
return Ok(prop); return Ok(prop);
} }
match self.base.get_properties() { let mut props = Default::default();
Ok(prop) => Ok(self.properties.get_or_insert(prop)), self.base.get_properties(&mut props)?;
Err(e) => Err(e), Ok(self.properties.get_or_insert(Box::new(props)))
}
} }
fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> { fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> {
if let Some(ref prop) = self.image_properties { if let Some(ref prop) = self.image_properties {
return Ok(prop); return Ok(prop);
} }
match self.base.get_image_properties() { let mut props = Default::default();
Ok(prop) => Ok(self.image_properties.get_or_insert(prop)), self.base.get_image_properties(&mut props)?;
Err(e) => Err(e), Ok(self.image_properties.get_or_insert(Box::new(props)))
}
} }
fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> { fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> {
@ -88,10 +87,9 @@ impl Device {
if let Some(ref prop) = self.compute_properties { if let Some(ref prop) = self.compute_properties {
return Ok(prop); return Ok(prop);
} }
match self.base.get_compute_properties() { let mut props = Default::default();
Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)), self.base.get_compute_properties(&mut props)?;
Err(e) => Err(e), Ok(self.compute_properties.get_or_insert(Box::new(props)))
}
} }
pub fn late_init(&mut self) { pub fn late_init(&mut self) {
@ -351,7 +349,11 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
} }
// TODO: add support if Level 0 exposes it // TODO: add support if Level 0 exposes it
pub fn get_luid(luid: *mut c_char, dev_node_mask: *mut c_uint, _dev_idx: Index) -> Result<(), CUresult> { pub fn get_luid(
luid: *mut c_char,
dev_node_mask: *mut c_uint,
_dev_idx: Index,
) -> Result<(), CUresult> {
unsafe { ptr::write_bytes(luid, 0u8, 8) }; unsafe { ptr::write_bytes(luid, 0u8, 8) };
unsafe { *dev_node_mask = 0 }; unsafe { *dev_node_mask = 0 };
Ok(()) Ok(())

View file

@ -144,14 +144,14 @@ pub fn launch_kernel(
func.base func.base
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?; .set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
func.legacy_args.reset(); func.legacy_args.reset();
let mut cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
cmd_list.append_launch_kernel( cmd_list.append_launch_kernel(
&mut func.base, &mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z], &[grid_dim_x, grid_dim_y, grid_dim_z],
None, None,
&mut [], &mut [],
)?; )?;
stream.queue.execute(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok(()) Ok(())
})? })?
} }

View file

@ -4,7 +4,7 @@ use std::{ffi::c_void, mem};
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> { pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let ptr = GlobalState::lock_current_context(|ctx| { let ptr = GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device }; let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?) Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
})??; })??;
unsafe { *dptr = ptr }; unsafe { *dptr = ptr };
Ok(()) Ok(())
@ -12,9 +12,9 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult>
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> { pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let mut cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?; unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? };
stream.queue.execute(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(()) Ok::<_, CUresult>(())
})? })?
} }
@ -22,29 +22,29 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> { pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
GlobalState::lock_current_context(|ctx| { GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device }; let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?) Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
}) })
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)? .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
} }
pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> { pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let mut cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
unsafe { unsafe {
cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut []) cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut [])
}?; }?;
stream.queue.execute(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(()) Ok::<_, CUresult>(())
})? })?
} }
pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> { pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let mut cmd_list = stream.command_list()?; let cmd_list = stream.command_list()?;
unsafe { unsafe {
cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut []) cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut [])
}?; }?;
stream.queue.execute(cmd_list)?; stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(()) Ok::<_, CUresult>(())
})? })?
} }

View file

@ -41,7 +41,7 @@ pub struct SpirvModule {
} }
pub struct CompiledModule { pub struct CompiledModule {
pub base: l0::Module, pub base: l0::Module<'static>,
pub kernels: HashMap<CString, Box<Function>>, pub kernels: HashMap<CString, Box<Function>>,
} }
@ -78,7 +78,11 @@ impl SpirvModule {
}) })
} }
pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> { pub fn compile<'a>(
&self,
ctx: &'a l0::Context,
dev: l0::Device,
) -> Result<l0::Module<'a>, CUresult> {
let byte_il = unsafe { let byte_il = unsafe {
slice::from_raw_parts( slice::from_raw_parts(
self.binaries.as_ptr() as *const u8, self.binaries.as_ptr() as *const u8,
@ -86,13 +90,11 @@ impl SpirvModule {
) )
}; };
let l0_module = match self.should_link_ptx_impl { let l0_module = match self.should_link_ptx_impl {
None => { None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())),
l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str()))
}
Some(ptx_impl) => { Some(ptx_impl) => {
l0::Module::build_link_spirv( l0::Module::build_link_spirv(
ctx, ctx,
&dev, dev,
&[ptx_impl, byte_il], &[ptx_impl, byte_il],
Some(self.build_options.as_c_str()), Some(self.build_options.as_c_str()),
) )
@ -119,7 +121,7 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => { hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule { let new_module = CompiledModule {
base: module.spirv.compile(&mut device.l0_context, &device.base)?, base: module.spirv.compile(&mut device.l0_context, device.base)?,
kernels: HashMap::new(), kernels: HashMap::new(),
}; };
entry.insert(new_module) entry.insert(new_module)
@ -135,7 +137,7 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
}) })
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
let mut kernel = let kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
kernel.set_indirect_access( kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
@ -165,7 +167,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| { let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device }; let device = unsafe { &mut *ctx.device };
let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?; let l0_module = spirv_data.compile(&device.l0_context, device.base)?;
let mut device_binaries = HashMap::new(); let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule { let compiled_module = CompiledModule {
base: l0_module, base: l0_module,

View file

@ -33,11 +33,11 @@ impl HasLivenessCookie for StreamData {
pub struct StreamData { pub struct StreamData {
pub context: *mut ContextData, pub context: *mut ContextData,
pub queue: l0::CommandQueue, pub queue: l0::CommandQueue<'static>,
} }
impl StreamData { impl StreamData {
pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> { pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> {
Ok(StreamData { Ok(StreamData {
context: ptr::null_mut(), context: ptr::null_mut(),
queue: l0::CommandQueue::new(ctx, dev)?, queue: l0::CommandQueue::new(ctx, dev)?,
@ -45,7 +45,7 @@ impl StreamData {
} }
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> { pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
let l0_dev = &unsafe { &*ctx.device }.base; let l0_dev = unsafe { &*ctx.device }.base;
Ok(StreamData { Ok(StreamData {
context: ctx as *mut _, context: ctx as *mut _,
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?, queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
@ -55,7 +55,7 @@ impl StreamData {
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> { pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
let ctx = unsafe { &mut *self.context }; let ctx = unsafe { &mut *self.context };
let dev = unsafe { &mut *ctx.device }; let dev = unsafe { &mut *ctx.device };
l0::CommandList::new(&mut dev.l0_context, &dev.base) l0::CommandList::new(&mut dev.l0_context, dev.base)
} }
} }

View file

@ -127,7 +127,8 @@ pub(crate) fn system_get_driver_version(
len: 0, len: 0,
}; };
for d in drivers { for d in drivers {
let props = d.get_properties()?; let mut props = Default::default();
d.get_properties(&mut props)?;
let driver_version = props.driverVersion; let driver_version = props.driverVersion;
write!(&mut output_write, "{}", driver_version) write!(&mut output_write, "{}", driver_version)
.map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?; .map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;