Add sub, min, max

This commit is contained in:
Andrzej Janik 2020-10-02 00:11:28 +02:00
parent bd3d440dba
commit 9a65dd32f5
12 changed files with 820 additions and 181 deletions

View file

@ -241,6 +241,10 @@ sub_scalar_type!(IntType {
S64
});
sub_scalar_type!(UIntType { U8, U16, U32, U64 });
sub_scalar_type!(SIntType { S8, S16, S32, S64 });
impl IntType {
pub fn is_signed(self) -> bool {
match self {
@ -331,7 +335,7 @@ pub enum Instruction<P: ArgParams> {
Ld(LdDetails, Arg2Ld<P>),
Mov(MovDetails, Arg2Mov<P>),
Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>),
Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>),
SetpBool(SetpBoolData, Arg5<P>),
Not(NotType, Arg2<P>),
@ -346,6 +350,9 @@ pub enum Instruction<P: ArgParams> {
Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>),
Or(OrType, Arg3<P>),
Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>),
Max(MinMaxDetails, Arg3<P>),
}
#[derive(Copy, Clone)]
@ -554,11 +561,6 @@ impl MovDetails {
}
}
pub enum MulDetails {
Int(MulIntDesc),
Float(MulFloatDesc),
}
#[derive(Copy, Clone)]
pub struct MulIntDesc {
pub typ: IntType,
@ -572,14 +574,6 @@ pub enum MulIntControl {
Wide,
}
#[derive(Copy, Clone)]
pub struct MulFloatDesc {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum RoundingMode {
NearestEven,
@ -588,23 +582,11 @@ pub enum RoundingMode {
PositiveInf,
}
pub enum AddDetails {
Int(AddIntDesc),
Float(AddFloatDesc),
}
pub struct AddIntDesc {
pub typ: IntType,
pub saturate: bool,
}
pub struct AddFloatDesc {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
pub struct SetpData {
pub typ: ScalarType,
pub flush_to_zero: bool,
@ -810,3 +792,57 @@ sub_scalar_type!(OrType {
B32,
B64,
});
#[derive(Copy, Clone)]
pub enum MulDetails {
Unsigned(MulUInt),
Signed(MulSInt),
Float(ArithFloat),
}
#[derive(Copy, Clone)]
pub struct MulUInt {
pub typ: UIntType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub struct MulSInt {
pub typ: SIntType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub enum ArithDetails {
Unsigned(UIntType),
Signed(ArithSInt),
Float(ArithFloat),
}
#[derive(Copy, Clone)]
pub struct ArithSInt {
pub typ: SIntType,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub enum MinMaxDetails {
Signed(SIntType),
Unsigned(UIntType),
Float(MinMaxFloat),
}
#[derive(Copy, Clone)]
pub struct MinMaxFloat {
pub ftz: bool,
pub nan: bool,
pub typ: FloatType,
}

View file

@ -70,6 +70,7 @@ match {
".ltu",
".lu",
".nan",
".NaN",
".ne",
".neu",
".num",
@ -124,6 +125,8 @@ match {
"ld",
"mad",
"map_f64_to_f32",
"max",
"min",
"mov",
"mul",
"not",
@ -134,6 +137,7 @@ match {
"shr",
r"sm_[0-9]+" => ShaderModel,
"st",
"sub",
"texmode_independent",
"texmode_unified",
} else {
@ -153,6 +157,8 @@ ExtendedID : &'input str = {
"ld",
"mad",
"map_f64_to_f32",
"max",
"min",
"mov",
"mul",
"not",
@ -163,6 +169,7 @@ ExtendedID : &'input str = {
"shr",
ShaderModel,
"st",
"sub",
"texmode_independent",
"texmode_unified",
ID
@ -448,7 +455,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCall,
InstAbs,
InstMad,
InstOr
InstOr,
InstSub,
InstMin,
InstMax,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@ -570,38 +580,19 @@ MovVectorType: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a)
"mul" <d:MulDetails> <a:Arg3> => ast::Instruction::Mul(d, a)
};
InstMulMode: ast::MulDetails = {
<ctr:MulIntControl> <t:IntType> => ast::MulDetails::Int(ast::MulIntDesc {
MulDetails: ast::MulDetails = {
<ctr:MulIntControl> <t:UIntType> => ast::MulDetails::Unsigned(ast::MulUInt{
typ: t,
control: ctr
}),
<r:RoundingModeFloat?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F32,
rounding: r,
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
<ctr:MulIntControl> <t:SIntType> => ast::MulDetails::Signed(ast::MulSInt{
typ: t,
control: ctr
}),
<r:RoundingModeFloat?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F64,
rounding: r,
flush_to_zero: false,
saturate: false
}),
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F16,
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
}),
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F16x2,
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
<f:ArithFloat> => ast::MulDetails::Float(f)
};
MulIntControl: ast::MulIntControl = {
@ -634,41 +625,23 @@ IntType : ast::IntType = {
".s64" => ast::IntType::S64,
};
UIntType: ast::UIntType = {
".u16" => ast::UIntType::U16,
".u32" => ast::UIntType::U32,
".u64" => ast::UIntType::U64,
};
SIntType: ast::SIntType = {
".s16" => ast::SIntType::S16,
".s32" => ast::SIntType::S32,
".s64" => ast::SIntType::S64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a)
};
InstAddMode: ast::AddDetails = {
<t:IntType> => ast::AddDetails::Int(ast::AddIntDesc {
typ: t,
saturate: false,
}),
".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc {
typ: ast::IntType::S32,
saturate: true,
}),
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F32,
rounding: rn,
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
}),
<rn:RoundingModeFloat?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F64,
rounding: rn,
flush_to_zero: false,
saturate: false,
}),
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?>".f16" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
}),
".rn"? ".ftz"? ".sat"? ".f16x2" => todo!()
"add" <d:ArithDetails> <a:Arg3> => ast::Instruction::Add(d, a)
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
@ -1041,7 +1014,7 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mad" <d:InstMulMode> <a:Arg4> => ast::Instruction::Mad(d, a),
"mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
"mad" ".hi" ".sat" ".s32" => todo!()
};
@ -1063,6 +1036,84 @@ OrType: ast::OrType = {
".b64" => ast::OrType::B64,
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
InstSub: ast::Instruction<ast::ParsedArgParams<'input>> = {
"sub" <d:ArithDetails> <a:Arg3> => ast::Instruction::Sub(d, a),
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min
InstMin: ast::Instruction<ast::ParsedArgParams<'input>> = {
"min" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Min(d, a),
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max
InstMax: ast::Instruction<ast::ParsedArgParams<'input>> = {
"max" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Max(d, a),
};
MinMaxDetails: ast::MinMaxDetails = {
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
<t:SIntType> => ast::MinMaxDetails::Signed(t),
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 }
),
".f64" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
)
}
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
typ: t,
saturate: false,
}),
".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S32,
saturate: true,
}),
<f:ArithFloat> => ast::ArithDetails::Float(f)
}
ArithFloat: ast::ArithFloat = {
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
typ: ast::FloatType::F32,
rounding: rn,
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
},
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
typ: ast::FloatType::F64,
rounding: rn,
flush_to_zero: false,
saturate: false,
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
typ: ast::FloatType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
typ: ast::FloatType::F16x2,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
},
}
Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r),
<r:ExtendedID> "+" <o:Num> => {

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry max(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp1;
.reg .s32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp1, [in_addr];
ld.s32 temp2, [in_addr+4];
max.s32 temp1, temp1, temp2;
st.s32 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,57 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "max"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%33 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
%1 = OpFunction %void None %33
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%28 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%25 = OpConvertUToPtr %_ptr_Generic_uint %15
%14 = OpLoad %uint %25
OpStore %6 %14
%17 = OpLoad %ulong %4
%24 = OpIAdd %ulong %17 %ulong_4
%26 = OpConvertUToPtr %_ptr_Generic_uint %24
%16 = OpLoad %uint %26
OpStore %7 %16
%19 = OpLoad %uint %6
%20 = OpLoad %uint %7
%18 = OpExtInst %uint %30 s_max %19 %20
OpStore %6 %18
%21 = OpLoad %ulong %5
%22 = OpLoad %uint %6
%27 = OpConvertUToPtr %_ptr_Generic_uint %21
OpStore %27 %22
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry min(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp1;
.reg .s32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp1, [in_addr];
ld.s32 temp2, [in_addr+4];
min.s32 temp1, temp1, temp2;
st.s32 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,57 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "min"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%33 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
%1 = OpFunction %void None %33
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%28 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%25 = OpConvertUToPtr %_ptr_Generic_uint %15
%14 = OpLoad %uint %25
OpStore %6 %14
%17 = OpLoad %ulong %4
%24 = OpIAdd %ulong %17 %ulong_4
%26 = OpConvertUToPtr %_ptr_Generic_uint %24
%16 = OpLoad %uint %26
OpStore %7 %16
%19 = OpLoad %uint %6
%20 = OpLoad %uint %7
%18 = OpExtInst %uint %30 s_min %19 %20
OpStore %6 %18
%21 = OpLoad %ulong %5
%22 = OpLoad %uint %6
%27 = OpConvertUToPtr %_ptr_Generic_uint %21
OpStore %27 %22
OpReturn
OpFunctionEnd

View file

@ -70,6 +70,9 @@ test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64])
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(or, [1u64, 2u64], [3u64]);
test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
test_ptx!(max, [555i32, 444i32], [555i32]);
struct DisplayError<T: Debug> {

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry or(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp1;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp1, [in_addr];
ld.u64 temp2, [in_addr+8];
or.b64 temp1, temp1, temp2;
st.u64 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,58 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%33 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "or"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%36 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_8 = OpConstant %ulong 8
%1 = OpFunction %void None %36
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%31 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%25 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %25
OpStore %6 %14
%17 = OpLoad %ulong %4
%24 = OpIAdd %ulong %17 %ulong_8
%26 = OpConvertUToPtr %_ptr_Generic_ulong %24
%16 = OpLoad %ulong %26
OpStore %7 %16
%19 = OpLoad %ulong %6
%20 = OpLoad %ulong %7
%28 = OpCopyObject %ulong %19
%29 = OpCopyObject %ulong %20
%27 = OpBitwiseOr %ulong %28 %29
%18 = OpCopyObject %ulong %27
OpStore %6 %18
%21 = OpLoad %ulong %5
%22 = OpLoad %ulong %6
%30 = OpConvertUToPtr %_ptr_Generic_ulong %21
OpStore %30 %22
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry sub(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
sub.u64 temp2, temp, 1;
st.u64 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,49 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "sub"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%1 = OpFunction %void None %28
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%23 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
%16 = OpISub %ulong %17 %ulong_1
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %22 %19
OpReturn
OpFunctionEnd

View file

@ -595,6 +595,15 @@ fn convert_to_typed_statements(
ast::Instruction::Or(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
}
ast::Instruction::Sub(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
}
ast::Instruction::Min(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
}
ast::Instruction::Max(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@ -968,62 +977,74 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg_offset(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: ast::Type,
mut typ: ast::Type,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
match desc.sema {
ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
todo!()
ArgumentSemantics::Default
| ArgumentSemantics::DefaultRelaxed
| ArgumentSemantics::PhysicalPointer => {
if desc.sema == ArgumentSemantics::PhysicalPointer {
typ = ast::Type::Scalar(ast::ScalarType::U64);
}
let (width, kind) = match typ {
ast::Type::Scalar(scalar_t) => {
let kind = match scalar_t.kind() {
kind @ ScalarKind::Bit
| kind @ ScalarKind::Unsigned
| kind @ ScalarKind::Signed => kind,
ScalarKind::Float => return Err(TranslateError::MismatchedType),
ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
};
(scalar_t.width(), kind)
}
_ => return Err(TranslateError::MismatchedType),
};
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
let arith_detail = if kind == ScalarKind::Signed {
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::from_size(width),
saturate: false,
})
} else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
let id_constant_stmt = self.id_def.new_id(typ);
let result_id = self.id_def.new_id(typ);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i64,
}));
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
ast::AddDetails::Int(ast::AddIntDesc {
typ: int_type,
saturate: false,
}),
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
Ok(result_id)
}
ArgumentSemantics::PhysicalPointer => {
let scalar_t = ast::ScalarType::U64;
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i64,
}));
let int_type = ast::IntType::U64;
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
ast::AddDetails::Int(ast::AddIntDesc {
typ: int_type,
saturate: false,
}),
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
value: -(offset as i64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
arith_detail,
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
} else {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
value: offset as i64,
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
arith_detail,
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
}
Ok(result_id)
}
ArgumentSemantics::RegisterPointer => {
@ -1522,14 +1543,22 @@ fn emit_function_body_ops(
}
},
ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?;
ast::MulDetails::Signed(ref ctr) => {
emit_mul_sint(builder, map, opencl, ctr, arg)?
}
ast::MulDetails::Unsigned(ref ctr) => {
emit_mul_uint(builder, map, opencl, ctr, arg)?
}
ast::MulDetails::Float(_) => todo!(),
},
ast::Instruction::Add(add, arg) => match add {
ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?,
ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
ast::ArithDetails::Signed(ref desc) => {
emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)?
}
ast::ArithDetails::Unsigned(ref desc) => {
emit_add_int(builder, map, (*desc).into(), false, arg)?
}
ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
},
ast::Instruction::Setp(setp, arg) => {
if arg.dst2.is_some() {
@ -1581,8 +1610,11 @@ fn emit_function_body_ops(
}
ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::Mad(mad, arg) => match mad {
ast::MulDetails::Int(ref desc) => {
emit_mad_int(builder, map, opencl, desc, arg)?
ast::MulDetails::Signed(ref desc) => {
emit_mad_sint(builder, map, opencl, desc, arg)?
}
ast::MulDetails::Unsigned(ref desc) => {
emit_mad_uint(builder, map, opencl, desc, arg)?
}
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
},
@ -1594,6 +1626,23 @@ fn emit_function_body_ops(
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
ast::Instruction::Sub(d, arg) => match d {
ast::ArithDetails::Signed(desc) => {
emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?;
}
ast::ArithDetails::Unsigned(desc) => {
emit_sub_int(builder, map, (*desc).into(), false, arg)?;
}
ast::ArithDetails::Float(desc) => {
emit_sub_float(builder, map, desc, arg)?;
}
},
ast::Instruction::Min(d, a) => {
emit_min(builder, map, opencl, d, a)?;
}
ast::Instruction::Max(d, a) => {
emit_max(builder, map, opencl, d, a)?;
}
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@ -1624,11 +1673,11 @@ fn emit_function_body_ops(
Ok(())
}
fn emit_mad_int(
fn emit_mad_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulIntDesc,
desc: &ast::MulUInt,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
@ -1638,16 +1687,38 @@ fn emit_mad_int(
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
}
ast::MulIntControl::High => {
let cl_op = if desc.typ.is_signed() {
spirv::CLOp::s_mad_hi
} else {
spirv::CLOp::u_mad_hi
};
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
cl_op as spirv::Word,
spirv::CLOp::u_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
}
ast::MulIntControl::Wide => todo!(),
};
Ok(())
}
fn emit_mad_sint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulSInt,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::s_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
}
@ -1659,7 +1730,7 @@ fn emit_mad_int(
fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::MulFloatDesc,
desc: &ast::ArithFloat,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
todo!()
@ -1668,7 +1739,7 @@ fn emit_mad_float(
fn emit_add_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::AddFloatDesc,
desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if desc.flush_to_zero {
@ -1680,6 +1751,67 @@ fn emit_add_float(
Ok(())
}
fn emit_sub_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if desc.flush_to_zero {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
Ok(())
}
fn emit_min(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MinMaxDetails,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let cl_op = match desc {
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
cl_op as spirv::Word,
[arg.src1, arg.src2],
)?;
Ok(())
}
fn emit_max(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MinMaxDetails,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let cl_op = match desc {
ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
cl_op as spirv::Word,
[arg.src1, arg.src2],
)?;
Ok(())
}
fn emit_cvt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -1880,11 +2012,11 @@ fn emit_setp(
Ok(())
}
fn emit_mul_int(
fn emit_mul_sint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulIntDesc,
desc: &ast::MulSInt,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ);
@ -1894,16 +2026,11 @@ fn emit_mul_int(
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
let ocl_mul_hi = if desc.typ.is_signed() {
spirv::CLOp::s_mul_hi
} else {
spirv::CLOp::u_mul_hi
};
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
ocl_mul_hi as spirv::Word,
spirv::CLOp::s_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
}
@ -1913,11 +2040,54 @@ fn emit_mul_int(
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = if desc.typ.is_signed() {
builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
} else {
builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
};
let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.width();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
struct2_bitcast_to_wide(
builder,
map,
SpirvScalarKey::from(instruction_type),
inst_type,
arg.dst,
dst_type_id,
mul,
)?;
}
}
Ok(())
}
fn emit_mul_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulUInt,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ);
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
spirv::CLOp::u_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
}
ast::MulIntControl::Wide => {
let mul_ext_type = SpirvType::Struct(vec![
SpirvScalarKey::from(instruction_type),
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.width();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
@ -1981,14 +2151,33 @@ fn emit_abs(
fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ctr: &ast::AddIntDesc,
typ: ast::ScalarType,
saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ)));
if saturate {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
fn emit_sub_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
typ: ast::ScalarType,
saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if saturate {
todo!()
}
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -2920,6 +3109,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
t,
a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Sub(d, a) => {
let typ = d.get_type();
ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?)
}
ast::Instruction::Min(d, a) => {
let typ = d.get_type();
ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?)
}
ast::Instruction::Max(d, a) => {
let typ = d.get_type();
ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?)
}
})
}
}
@ -3129,6 +3330,9 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Abs(_, _)
| ast::Instruction::Call(_)
| ast::Instruction::Or(_, _)
| ast::Instruction::Sub(_, _)
| ast::Instruction::Min(_, _)
| ast::Instruction::Max(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}
@ -4049,25 +4253,33 @@ impl ast::ShrType {
}
}
impl ast::AddDetails {
impl ast::ArithDetails {
fn get_type(&self) -> ast::Type {
match self {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
ast::Type::Scalar((*typ).into())
}
}
ast::Type::Scalar(match self {
ast::ArithDetails::Unsigned(t) => (*t).into(),
ast::ArithDetails::Signed(d) => d.typ.into(),
ast::ArithDetails::Float(d) => d.typ.into(),
})
}
}
impl ast::MulDetails {
fn get_type(&self) -> ast::Type {
match self {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
ast::Type::Scalar((*typ).into())
}
}
ast::Type::Scalar(match self {
ast::MulDetails::Unsigned(d) => d.typ.into(),
ast::MulDetails::Signed(d) => d.typ.into(),
ast::MulDetails::Float(d) => d.typ.into(),
})
}
}
impl ast::MinMaxDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
ast::MinMaxDetails::Signed(t) => (*t).into(),
ast::MinMaxDetails::Unsigned(t) => (*t).into(),
ast::MinMaxDetails::Float(d) => d.typ.into(),
})
}
}
@ -4085,6 +4297,30 @@ impl ast::IntType {
}
}
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::SIntType::S8,
2 => ast::SIntType::S16,
4 => ast::SIntType::S32,
8 => ast::SIntType::S64,
_ => unreachable!(),
}
}
}
impl ast::UIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::UIntType::U8,
2 => ast::UIntType::U16,
4 => ast::UIntType::U32,
8 => ast::UIntType::U64,
_ => unreachable!(),
}
}
}
impl ast::LdStateSpace {
fn to_spirv(self) -> spirv::StorageClass {
match self {
@ -4128,7 +4364,8 @@ impl<T> ast::OperandOrVector<T> {
impl ast::MulDetails {
fn is_wide(&self) -> bool {
match self {
ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Float(_) => false,
}
}