mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Add missing support for Milestone 1
This commit is contained in:
parent
42bcd999eb
commit
e0190fcbe1
18 changed files with 980 additions and 615 deletions
|
@ -726,6 +726,11 @@ impl<'a> Kernel<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub unsafe fn set_arg_raw(&self, index: u32, size: usize, value: *const c_void) -> Result<()> {
|
||||
check!(sys::zeKernelSetArgumentValue(self.0, index, size, value));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_group_size(&self, x: u32, y: u32, z: u32) -> Result<()> {
|
||||
check!(sys::zeKernelSetGroupSize(self.0, x, y, z));
|
||||
Ok(())
|
||||
|
|
File diff suppressed because it is too large
Load diff
52
notcuda/src/impl/function.rs
Normal file
52
notcuda/src/impl/function.rs
Normal file
|
@ -0,0 +1,52 @@
|
|||
use ::std::os::raw::{c_uint, c_void};
|
||||
use std::ptr;
|
||||
|
||||
use super::{context, device, stream::Stream, CUresult};
|
||||
|
||||
pub struct Function {
|
||||
pub base: l0::Kernel<'static>,
|
||||
pub arg_size: Vec<usize>,
|
||||
}
|
||||
|
||||
pub fn launch_kernel(
|
||||
f: *mut Function,
|
||||
grid_dim_x: c_uint,
|
||||
grid_dim_y: c_uint,
|
||||
grid_dim_z: c_uint,
|
||||
block_dim_x: c_uint,
|
||||
block_dim_y: c_uint,
|
||||
block_dim_z: c_uint,
|
||||
shared_mem_bytes: c_uint,
|
||||
strean: *mut Stream,
|
||||
kernel_params: *mut *mut c_void,
|
||||
extra: *mut *mut c_void,
|
||||
) -> Result<(), CUresult> {
|
||||
if f == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
if shared_mem_bytes != 0 || strean != ptr::null_mut() || extra != ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED);
|
||||
}
|
||||
let func = unsafe { &*f };
|
||||
for (i, arg_size) in func.arg_size.iter().copied().enumerate() {
|
||||
unsafe {
|
||||
func.base
|
||||
.set_arg_raw(i as u32, arg_size, *kernel_params.add(i))?
|
||||
};
|
||||
}
|
||||
unsafe { &*f }
|
||||
.base
|
||||
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
|
||||
device::with_current_exclusive(|dev| {
|
||||
let mut cmd_list = l0::CommandList::new(&mut dev.l0_context, &dev.base)?;
|
||||
cmd_list.append_launch_kernel(
|
||||
&unsafe { &*f }.base,
|
||||
&[grid_dim_x, grid_dim_y, grid_dim_z],
|
||||
None,
|
||||
&mut [],
|
||||
)?;
|
||||
dev.default_queue.execute(cmd_list)?;
|
||||
l0::Result::Ok(())
|
||||
})??;
|
||||
Ok(())
|
||||
}
|
|
@ -46,6 +46,10 @@ unsafe fn memcpy_impl(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn free_v2(mem: *mut c_void)-> l0::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::test::CudaDriverFns;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunction, CUmod_st, CUmodule, CUresult};
|
||||
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUfunction, CUmod_st, CUmodule, CUresult, CUstream, CUstream_st};
|
||||
use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -9,6 +9,8 @@ pub mod device;
|
|||
pub mod export_table;
|
||||
pub mod memory;
|
||||
pub mod module;
|
||||
pub mod function;
|
||||
pub mod stream;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub fn unimplemented() -> CUresult {
|
||||
|
@ -242,6 +244,10 @@ impl<'a> CudaRepr for CUmod_st {
|
|||
type Impl = module::Module;
|
||||
}
|
||||
|
||||
impl<'a> CudaRepr for CUfunction {
|
||||
type Impl = *mut module::Function;
|
||||
impl<'a> CudaRepr for CUfunc_st {
|
||||
type Impl = function::Function;
|
||||
}
|
||||
|
||||
impl<'a> CudaRepr for CUstream_st {
|
||||
type Impl = stream::Stream;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::{ffi::c_void, ffi::CStr, mem, os::raw::c_char, ptr, slice, sync::Mutex};
|
||||
use std::{
|
||||
collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice,
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use super::{transmute_lifetime, CUresult};
|
||||
use super::{function::Function, transmute_lifetime, CUresult};
|
||||
use ptx;
|
||||
|
||||
use super::context;
|
||||
|
@ -9,6 +12,7 @@ pub type Module = Mutex<ModuleData>;
|
|||
|
||||
pub struct ModuleData {
|
||||
base: l0::Module,
|
||||
arg_lens: HashMap<CString, Vec<usize>>,
|
||||
}
|
||||
|
||||
pub enum ModuleCompileError<'a> {
|
||||
|
@ -52,7 +56,7 @@ impl ModuleData {
|
|||
Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
|
||||
Ok(ast) => ast,
|
||||
};
|
||||
let spirv = ptx::to_spirv(ast)?;
|
||||
let (spirv, all_arg_lens) = ptx::to_spirv(ast)?;
|
||||
let byte_il = unsafe {
|
||||
slice::from_raw_parts::<u8>(
|
||||
spirv.as_ptr() as *const _,
|
||||
|
@ -63,17 +67,19 @@ impl ModuleData {
|
|||
l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
|
||||
});
|
||||
match module {
|
||||
Ok(Ok(module)) => Ok(Mutex::new(Self { base: module })),
|
||||
Ok(Ok(module)) => Ok(Mutex::new(Self {
|
||||
base: module,
|
||||
arg_lens: all_arg_lens
|
||||
.into_iter()
|
||||
.map(|(k, v)| (CString::new(k).unwrap(), v))
|
||||
.collect(),
|
||||
})),
|
||||
Ok(Err(err)) => Err(ModuleCompileError::from(err)),
|
||||
Err(err) => Err(ModuleCompileError::from(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Function {
|
||||
base: l0::Kernel<'static>,
|
||||
}
|
||||
|
||||
pub fn get_function(
|
||||
hfunc: *mut *mut Function,
|
||||
hmod: *mut Module,
|
||||
|
@ -83,10 +89,33 @@ pub fn get_function(
|
|||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
let name = unsafe { CStr::from_ptr(name) };
|
||||
let kernel = unsafe { &*hmod }
|
||||
let (mut kernel, args_len) = unsafe { &*hmod }
|
||||
.try_lock()
|
||||
.map(|module| l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name))
|
||||
.map(|module| {
|
||||
Result::<_, CUresult>::Ok((
|
||||
l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?,
|
||||
module
|
||||
.arg_lens
|
||||
.get(name)
|
||||
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?
|
||||
.clone(),
|
||||
))
|
||||
})
|
||||
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
|
||||
unsafe { *hfunc = Box::into_raw(Box::new(Function { base: kernel })) };
|
||||
kernel.set_indirect_access(
|
||||
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
|
||||
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
|
||||
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
|
||||
)?;
|
||||
unsafe {
|
||||
*hfunc = Box::into_raw(Box::new(Function {
|
||||
base: kernel,
|
||||
arg_size: args_len,
|
||||
}))
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn unload(decuda: *mut Module) -> Result<(), CUresult> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
69
notcuda/src/impl/stream.rs
Normal file
69
notcuda/src/impl/stream.rs
Normal file
|
@ -0,0 +1,69 @@
|
|||
use std::cell::RefCell;
|
||||
|
||||
use device::Device;
|
||||
|
||||
use super::device;
|
||||
|
||||
pub struct Stream {
|
||||
dev: *mut Device,
|
||||
}
|
||||
|
||||
pub struct DefaultStream {
|
||||
streams: Vec<Option<Stream>>,
|
||||
}
|
||||
|
||||
impl DefaultStream {
|
||||
fn new() -> Self {
|
||||
DefaultStream {
|
||||
streams: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
pub static DEFAULT_STREAM: RefCell<DefaultStream> = RefCell::new(DefaultStream::new());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::cuda::CUstream;
|
||||
|
||||
use super::super::test::CudaDriverFns;
|
||||
use super::super::CUresult;
|
||||
use std::{ffi::c_void, ptr};
|
||||
|
||||
const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
|
||||
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
|
||||
|
||||
cuda_driver_test!(default_stream_uses_current_ctx_legacy);
|
||||
cuda_driver_test!(default_stream_uses_current_ctx_ptsd);
|
||||
|
||||
fn default_stream_uses_current_ctx_legacy<T: CudaDriverFns>() {
|
||||
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_LEGACY);
|
||||
}
|
||||
|
||||
fn default_stream_uses_current_ctx_ptsd<T: CudaDriverFns>() {
|
||||
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_PER_THREAD);
|
||||
}
|
||||
|
||||
fn default_stream_uses_current_ctx_impl<T: CudaDriverFns>(stream: CUstream) {
|
||||
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
|
||||
let mut ctx1 = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx1, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
let mut stream_ctx1 = ptr::null_mut();
|
||||
assert_eq!(
|
||||
T::cuStreamGetCtx(stream, &mut stream_ctx1),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(ctx1, stream_ctx1);
|
||||
let mut ctx2 = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
assert_ne!(ctx1, ctx2);
|
||||
let mut stream_ctx2 = ptr::null_mut();
|
||||
assert_eq!(
|
||||
T::cuStreamGetCtx(stream, &mut stream_ctx2),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(ctx2, stream_ctx2);
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::r#impl as notcuda;
|
||||
use crate::{cuda::CUcontext, cuda::CUstream, r#impl as notcuda};
|
||||
use crate::r#impl::CUresult;
|
||||
use crate::{cuda::CUuuid, r#impl::Encuda};
|
||||
use ::std::{
|
||||
|
@ -36,14 +36,14 @@ pub trait CudaDriverFns {
|
|||
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult;
|
||||
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
|
||||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
|
||||
}
|
||||
|
||||
pub struct NotCuda();
|
||||
|
||||
impl CudaDriverFns for NotCuda {
|
||||
fn cuInit(_flags: c_uint) -> CUresult {
|
||||
assert!(notcuda::context::is_context_stack_empty());
|
||||
notcuda::init().encuda()
|
||||
crate::cuda::cuInit(_flags as _)
|
||||
}
|
||||
|
||||
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
|
||||
|
@ -76,6 +76,10 @@ impl CudaDriverFns for NotCuda {
|
|||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
|
||||
notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda()
|
||||
}
|
||||
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
|
||||
crate::cuda::cuStreamGetCtx(hStream, pctx as _)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Cuda();
|
||||
|
@ -115,4 +119,8 @@ impl CudaDriverFns for Cuda {
|
|||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -320,7 +320,7 @@ pub enum Instruction<P: ArgParams> {
|
|||
MovVector(MovVectorDetails, Arg2Vec<P>),
|
||||
Mul(MulDetails, Arg3<P>),
|
||||
Add(AddDetails, Arg3<P>),
|
||||
Setp(SetpData, Arg4<P>),
|
||||
Setp(SetpData, Arg4Setp<P>),
|
||||
SetpBool(SetpBoolData, Arg5<P>),
|
||||
Not(NotType, Arg2<P>),
|
||||
Bra(BraData, Arg1<P>),
|
||||
|
@ -331,8 +331,12 @@ pub enum Instruction<P: ArgParams> {
|
|||
Ret(RetData),
|
||||
Call(CallInst<P>),
|
||||
Abs(AbsDetails, Arg2<P>),
|
||||
Mad(MulDetails, Arg4<P>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MadFloatDesc {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MovVectorDetails {
|
||||
pub typ: MovVectorType,
|
||||
|
@ -398,6 +402,13 @@ pub struct Arg3<P: ArgParams> {
|
|||
}
|
||||
|
||||
pub struct Arg4<P: ArgParams> {
|
||||
pub dst: P::ID,
|
||||
pub src1: P::Operand,
|
||||
pub src2: P::Operand,
|
||||
pub src3: P::Operand,
|
||||
}
|
||||
|
||||
pub struct Arg4Setp<P: ArgParams> {
|
||||
pub dst1: P::ID,
|
||||
pub dst2: Option<P::ID>,
|
||||
pub src1: P::Operand,
|
||||
|
@ -503,7 +514,7 @@ sub_scalar_type!(MovVectorType {
|
|||
|
||||
pub struct MovDetails {
|
||||
pub typ: MovType,
|
||||
pub src_is_address: bool
|
||||
pub src_is_address: bool,
|
||||
}
|
||||
|
||||
sub_type! {
|
||||
|
@ -518,17 +529,20 @@ pub enum MulDetails {
|
|||
Float(MulFloatDesc),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulIntDesc {
|
||||
pub typ: IntType,
|
||||
pub control: MulIntControl,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum MulIntControl {
|
||||
Low,
|
||||
High,
|
||||
Wide,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulFloatDesc {
|
||||
pub typ: FloatType,
|
||||
pub rounding: Option<RoundingMode>,
|
||||
|
|
|
@ -122,6 +122,7 @@ match {
|
|||
"cvta",
|
||||
"debug",
|
||||
"ld",
|
||||
"mad",
|
||||
"map_f64_to_f32",
|
||||
"mov",
|
||||
"mul",
|
||||
|
@ -149,6 +150,7 @@ ExtendedID : &'input str = {
|
|||
"cvta",
|
||||
"debug",
|
||||
"ld",
|
||||
"mad",
|
||||
"map_f64_to_f32",
|
||||
"mov",
|
||||
"mul",
|
||||
|
@ -442,6 +444,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
InstCvta,
|
||||
InstCall,
|
||||
InstAbs,
|
||||
InstMad
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||
|
@ -649,7 +652,7 @@ InstAddMode: ast::AddDetails = {
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
|
||||
// TODO: support f16 setp
|
||||
InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"setp" <d:SetpMode> <a:Arg4> => ast::Instruction::Setp(d, a),
|
||||
"setp" <d:SetpMode> <a:Arg4Setp> => ast::Instruction::Setp(d, a),
|
||||
"setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a),
|
||||
};
|
||||
|
||||
|
@ -995,6 +998,13 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
},
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
|
||||
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"mad" <d:InstMulMode> <a:Arg4> => ast::Instruction::Mad(d, a),
|
||||
"mad" ".hi" ".sat" ".s32" => todo!()
|
||||
};
|
||||
|
||||
SignedIntType: ast::ScalarType = {
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
|
@ -1056,7 +1066,11 @@ Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {
|
|||
};
|
||||
|
||||
Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = {
|
||||
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>}
|
||||
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>}
|
||||
};
|
||||
|
||||
Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = {
|
||||
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4Setp{<>}
|
||||
};
|
||||
|
||||
// TODO: pass src3 negation somewhere
|
||||
|
|
|
@ -40,3 +40,10 @@ fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), TranslateError> {
|
|||
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
|
||||
compile_and_assert(vector_add)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn vectorAdd_11_ptx() -> Result<(), TranslateError> {
|
||||
let vector_add = include_str!("vectorAdd_11.ptx");
|
||||
compile_and_assert(vector_add)
|
||||
}
|
||||
|
|
28
ptx/src/test/spirv_run/mad_s32.ptx
Normal file
28
ptx/src/test/spirv_run/mad_s32.ptx
Normal file
|
@ -0,0 +1,28 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry mad_s32(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s32 dst;
|
||||
.reg .s32 src1;
|
||||
.reg .s32 src2;
|
||||
.reg .s32 src3;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.s32 src1, [in_addr];
|
||||
ld.s32 src2, [in_addr+4];
|
||||
ld.s32 src3, [in_addr+8];
|
||||
mad.lo.s32 dst, src1, src2, src3;
|
||||
st.s32 [out_addr], dst;
|
||||
st.s32 [out_addr+4], dst;
|
||||
st.s32 [out_addr+8], dst;
|
||||
ret;
|
||||
}
|
77
ptx/src/test/spirv_run/mad_s32.spvtxt
Normal file
77
ptx/src/test/spirv_run/mad_s32.spvtxt
Normal file
|
@ -0,0 +1,77 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
OpCapability Float64
|
||||
%48 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "mad_s32"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%51 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%ulong_4 = OpConstant %ulong 4
|
||||
%ulong_8 = OpConstant %ulong 8
|
||||
%ulong_4_0 = OpConstant %ulong 4
|
||||
%ulong_8_0 = OpConstant %ulong 8
|
||||
%1 = OpFunction %void None %51
|
||||
%10 = OpFunctionParameter %ulong
|
||||
%11 = OpFunctionParameter %ulong
|
||||
%46 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
%7 = OpVariable %_ptr_Function_uint Function
|
||||
%8 = OpVariable %_ptr_Function_uint Function
|
||||
%9 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %2 %10
|
||||
OpStore %3 %11
|
||||
%13 = OpLoad %ulong %2
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %4 %12
|
||||
%15 = OpLoad %ulong %3
|
||||
%14 = OpCopyObject %ulong %15
|
||||
OpStore %5 %14
|
||||
%17 = OpLoad %ulong %4
|
||||
%40 = OpConvertUToPtr %_ptr_Generic_uint %17
|
||||
%16 = OpLoad %uint %40
|
||||
OpStore %7 %16
|
||||
%19 = OpLoad %ulong %4
|
||||
%33 = OpIAdd %ulong %19 %ulong_4
|
||||
%41 = OpConvertUToPtr %_ptr_Generic_uint %33
|
||||
%18 = OpLoad %uint %41
|
||||
OpStore %8 %18
|
||||
%21 = OpLoad %ulong %4
|
||||
%35 = OpIAdd %ulong %21 %ulong_8
|
||||
%42 = OpConvertUToPtr %_ptr_Generic_uint %35
|
||||
%20 = OpLoad %uint %42
|
||||
OpStore %9 %20
|
||||
%23 = OpLoad %uint %7
|
||||
%24 = OpLoad %uint %8
|
||||
%25 = OpLoad %uint %9
|
||||
%56 = OpIMul %uint %23 %24
|
||||
%22 = OpIAdd %uint %25 %56
|
||||
OpStore %6 %22
|
||||
%26 = OpLoad %ulong %5
|
||||
%27 = OpLoad %uint %6
|
||||
%43 = OpConvertUToPtr %_ptr_Generic_uint %26
|
||||
OpStore %43 %27
|
||||
%28 = OpLoad %ulong %5
|
||||
%29 = OpLoad %uint %6
|
||||
%37 = OpIAdd %ulong %28 %ulong_4_0
|
||||
%44 = OpConvertUToPtr %_ptr_Generic_uint %37
|
||||
OpStore %44 %29
|
||||
%30 = OpLoad %ulong %5
|
||||
%31 = OpLoad %uint %6
|
||||
%39 = OpIAdd %ulong %30 %ulong_8_0
|
||||
%45 = OpConvertUToPtr %_ptr_Generic_uint %39
|
||||
OpStore %45 %31
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -8,7 +8,6 @@ use spirv_headers::Word;
|
|||
use spirv_tools_sys::{
|
||||
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
|
||||
};
|
||||
use std::{collections::hash_map::Entry, cmp};
|
||||
use std::error;
|
||||
use std::ffi::{c_void, CStr, CString};
|
||||
use std::fmt;
|
||||
|
@ -17,6 +16,7 @@ use std::hash::Hash;
|
|||
use std::mem;
|
||||
use std::slice;
|
||||
use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str};
|
||||
use std::{cmp, collections::hash_map::Entry};
|
||||
|
||||
macro_rules! test_ptx {
|
||||
($fn_name:ident, $input:expr, $output:expr) => {
|
||||
|
@ -65,6 +65,8 @@ test_ptx!(mov_address, [0xDEADu64], [0u64]);
|
|||
test_ptx!(b64tof64, [111u64], [111u64]);
|
||||
test_ptx!(implicit_param, [34u32], [34u32]);
|
||||
test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
|
||||
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
|
||||
test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
@ -93,7 +95,7 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
|
|||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
|
||||
assert!(errors.len() == 0);
|
||||
let spirv = translate::to_spirv(ast)?;
|
||||
let (spirv, _) = translate::to_spirv(ast)?;
|
||||
let name = CString::new(name)?;
|
||||
let result =
|
||||
run_spirv(name.as_c_str(), &spirv, input, output).map_err(|err| DisplayError { err })?;
|
||||
|
@ -127,7 +129,7 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
|||
kernel.set_indirect_access(
|
||||
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
|
||||
)?;
|
||||
let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(),1))?;
|
||||
let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(), 1))?;
|
||||
let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
|
||||
let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into();
|
||||
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
|
||||
|
@ -157,7 +159,7 @@ fn test_spvtxt_assert<'a>(
|
|||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
|
||||
assert!(errors.len() == 0);
|
||||
let ptx_mod = translate::to_spirv_module(ast)?;
|
||||
let (ptx_mod, _) = translate::to_spirv_module(ast)?;
|
||||
let spv_context =
|
||||
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
||||
assert!(spv_context != ptr::null_mut());
|
||||
|
|
24
ptx/src/test/spirv_run/mul_wide.ptx
Normal file
24
ptx/src/test/spirv_run/mul_wide.ptx
Normal file
|
@ -0,0 +1,24 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry mul_wide(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s32 inp1;
|
||||
.reg .s32 inp2;
|
||||
.reg .s64 result;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.global.s32 inp1, [in_addr];
|
||||
ld.global.s32 inp2, [in_addr+4];
|
||||
mul.wide.s32 result, inp1, inp2;
|
||||
st.u64 [out_addr], result;
|
||||
ret;
|
||||
}
|
64
ptx/src/test/spirv_run/mul_wide.spvtxt
Normal file
64
ptx/src/test/spirv_run/mul_wide.spvtxt
Normal file
|
@ -0,0 +1,64 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
OpCapability Float64
|
||||
%32 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "mul_wide"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%35 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
|
||||
%ulong_4 = OpConstant %ulong 4
|
||||
%_struct_40 = OpTypeStruct %uint %uint
|
||||
%v2uint = OpTypeVector %uint 2
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%1 = OpFunction %void None %35
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%10 = OpFunctionParameter %ulong
|
||||
%30 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
%7 = OpVariable %_ptr_Function_uint Function
|
||||
%8 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %9
|
||||
OpStore %3 %10
|
||||
%12 = OpLoad %ulong %2
|
||||
%11 = OpCopyObject %ulong %12
|
||||
OpStore %4 %11
|
||||
%14 = OpLoad %ulong %3
|
||||
%13 = OpCopyObject %ulong %14
|
||||
OpStore %5 %13
|
||||
%16 = OpLoad %ulong %4
|
||||
%26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16
|
||||
%15 = OpLoad %uint %26
|
||||
OpStore %6 %15
|
||||
%18 = OpLoad %ulong %4
|
||||
%25 = OpIAdd %ulong %18 %ulong_4
|
||||
%27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %25
|
||||
%17 = OpLoad %uint %27
|
||||
OpStore %7 %17
|
||||
%20 = OpLoad %uint %6
|
||||
%21 = OpLoad %uint %7
|
||||
%41 = OpSMulExtended %_struct_40 %20 %21
|
||||
%42 = OpCompositeExtract %uint %41 0
|
||||
%43 = OpCompositeExtract %uint %41 1
|
||||
%45 = OpCompositeConstruct %v2uint %42 %43
|
||||
%19 = OpBitcast %ulong %45
|
||||
OpStore %8 %19
|
||||
%22 = OpLoad %ulong %5
|
||||
%23 = OpLoad %ulong %8
|
||||
%28 = OpCopyObject %ulong %23
|
||||
%29 = OpConvertUToPtr %_ptr_Generic_ulong %22
|
||||
OpStore %29 %28
|
||||
OpReturn
|
||||
OpFunctionEnd
|
55
ptx/src/test/vectorAdd_11.ptx
Normal file
55
ptx/src/test/vectorAdd_11.ptx
Normal file
|
@ -0,0 +1,55 @@
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.version 7.0
|
||||
.target sm_80
|
||||
.address_size 64
|
||||
|
||||
|
||||
|
||||
.visible .entry _Z9vectorAddPKfS0_Pfi(
|
||||
.param .u64 _Z9vectorAddPKfS0_Pfi_param_0,
|
||||
.param .u64 _Z9vectorAddPKfS0_Pfi_param_1,
|
||||
.param .u64 _Z9vectorAddPKfS0_Pfi_param_2,
|
||||
.param .u32 _Z9vectorAddPKfS0_Pfi_param_3
|
||||
)
|
||||
{
|
||||
.reg .pred %p<2>;
|
||||
.reg .f32 %f<4>;
|
||||
.reg .b32 %r<6>;
|
||||
.reg .b64 %rd<11>;
|
||||
|
||||
|
||||
ld.param.u64 %rd1, [_Z9vectorAddPKfS0_Pfi_param_0];
|
||||
ld.param.u64 %rd2, [_Z9vectorAddPKfS0_Pfi_param_1];
|
||||
ld.param.u64 %rd3, [_Z9vectorAddPKfS0_Pfi_param_2];
|
||||
ld.param.u32 %r2, [_Z9vectorAddPKfS0_Pfi_param_3];
|
||||
mov.u32 %r3, %ntid.x;
|
||||
mov.u32 %r4, %ctaid.x;
|
||||
mov.u32 %r5, %tid.x;
|
||||
mad.lo.s32 %r1, %r4, %r3, %r5;
|
||||
setp.ge.s32 %p1, %r1, %r2;
|
||||
@%p1 bra BB0_2;
|
||||
|
||||
cvta.to.global.u64 %rd4, %rd1;
|
||||
mul.wide.s32 %rd5, %r1, 4;
|
||||
add.s64 %rd6, %rd4, %rd5;
|
||||
cvta.to.global.u64 %rd7, %rd2;
|
||||
add.s64 %rd8, %rd7, %rd5;
|
||||
ld.global.f32 %f1, [%rd8];
|
||||
ld.global.f32 %f2, [%rd6];
|
||||
add.f32 %f3, %f2, %f1;
|
||||
cvta.to.global.u64 %rd9, %rd3;
|
||||
add.s64 %rd10, %rd9, %rd5;
|
||||
st.global.f32 [%rd10], %f3;
|
||||
|
||||
BB0_2:
|
||||
ret;
|
||||
}
|
||||
|
||||
|
|
@ -28,6 +28,7 @@ enum SpirvType {
|
|||
Array(SpirvScalarKey, u32),
|
||||
Pointer(Box<SpirvType>, spirv::StorageClass),
|
||||
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
|
||||
Struct(Vec<SpirvScalarKey>),
|
||||
}
|
||||
|
||||
impl SpirvType {
|
||||
|
@ -174,6 +175,16 @@ impl TypeWordMap {
|
|||
.entry(t)
|
||||
.or_insert_with(|| b.type_function(out_t, in_t))
|
||||
}
|
||||
SpirvType::Struct(ref underlying) => {
|
||||
let underlying_ids = underlying
|
||||
.iter()
|
||||
.map(|t| self.get_or_add_spirv_scalar(b, *t))
|
||||
.collect::<Vec<_>>();
|
||||
*self
|
||||
.complex
|
||||
.entry(t)
|
||||
.or_insert_with(|| b.type_struct(underlying_ids))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -201,7 +212,9 @@ impl TypeWordMap {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, TranslateError> {
|
||||
pub fn to_spirv_module<'a>(
|
||||
ast: ast::Module<'a>,
|
||||
) -> Result<(dr::Module, HashMap<String, Vec<usize>>), TranslateError> {
|
||||
let mut id_defs = GlobalStringIdResolver::new(1);
|
||||
let ssa_functions = ast
|
||||
.functions
|
||||
|
@ -218,17 +231,24 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
|
|||
emit_memory_model(&mut builder);
|
||||
let mut map = TypeWordMap::new(&mut builder);
|
||||
emit_builtins(&mut builder, &mut map, &id_defs);
|
||||
let mut args_len = HashMap::new();
|
||||
for f in ssa_functions {
|
||||
let f_body = match f.body {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
|
||||
emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?;
|
||||
emit_function_header(
|
||||
&mut builder,
|
||||
&mut map,
|
||||
&id_defs,
|
||||
f.func_directive,
|
||||
&mut args_len,
|
||||
)?;
|
||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
||||
builder.end_function()?;
|
||||
}
|
||||
Ok(builder.module())
|
||||
Ok((builder.module(), args_len))
|
||||
}
|
||||
|
||||
fn emit_builtins(
|
||||
|
@ -263,7 +283,12 @@ fn emit_function_header<'a>(
|
|||
map: &mut TypeWordMap,
|
||||
global: &GlobalStringIdResolver<'a>,
|
||||
func_directive: ast::MethodDecl<ExpandedArgParams>,
|
||||
all_args_lens: &mut HashMap<String, Vec<usize>>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if let ast::MethodDecl::Kernel(name, args) = &func_directive {
|
||||
let args_lens = args.iter().map(|param| param.v_type.width()).collect();
|
||||
all_args_lens.insert(name.to_string(), args_lens);
|
||||
}
|
||||
let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
|
||||
let fn_id = match func_directive {
|
||||
ast::MethodDecl::Kernel(name, _) => {
|
||||
|
@ -297,9 +322,11 @@ fn emit_function_header<'a>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, TranslateError> {
|
||||
let module = to_spirv_module(ast)?;
|
||||
Ok(module.assemble())
|
||||
pub fn to_spirv<'a>(
|
||||
ast: ast::Module<'a>,
|
||||
) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
|
||||
let (module, all_args_lens) = to_spirv_module(ast)?;
|
||||
Ok((module.assemble(), all_args_lens))
|
||||
}
|
||||
|
||||
fn emit_capabilities(builder: &mut dr::Builder) {
|
||||
|
@ -905,7 +932,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
ArgumentSemantics::PhysicalPointer => {
|
||||
let scalar_t = ast::ScalarType::U64;
|
||||
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
|
||||
let result_id = self.id_def.new_id(typ);
|
||||
let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: scalar_t,
|
||||
|
@ -1314,8 +1341,8 @@ fn emit_function_body_ops(
|
|||
let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
|
||||
let const_true = builder.constant_true(type_pred);
|
||||
let const_false = builder.constant_false(type_pred);
|
||||
builder.select(result_type, result_id, operand, const_false, const_true)
|
||||
},
|
||||
builder.select(result_type, result_id, operand, const_false, const_true)
|
||||
}
|
||||
_ => builder.not(result_type, result_id, operand),
|
||||
}?;
|
||||
}
|
||||
|
@ -1359,6 +1386,12 @@ fn emit_function_body_ops(
|
|||
builder.copy_object(result_type, Some(*dst), *src)?;
|
||||
}
|
||||
},
|
||||
ast::Instruction::Mad(mad, arg) => match mad {
|
||||
ast::MulDetails::Int(ref desc) => {
|
||||
emit_mad_int(builder, map, opencl, desc, arg)?
|
||||
}
|
||||
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
|
||||
},
|
||||
},
|
||||
Statement::LoadVar(arg, typ) => {
|
||||
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
|
||||
|
@ -1385,6 +1418,47 @@ fn emit_function_body_ops(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mad_int(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
desc: &ast::MulIntDesc,
|
||||
arg: &ast::Arg4<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||
match desc.control {
|
||||
ast::MulIntControl::Low => {
|
||||
let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
|
||||
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
|
||||
}
|
||||
ast::MulIntControl::High => {
|
||||
let cl_op = if desc.typ.is_signed() {
|
||||
spirv::CLOp::s_mad_hi
|
||||
} else {
|
||||
spirv::CLOp::u_mad_hi
|
||||
};
|
||||
builder.ext_inst(
|
||||
inst_type,
|
||||
Some(arg.dst),
|
||||
opencl,
|
||||
cl_op as spirv::Word,
|
||||
[arg.src1, arg.src2, arg.src3],
|
||||
)?;
|
||||
}
|
||||
ast::MulIntControl::Wide => todo!(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mad_float(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
desc: &ast::MulFloatDesc,
|
||||
arg: &ast::Arg4<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn emit_add_float(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -1529,7 +1603,7 @@ fn emit_setp(
|
|||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
setp: &ast::SetpData,
|
||||
arg: &ast::Arg4<ExpandedArgParams>,
|
||||
arg: &ast::Arg4Setp<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
if setp.flush_to_zero {
|
||||
todo!()
|
||||
|
@ -1607,6 +1681,7 @@ fn emit_mul_int(
|
|||
desc: &ast::MulIntDesc,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let instruction_type = ast::ScalarType::from(desc.typ);
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||
match desc.control {
|
||||
ast::MulIntControl::Low => {
|
||||
|
@ -1626,11 +1701,53 @@ fn emit_mul_int(
|
|||
[arg.src1, arg.src2],
|
||||
)?;
|
||||
}
|
||||
ast::MulIntControl::Wide => todo!(),
|
||||
ast::MulIntControl::Wide => {
|
||||
let mul_ext_type = SpirvType::Struct(vec![
|
||||
SpirvScalarKey::from(instruction_type),
|
||||
SpirvScalarKey::from(instruction_type),
|
||||
]);
|
||||
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
|
||||
let mul = if desc.typ.is_signed() {
|
||||
builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
|
||||
} else {
|
||||
builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
|
||||
};
|
||||
let instr_width = instruction_type.width();
|
||||
let instr_kind = instruction_type.kind();
|
||||
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
|
||||
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
|
||||
struct2_bitcast_to_wide(
|
||||
builder,
|
||||
map,
|
||||
SpirvScalarKey::from(instruction_type),
|
||||
inst_type,
|
||||
arg.dst,
|
||||
dst_type_id,
|
||||
mul,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Surprisingly, structs can't be bitcast, so we route everything through a vector
|
||||
fn struct2_bitcast_to_wide(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
base_type_key: SpirvScalarKey,
|
||||
instruction_type: spirv::Word,
|
||||
dst: spirv::Word,
|
||||
dst_type_id: spirv::Word,
|
||||
src: spirv::Word,
|
||||
) -> Result<(), dr::Error> {
|
||||
let low_bits = builder.composite_extract(instruction_type, None, src, [0])?;
|
||||
let high_bits = builder.composite_extract(instruction_type, None, src, [1])?;
|
||||
let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
|
||||
let vector = builder.composite_construct(vector_type, None, [low_bits, high_bits])?;
|
||||
builder.bitcast(dst_type_id, Some(dst), vector)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_abs(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -1844,8 +1961,8 @@ impl PtxSpecialRegister {
|
|||
|
||||
fn get_builtin(self) -> spirv::BuiltIn {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid => spirv::BuiltIn::GlobalInvocationId,
|
||||
PtxSpecialRegister::Ntid => spirv::BuiltIn::GlobalSize,
|
||||
PtxSpecialRegister::Tid => spirv::BuiltIn::LocalInvocationId,
|
||||
PtxSpecialRegister::Ntid => spirv::BuiltIn::WorkgroupSize,
|
||||
PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId,
|
||||
PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups,
|
||||
}
|
||||
|
@ -2492,6 +2609,10 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
|
||||
ast::Instruction::Cvta(d, a.map(visitor, false, inst_type)?)
|
||||
}
|
||||
ast::Instruction::Mad(d, a) => {
|
||||
let inst_type = d.get_type();
|
||||
ast::Instruction::Mad(d, a.map(visitor, inst_type)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2641,7 +2762,8 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||
| ast::Instruction::St(_, _)
|
||||
| ast::Instruction::Ret(_)
|
||||
| ast::Instruction::Abs(_, _)
|
||||
| ast::Instruction::Call(_) => None,
|
||||
| ast::Instruction::Call(_)
|
||||
| ast::Instruction::Mad(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2741,6 +2863,17 @@ impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::VariableParamType {
|
||||
fn width(self) -> usize {
|
||||
match self {
|
||||
ast::VariableParamType::Scalar(t) => ast::ScalarType::from(t).width() as usize,
|
||||
ast::VariableParamType::Array(t, len) => {
|
||||
(ast::ScalarType::from(t).width() as usize) * (len as usize)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ArgParamsEx> ast::Arg1<T> {
|
||||
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||
self,
|
||||
|
@ -3042,6 +3175,53 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||
visitor: &mut V,
|
||||
t: ast::Type,
|
||||
) -> Result<ast::Arg4<U>, TranslateError> {
|
||||
let dst = visitor.variable(
|
||||
ArgumentDescriptor {
|
||||
op: self.dst,
|
||||
is_dst: true,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(t),
|
||||
)?;
|
||||
let src1 = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
op: self.src1,
|
||||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
t,
|
||||
)?;
|
||||
let src2 = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
op: self.src2,
|
||||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
t,
|
||||
)?;
|
||||
let src3 = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
op: self.src3,
|
||||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
t,
|
||||
)?;
|
||||
Ok(ast::Arg4 {
|
||||
dst,
|
||||
src1,
|
||||
src2,
|
||||
src3,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
|
||||
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||
self,
|
||||
visitor: &mut V,
|
||||
t: ast::Type,
|
||||
) -> Result<ast::Arg4Setp<U>, TranslateError> {
|
||||
let dst1 = visitor.variable(
|
||||
ArgumentDescriptor {
|
||||
op: self.dst1,
|
||||
|
@ -3079,7 +3259,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||
},
|
||||
t,
|
||||
)?;
|
||||
Ok(ast::Arg4 {
|
||||
Ok(ast::Arg4Setp {
|
||||
dst1,
|
||||
dst2,
|
||||
src1,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue