mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Add sub, min, max
This commit is contained in:
parent
bd3d440dba
commit
9a65dd32f5
12 changed files with 820 additions and 181 deletions
|
@ -241,6 +241,10 @@ sub_scalar_type!(IntType {
|
|||
S64
|
||||
});
|
||||
|
||||
sub_scalar_type!(UIntType { U8, U16, U32, U64 });
|
||||
|
||||
sub_scalar_type!(SIntType { S8, S16, S32, S64 });
|
||||
|
||||
impl IntType {
|
||||
pub fn is_signed(self) -> bool {
|
||||
match self {
|
||||
|
@ -331,7 +335,7 @@ pub enum Instruction<P: ArgParams> {
|
|||
Ld(LdDetails, Arg2Ld<P>),
|
||||
Mov(MovDetails, Arg2Mov<P>),
|
||||
Mul(MulDetails, Arg3<P>),
|
||||
Add(AddDetails, Arg3<P>),
|
||||
Add(ArithDetails, Arg3<P>),
|
||||
Setp(SetpData, Arg4Setp<P>),
|
||||
SetpBool(SetpBoolData, Arg5<P>),
|
||||
Not(NotType, Arg2<P>),
|
||||
|
@ -346,6 +350,9 @@ pub enum Instruction<P: ArgParams> {
|
|||
Abs(AbsDetails, Arg2<P>),
|
||||
Mad(MulDetails, Arg4<P>),
|
||||
Or(OrType, Arg3<P>),
|
||||
Sub(ArithDetails, Arg3<P>),
|
||||
Min(MinMaxDetails, Arg3<P>),
|
||||
Max(MinMaxDetails, Arg3<P>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
@ -554,11 +561,6 @@ impl MovDetails {
|
|||
}
|
||||
}
|
||||
|
||||
pub enum MulDetails {
|
||||
Int(MulIntDesc),
|
||||
Float(MulFloatDesc),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulIntDesc {
|
||||
pub typ: IntType,
|
||||
|
@ -572,14 +574,6 @@ pub enum MulIntControl {
|
|||
Wide,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulFloatDesc {
|
||||
pub typ: FloatType,
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: bool,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum RoundingMode {
|
||||
NearestEven,
|
||||
|
@ -588,23 +582,11 @@ pub enum RoundingMode {
|
|||
PositiveInf,
|
||||
}
|
||||
|
||||
pub enum AddDetails {
|
||||
Int(AddIntDesc),
|
||||
Float(AddFloatDesc),
|
||||
}
|
||||
|
||||
pub struct AddIntDesc {
|
||||
pub typ: IntType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
pub struct AddFloatDesc {
|
||||
pub typ: FloatType,
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: bool,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
pub struct SetpData {
|
||||
pub typ: ScalarType,
|
||||
pub flush_to_zero: bool,
|
||||
|
@ -810,3 +792,57 @@ sub_scalar_type!(OrType {
|
|||
B32,
|
||||
B64,
|
||||
});
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum MulDetails {
|
||||
Unsigned(MulUInt),
|
||||
Signed(MulSInt),
|
||||
Float(ArithFloat),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulUInt {
|
||||
pub typ: UIntType,
|
||||
pub control: MulIntControl,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulSInt {
|
||||
pub typ: SIntType,
|
||||
pub control: MulIntControl,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum ArithDetails {
|
||||
Unsigned(UIntType),
|
||||
Signed(ArithSInt),
|
||||
Float(ArithFloat),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArithSInt {
|
||||
pub typ: SIntType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArithFloat {
|
||||
pub typ: FloatType,
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: bool,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum MinMaxDetails {
|
||||
Signed(SIntType),
|
||||
Unsigned(UIntType),
|
||||
Float(MinMaxFloat),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MinMaxFloat {
|
||||
pub ftz: bool,
|
||||
pub nan: bool,
|
||||
pub typ: FloatType,
|
||||
}
|
||||
|
|
|
@ -70,6 +70,7 @@ match {
|
|||
".ltu",
|
||||
".lu",
|
||||
".nan",
|
||||
".NaN",
|
||||
".ne",
|
||||
".neu",
|
||||
".num",
|
||||
|
@ -124,6 +125,8 @@ match {
|
|||
"ld",
|
||||
"mad",
|
||||
"map_f64_to_f32",
|
||||
"max",
|
||||
"min",
|
||||
"mov",
|
||||
"mul",
|
||||
"not",
|
||||
|
@ -134,6 +137,7 @@ match {
|
|||
"shr",
|
||||
r"sm_[0-9]+" => ShaderModel,
|
||||
"st",
|
||||
"sub",
|
||||
"texmode_independent",
|
||||
"texmode_unified",
|
||||
} else {
|
||||
|
@ -153,6 +157,8 @@ ExtendedID : &'input str = {
|
|||
"ld",
|
||||
"mad",
|
||||
"map_f64_to_f32",
|
||||
"max",
|
||||
"min",
|
||||
"mov",
|
||||
"mul",
|
||||
"not",
|
||||
|
@ -163,6 +169,7 @@ ExtendedID : &'input str = {
|
|||
"shr",
|
||||
ShaderModel,
|
||||
"st",
|
||||
"sub",
|
||||
"texmode_independent",
|
||||
"texmode_unified",
|
||||
ID
|
||||
|
@ -448,7 +455,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
InstCall,
|
||||
InstAbs,
|
||||
InstMad,
|
||||
InstOr
|
||||
InstOr,
|
||||
InstSub,
|
||||
InstMin,
|
||||
InstMax,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||
|
@ -570,38 +580,19 @@ MovVectorType: ast::ScalarType = {
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
|
||||
InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a)
|
||||
"mul" <d:MulDetails> <a:Arg3> => ast::Instruction::Mul(d, a)
|
||||
};
|
||||
|
||||
InstMulMode: ast::MulDetails = {
|
||||
<ctr:MulIntControl> <t:IntType> => ast::MulDetails::Int(ast::MulIntDesc {
|
||||
MulDetails: ast::MulDetails = {
|
||||
<ctr:MulIntControl> <t:UIntType> => ast::MulDetails::Unsigned(ast::MulUInt{
|
||||
typ: t,
|
||||
control: ctr
|
||||
}),
|
||||
<r:RoundingModeFloat?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
typ: ast::FloatType::F32,
|
||||
rounding: r,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: s.is_some()
|
||||
<ctr:MulIntControl> <t:SIntType> => ast::MulDetails::Signed(ast::MulSInt{
|
||||
typ: t,
|
||||
control: ctr
|
||||
}),
|
||||
<r:RoundingModeFloat?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
typ: ast::FloatType::F64,
|
||||
rounding: r,
|
||||
flush_to_zero: false,
|
||||
saturate: false
|
||||
}),
|
||||
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
typ: ast::FloatType::F16,
|
||||
rounding: r.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: s.is_some()
|
||||
}),
|
||||
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc {
|
||||
typ: ast::FloatType::F16x2,
|
||||
rounding: r.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: s.is_some()
|
||||
})
|
||||
<f:ArithFloat> => ast::MulDetails::Float(f)
|
||||
};
|
||||
|
||||
MulIntControl: ast::MulIntControl = {
|
||||
|
@ -634,41 +625,23 @@ IntType : ast::IntType = {
|
|||
".s64" => ast::IntType::S64,
|
||||
};
|
||||
|
||||
UIntType: ast::UIntType = {
|
||||
".u16" => ast::UIntType::U16,
|
||||
".u32" => ast::UIntType::U32,
|
||||
".u64" => ast::UIntType::U64,
|
||||
};
|
||||
|
||||
SIntType: ast::SIntType = {
|
||||
".s16" => ast::SIntType::S16,
|
||||
".s32" => ast::SIntType::S32,
|
||||
".s64" => ast::SIntType::S64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
|
||||
InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a)
|
||||
};
|
||||
|
||||
InstAddMode: ast::AddDetails = {
|
||||
<t:IntType> => ast::AddDetails::Int(ast::AddIntDesc {
|
||||
typ: t,
|
||||
saturate: false,
|
||||
}),
|
||||
".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc {
|
||||
typ: ast::IntType::S32,
|
||||
saturate: true,
|
||||
}),
|
||||
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
|
||||
typ: ast::FloatType::F32,
|
||||
rounding: rn,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: sat.is_some(),
|
||||
}),
|
||||
<rn:RoundingModeFloat?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
|
||||
typ: ast::FloatType::F64,
|
||||
rounding: rn,
|
||||
flush_to_zero: false,
|
||||
saturate: false,
|
||||
}),
|
||||
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?>".f16" => ast::AddDetails::Float(ast::AddFloatDesc {
|
||||
typ: ast::FloatType::F16,
|
||||
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: sat.is_some(),
|
||||
}),
|
||||
".rn"? ".ftz"? ".sat"? ".f16x2" => todo!()
|
||||
"add" <d:ArithDetails> <a:Arg3> => ast::Instruction::Add(d, a)
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
|
||||
|
@ -1041,7 +1014,7 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
|
||||
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"mad" <d:InstMulMode> <a:Arg4> => ast::Instruction::Mad(d, a),
|
||||
"mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
|
||||
"mad" ".hi" ".sat" ".s32" => todo!()
|
||||
};
|
||||
|
||||
|
@ -1063,6 +1036,84 @@ OrType: ast::OrType = {
|
|||
".b64" => ast::OrType::B64,
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
|
||||
InstSub: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"sub" <d:ArithDetails> <a:Arg3> => ast::Instruction::Sub(d, a),
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min
|
||||
InstMin: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"min" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Min(d, a),
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max
|
||||
InstMax: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"max" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Max(d, a),
|
||||
};
|
||||
|
||||
MinMaxDetails: ast::MinMaxDetails = {
|
||||
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
|
||||
<t:SIntType> => ast::MinMaxDetails::Signed(t),
|
||||
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 }
|
||||
),
|
||||
".f64" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 }
|
||||
),
|
||||
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 }
|
||||
),
|
||||
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
|
||||
)
|
||||
}
|
||||
|
||||
ArithDetails: ast::ArithDetails = {
|
||||
<t:UIntType> => ast::ArithDetails::Unsigned(t),
|
||||
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: t,
|
||||
saturate: false,
|
||||
}),
|
||||
".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::S32,
|
||||
saturate: true,
|
||||
}),
|
||||
<f:ArithFloat> => ast::ArithDetails::Float(f)
|
||||
}
|
||||
|
||||
ArithFloat: ast::ArithFloat = {
|
||||
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F32,
|
||||
rounding: rn,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: sat.is_some(),
|
||||
},
|
||||
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F64,
|
||||
rounding: rn,
|
||||
flush_to_zero: false,
|
||||
saturate: false,
|
||||
},
|
||||
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F16,
|
||||
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: sat.is_some(),
|
||||
},
|
||||
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F16x2,
|
||||
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: sat.is_some(),
|
||||
},
|
||||
}
|
||||
|
||||
Operand: ast::Operand<&'input str> = {
|
||||
<r:ExtendedID> => ast::Operand::Reg(r),
|
||||
<r:ExtendedID> "+" <o:Num> => {
|
||||
|
|
23
ptx/src/test/spirv_run/max.ptx
Normal file
23
ptx/src/test/spirv_run/max.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry max(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s32 temp1;
|
||||
.reg .s32 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.s32 temp1, [in_addr];
|
||||
ld.s32 temp2, [in_addr+4];
|
||||
max.s32 temp1, temp1, temp2;
|
||||
st.s32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
57
ptx/src/test/spirv_run/max.spvtxt
Normal file
57
ptx/src/test/spirv_run/max.spvtxt
Normal file
|
@ -0,0 +1,57 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%30 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "max"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%33 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%ulong_4 = OpConstant %ulong 4
|
||||
%1 = OpFunction %void None %33
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%28 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
%7 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%25 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||
%14 = OpLoad %uint %25
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %4
|
||||
%24 = OpIAdd %ulong %17 %ulong_4
|
||||
%26 = OpConvertUToPtr %_ptr_Generic_uint %24
|
||||
%16 = OpLoad %uint %26
|
||||
OpStore %7 %16
|
||||
%19 = OpLoad %uint %6
|
||||
%20 = OpLoad %uint %7
|
||||
%18 = OpExtInst %uint %30 s_max %19 %20
|
||||
OpStore %6 %18
|
||||
%21 = OpLoad %ulong %5
|
||||
%22 = OpLoad %uint %6
|
||||
%27 = OpConvertUToPtr %_ptr_Generic_uint %21
|
||||
OpStore %27 %22
|
||||
OpReturn
|
||||
OpFunctionEnd
|
23
ptx/src/test/spirv_run/min.ptx
Normal file
23
ptx/src/test/spirv_run/min.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry min(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s32 temp1;
|
||||
.reg .s32 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.s32 temp1, [in_addr];
|
||||
ld.s32 temp2, [in_addr+4];
|
||||
min.s32 temp1, temp1, temp2;
|
||||
st.s32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
57
ptx/src/test/spirv_run/min.spvtxt
Normal file
57
ptx/src/test/spirv_run/min.spvtxt
Normal file
|
@ -0,0 +1,57 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%30 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "min"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%33 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%ulong_4 = OpConstant %ulong 4
|
||||
%1 = OpFunction %void None %33
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%28 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
%7 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%25 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||
%14 = OpLoad %uint %25
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %4
|
||||
%24 = OpIAdd %ulong %17 %ulong_4
|
||||
%26 = OpConvertUToPtr %_ptr_Generic_uint %24
|
||||
%16 = OpLoad %uint %26
|
||||
OpStore %7 %16
|
||||
%19 = OpLoad %uint %6
|
||||
%20 = OpLoad %uint %7
|
||||
%18 = OpExtInst %uint %30 s_min %19 %20
|
||||
OpStore %6 %18
|
||||
%21 = OpLoad %ulong %5
|
||||
%22 = OpLoad %uint %6
|
||||
%27 = OpConvertUToPtr %_ptr_Generic_uint %21
|
||||
OpStore %27 %22
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -70,6 +70,9 @@ test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64])
|
|||
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
|
||||
test_ptx!(shr, [-2i32], [-1i32]);
|
||||
test_ptx!(or, [1u64, 2u64], [3u64]);
|
||||
test_ptx!(sub, [2u64], [1u64]);
|
||||
test_ptx!(min, [555i32, 444i32], [444i32]);
|
||||
test_ptx!(max, [555i32, 444i32], [555i32]);
|
||||
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
|
|
23
ptx/src/test/spirv_run/or.ptx
Normal file
23
ptx/src/test/spirv_run/or.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry or(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp1;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp1, [in_addr];
|
||||
ld.u64 temp2, [in_addr+8];
|
||||
or.b64 temp1, temp1, temp2;
|
||||
st.u64 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
58
ptx/src/test/spirv_run/or.spvtxt
Normal file
58
ptx/src/test/spirv_run/or.spvtxt
Normal file
|
@ -0,0 +1,58 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%33 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "or"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%36 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%ulong_8 = OpConstant %ulong 8
|
||||
%1 = OpFunction %void None %36
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%31 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%25 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||
%14 = OpLoad %ulong %25
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %4
|
||||
%24 = OpIAdd %ulong %17 %ulong_8
|
||||
%26 = OpConvertUToPtr %_ptr_Generic_ulong %24
|
||||
%16 = OpLoad %ulong %26
|
||||
OpStore %7 %16
|
||||
%19 = OpLoad %ulong %6
|
||||
%20 = OpLoad %ulong %7
|
||||
%28 = OpCopyObject %ulong %19
|
||||
%29 = OpCopyObject %ulong %20
|
||||
%27 = OpBitwiseOr %ulong %28 %29
|
||||
%18 = OpCopyObject %ulong %27
|
||||
OpStore %6 %18
|
||||
%21 = OpLoad %ulong %5
|
||||
%22 = OpLoad %ulong %6
|
||||
%30 = OpConvertUToPtr %_ptr_Generic_ulong %21
|
||||
OpStore %30 %22
|
||||
OpReturn
|
||||
OpFunctionEnd
|
22
ptx/src/test/spirv_run/sub.ptx
Normal file
22
ptx/src/test/spirv_run/sub.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry sub(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp, [in_addr];
|
||||
sub.u64 temp2, temp, 1;
|
||||
st.u64 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
49
ptx/src/test/spirv_run/sub.spvtxt
Normal file
49
ptx/src/test/spirv_run/sub.spvtxt
Normal file
|
@ -0,0 +1,49 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%25 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "sub"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%28 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%ulong_1 = OpConstant %ulong 1
|
||||
%1 = OpFunction %void None %28
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%23 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||
%14 = OpLoad %ulong %21
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %6
|
||||
%16 = OpISub %ulong %17 %ulong_1
|
||||
OpStore %7 %16
|
||||
%18 = OpLoad %ulong %5
|
||||
%19 = OpLoad %ulong %7
|
||||
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||
OpStore %22 %19
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -595,6 +595,15 @@ fn convert_to_typed_statements(
|
|||
ast::Instruction::Or(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
|
||||
}
|
||||
ast::Instruction::Sub(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
|
||||
}
|
||||
ast::Instruction::Min(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
|
||||
}
|
||||
ast::Instruction::Max(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
|
||||
}
|
||||
},
|
||||
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||
|
@ -968,62 +977,74 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
fn reg_offset(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<(spirv::Word, i32)>,
|
||||
typ: ast::Type,
|
||||
mut typ: ast::Type,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
let (reg, offset) = desc.op;
|
||||
match desc.sema {
|
||||
ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
|
||||
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
|
||||
scalar
|
||||
} else {
|
||||
todo!()
|
||||
ArgumentSemantics::Default
|
||||
| ArgumentSemantics::DefaultRelaxed
|
||||
| ArgumentSemantics::PhysicalPointer => {
|
||||
if desc.sema == ArgumentSemantics::PhysicalPointer {
|
||||
typ = ast::Type::Scalar(ast::ScalarType::U64);
|
||||
}
|
||||
let (width, kind) = match typ {
|
||||
ast::Type::Scalar(scalar_t) => {
|
||||
let kind = match scalar_t.kind() {
|
||||
kind @ ScalarKind::Bit
|
||||
| kind @ ScalarKind::Unsigned
|
||||
| kind @ ScalarKind::Signed => kind,
|
||||
ScalarKind::Float => return Err(TranslateError::MismatchedType),
|
||||
ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
|
||||
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
(scalar_t.width(), kind)
|
||||
}
|
||||
_ => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
|
||||
let arith_detail = if kind == ScalarKind::Signed {
|
||||
ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::from_size(width),
|
||||
saturate: false,
|
||||
})
|
||||
} else {
|
||||
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
|
||||
};
|
||||
let id_constant_stmt = self.id_def.new_id(typ);
|
||||
let result_id = self.id_def.new_id(typ);
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: scalar_t,
|
||||
value: offset as i64,
|
||||
}));
|
||||
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
|
||||
self.func.push(Statement::Instruction(
|
||||
ast::Instruction::<ExpandedArgParams>::Add(
|
||||
ast::AddDetails::Int(ast::AddIntDesc {
|
||||
typ: int_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::Arg3 {
|
||||
dst: result_id,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
),
|
||||
));
|
||||
Ok(result_id)
|
||||
}
|
||||
ArgumentSemantics::PhysicalPointer => {
|
||||
let scalar_t = ast::ScalarType::U64;
|
||||
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
|
||||
let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: scalar_t,
|
||||
value: offset as i64,
|
||||
}));
|
||||
let int_type = ast::IntType::U64;
|
||||
self.func.push(Statement::Instruction(
|
||||
ast::Instruction::<ExpandedArgParams>::Add(
|
||||
ast::AddDetails::Int(ast::AddIntDesc {
|
||||
typ: int_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::Arg3 {
|
||||
dst: result_id,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
),
|
||||
));
|
||||
// TODO: check for edge cases around min value/max value/wrapping
|
||||
if offset < 0 && kind != ScalarKind::Signed {
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::from_parts(width, kind),
|
||||
value: -(offset as i64),
|
||||
}));
|
||||
self.func.push(Statement::Instruction(
|
||||
ast::Instruction::<ExpandedArgParams>::Sub(
|
||||
arith_detail,
|
||||
ast::Arg3 {
|
||||
dst: result_id,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
),
|
||||
));
|
||||
} else {
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::from_parts(width, kind),
|
||||
value: offset as i64,
|
||||
}));
|
||||
self.func.push(Statement::Instruction(
|
||||
ast::Instruction::<ExpandedArgParams>::Add(
|
||||
arith_detail,
|
||||
ast::Arg3 {
|
||||
dst: result_id,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(result_id)
|
||||
}
|
||||
ArgumentSemantics::RegisterPointer => {
|
||||
|
@ -1522,14 +1543,22 @@ fn emit_function_body_ops(
|
|||
}
|
||||
},
|
||||
ast::Instruction::Mul(mul, arg) => match mul {
|
||||
ast::MulDetails::Int(ref ctr) => {
|
||||
emit_mul_int(builder, map, opencl, ctr, arg)?;
|
||||
ast::MulDetails::Signed(ref ctr) => {
|
||||
emit_mul_sint(builder, map, opencl, ctr, arg)?
|
||||
}
|
||||
ast::MulDetails::Unsigned(ref ctr) => {
|
||||
emit_mul_uint(builder, map, opencl, ctr, arg)?
|
||||
}
|
||||
ast::MulDetails::Float(_) => todo!(),
|
||||
},
|
||||
ast::Instruction::Add(add, arg) => match add {
|
||||
ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?,
|
||||
ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
|
||||
ast::ArithDetails::Signed(ref desc) => {
|
||||
emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)?
|
||||
}
|
||||
ast::ArithDetails::Unsigned(ref desc) => {
|
||||
emit_add_int(builder, map, (*desc).into(), false, arg)?
|
||||
}
|
||||
ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
|
||||
},
|
||||
ast::Instruction::Setp(setp, arg) => {
|
||||
if arg.dst2.is_some() {
|
||||
|
@ -1581,8 +1610,11 @@ fn emit_function_body_ops(
|
|||
}
|
||||
ast::Instruction::SetpBool(_, _) => todo!(),
|
||||
ast::Instruction::Mad(mad, arg) => match mad {
|
||||
ast::MulDetails::Int(ref desc) => {
|
||||
emit_mad_int(builder, map, opencl, desc, arg)?
|
||||
ast::MulDetails::Signed(ref desc) => {
|
||||
emit_mad_sint(builder, map, opencl, desc, arg)?
|
||||
}
|
||||
ast::MulDetails::Unsigned(ref desc) => {
|
||||
emit_mad_uint(builder, map, opencl, desc, arg)?
|
||||
}
|
||||
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
|
||||
},
|
||||
|
@ -1594,6 +1626,23 @@ fn emit_function_body_ops(
|
|||
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
|
||||
}
|
||||
}
|
||||
ast::Instruction::Sub(d, arg) => match d {
|
||||
ast::ArithDetails::Signed(desc) => {
|
||||
emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?;
|
||||
}
|
||||
ast::ArithDetails::Unsigned(desc) => {
|
||||
emit_sub_int(builder, map, (*desc).into(), false, arg)?;
|
||||
}
|
||||
ast::ArithDetails::Float(desc) => {
|
||||
emit_sub_float(builder, map, desc, arg)?;
|
||||
}
|
||||
},
|
||||
ast::Instruction::Min(d, a) => {
|
||||
emit_min(builder, map, opencl, d, a)?;
|
||||
}
|
||||
ast::Instruction::Max(d, a) => {
|
||||
emit_max(builder, map, opencl, d, a)?;
|
||||
}
|
||||
},
|
||||
Statement::LoadVar(arg, typ) => {
|
||||
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
|
||||
|
@ -1624,11 +1673,11 @@ fn emit_function_body_ops(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mad_int(
|
||||
fn emit_mad_uint(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
desc: &ast::MulIntDesc,
|
||||
desc: &ast::MulUInt,
|
||||
arg: &ast::Arg4<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||
|
@ -1638,16 +1687,38 @@ fn emit_mad_int(
|
|||
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
|
||||
}
|
||||
ast::MulIntControl::High => {
|
||||
let cl_op = if desc.typ.is_signed() {
|
||||
spirv::CLOp::s_mad_hi
|
||||
} else {
|
||||
spirv::CLOp::u_mad_hi
|
||||
};
|
||||
builder.ext_inst(
|
||||
inst_type,
|
||||
Some(arg.dst),
|
||||
opencl,
|
||||
cl_op as spirv::Word,
|
||||
spirv::CLOp::u_mad_hi as spirv::Word,
|
||||
[arg.src1, arg.src2, arg.src3],
|
||||
)?;
|
||||
}
|
||||
ast::MulIntControl::Wide => todo!(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mad_sint(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
desc: &ast::MulSInt,
|
||||
arg: &ast::Arg4<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||
match desc.control {
|
||||
ast::MulIntControl::Low => {
|
||||
let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
|
||||
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
|
||||
}
|
||||
ast::MulIntControl::High => {
|
||||
builder.ext_inst(
|
||||
inst_type,
|
||||
Some(arg.dst),
|
||||
opencl,
|
||||
spirv::CLOp::s_mad_hi as spirv::Word,
|
||||
[arg.src1, arg.src2, arg.src3],
|
||||
)?;
|
||||
}
|
||||
|
@ -1659,7 +1730,7 @@ fn emit_mad_int(
|
|||
fn emit_mad_float(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
desc: &ast::MulFloatDesc,
|
||||
desc: &ast::ArithFloat,
|
||||
arg: &ast::Arg4<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
todo!()
|
||||
|
@ -1668,7 +1739,7 @@ fn emit_mad_float(
|
|||
fn emit_add_float(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
desc: &ast::AddFloatDesc,
|
||||
desc: &ast::ArithFloat,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
if desc.flush_to_zero {
|
||||
|
@ -1680,6 +1751,67 @@ fn emit_add_float(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_sub_float(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
desc: &ast::ArithFloat,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
if desc.flush_to_zero {
|
||||
todo!()
|
||||
}
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||
builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_min(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
desc: &ast::MinMaxDetails,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let cl_op = match desc {
|
||||
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
|
||||
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
|
||||
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
|
||||
};
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
|
||||
builder.ext_inst(
|
||||
inst_type,
|
||||
Some(arg.dst),
|
||||
opencl,
|
||||
cl_op as spirv::Word,
|
||||
[arg.src1, arg.src2],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_max(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
desc: &ast::MinMaxDetails,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let cl_op = match desc {
|
||||
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
|
||||
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
|
||||
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
|
||||
};
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
|
||||
builder.ext_inst(
|
||||
inst_type,
|
||||
Some(arg.dst),
|
||||
opencl,
|
||||
cl_op as spirv::Word,
|
||||
[arg.src1, arg.src2],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_cvt(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -1880,11 +2012,11 @@ fn emit_setp(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mul_int(
|
||||
fn emit_mul_sint(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
desc: &ast::MulIntDesc,
|
||||
desc: &ast::MulSInt,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let instruction_type = ast::ScalarType::from(desc.typ);
|
||||
|
@ -1894,16 +2026,11 @@ fn emit_mul_int(
|
|||
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
}
|
||||
ast::MulIntControl::High => {
|
||||
let ocl_mul_hi = if desc.typ.is_signed() {
|
||||
spirv::CLOp::s_mul_hi
|
||||
} else {
|
||||
spirv::CLOp::u_mul_hi
|
||||
};
|
||||
builder.ext_inst(
|
||||
inst_type,
|
||||
Some(arg.dst),
|
||||
opencl,
|
||||
ocl_mul_hi as spirv::Word,
|
||||
spirv::CLOp::s_mul_hi as spirv::Word,
|
||||
[arg.src1, arg.src2],
|
||||
)?;
|
||||
}
|
||||
|
@ -1913,11 +2040,54 @@ fn emit_mul_int(
|
|||
SpirvScalarKey::from(instruction_type),
|
||||
]);
|
||||
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
|
||||
let mul = if desc.typ.is_signed() {
|
||||
builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
|
||||
} else {
|
||||
builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
|
||||
};
|
||||
let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
|
||||
let instr_width = instruction_type.width();
|
||||
let instr_kind = instruction_type.kind();
|
||||
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
|
||||
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
|
||||
struct2_bitcast_to_wide(
|
||||
builder,
|
||||
map,
|
||||
SpirvScalarKey::from(instruction_type),
|
||||
inst_type,
|
||||
arg.dst,
|
||||
dst_type_id,
|
||||
mul,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mul_uint(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
desc: &ast::MulUInt,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let instruction_type = ast::ScalarType::from(desc.typ);
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
|
||||
match desc.control {
|
||||
ast::MulIntControl::Low => {
|
||||
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
}
|
||||
ast::MulIntControl::High => {
|
||||
builder.ext_inst(
|
||||
inst_type,
|
||||
Some(arg.dst),
|
||||
opencl,
|
||||
spirv::CLOp::u_mul_hi as spirv::Word,
|
||||
[arg.src1, arg.src2],
|
||||
)?;
|
||||
}
|
||||
ast::MulIntControl::Wide => {
|
||||
let mul_ext_type = SpirvType::Struct(vec![
|
||||
SpirvScalarKey::from(instruction_type),
|
||||
SpirvScalarKey::from(instruction_type),
|
||||
]);
|
||||
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
|
||||
let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
|
||||
let instr_width = instruction_type.width();
|
||||
let instr_kind = instruction_type.kind();
|
||||
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
|
||||
|
@ -1981,14 +2151,33 @@ fn emit_abs(
|
|||
fn emit_add_int(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
ctr: &ast::AddIntDesc,
|
||||
typ: ast::ScalarType,
|
||||
saturate: bool,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ)));
|
||||
if saturate {
|
||||
todo!()
|
||||
}
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
|
||||
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_sub_int(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
typ: ast::ScalarType,
|
||||
saturate: bool,
|
||||
arg: &ast::Arg3<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
if saturate {
|
||||
todo!()
|
||||
}
|
||||
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
|
||||
builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_implicit_conversion(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -2920,6 +3109,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
t,
|
||||
a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
|
||||
),
|
||||
ast::Instruction::Sub(d, a) => {
|
||||
let typ = d.get_type();
|
||||
ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?)
|
||||
}
|
||||
ast::Instruction::Min(d, a) => {
|
||||
let typ = d.get_type();
|
||||
ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?)
|
||||
}
|
||||
ast::Instruction::Max(d, a) => {
|
||||
let typ = d.get_type();
|
||||
ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -3129,6 +3330,9 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||
| ast::Instruction::Abs(_, _)
|
||||
| ast::Instruction::Call(_)
|
||||
| ast::Instruction::Or(_, _)
|
||||
| ast::Instruction::Sub(_, _)
|
||||
| ast::Instruction::Min(_, _)
|
||||
| ast::Instruction::Max(_, _)
|
||||
| ast::Instruction::Mad(_, _) => None,
|
||||
}
|
||||
}
|
||||
|
@ -4049,25 +4253,33 @@ impl ast::ShrType {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::AddDetails {
|
||||
impl ast::ArithDetails {
|
||||
fn get_type(&self) -> ast::Type {
|
||||
match self {
|
||||
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
||||
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
|
||||
ast::Type::Scalar((*typ).into())
|
||||
}
|
||||
}
|
||||
ast::Type::Scalar(match self {
|
||||
ast::ArithDetails::Unsigned(t) => (*t).into(),
|
||||
ast::ArithDetails::Signed(d) => d.typ.into(),
|
||||
ast::ArithDetails::Float(d) => d.typ.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::MulDetails {
|
||||
fn get_type(&self) -> ast::Type {
|
||||
match self {
|
||||
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
||||
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
|
||||
ast::Type::Scalar((*typ).into())
|
||||
}
|
||||
}
|
||||
ast::Type::Scalar(match self {
|
||||
ast::MulDetails::Unsigned(d) => d.typ.into(),
|
||||
ast::MulDetails::Signed(d) => d.typ.into(),
|
||||
ast::MulDetails::Float(d) => d.typ.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::MinMaxDetails {
|
||||
fn get_type(&self) -> ast::Type {
|
||||
ast::Type::Scalar(match self {
|
||||
ast::MinMaxDetails::Signed(t) => (*t).into(),
|
||||
ast::MinMaxDetails::Unsigned(t) => (*t).into(),
|
||||
ast::MinMaxDetails::Float(d) => d.typ.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4085,6 +4297,30 @@ impl ast::IntType {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::SIntType {
|
||||
fn from_size(width: u8) -> Self {
|
||||
match width {
|
||||
1 => ast::SIntType::S8,
|
||||
2 => ast::SIntType::S16,
|
||||
4 => ast::SIntType::S32,
|
||||
8 => ast::SIntType::S64,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::UIntType {
|
||||
fn from_size(width: u8) -> Self {
|
||||
match width {
|
||||
1 => ast::UIntType::U8,
|
||||
2 => ast::UIntType::U16,
|
||||
4 => ast::UIntType::U32,
|
||||
8 => ast::UIntType::U64,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::LdStateSpace {
|
||||
fn to_spirv(self) -> spirv::StorageClass {
|
||||
match self {
|
||||
|
@ -4128,7 +4364,8 @@ impl<T> ast::OperandOrVector<T> {
|
|||
impl ast::MulDetails {
|
||||
fn is_wide(&self) -> bool {
|
||||
match self {
|
||||
ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
|
||||
ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide,
|
||||
ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide,
|
||||
ast::MulDetails::Float(_) => false,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue