mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Implement div, sqrt, rsqrt and more of setp
This commit is contained in:
parent
a82eb20817
commit
b7d61baf37
12 changed files with 645 additions and 44 deletions
|
@ -539,6 +539,9 @@ pub enum Instruction<P: ArgParams> {
|
|||
Bar(BarDetails, Arg1Bar<P>),
|
||||
Atom(AtomDetails, Arg3<P>),
|
||||
AtomCas(AtomCasDetails, Arg4<P>),
|
||||
Div(DivDetails, Arg3<P>),
|
||||
Sqrt(SqrtDetails, Arg2<P>),
|
||||
Rsqrt(RsqrtDetails, Arg2<P>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
@ -1132,7 +1135,28 @@ pub struct AtomCasDetails {
|
|||
pub semantics: AtomSemantics,
|
||||
pub scope: MemScope,
|
||||
pub space: AtomSpace,
|
||||
pub typ: BitType
|
||||
pub typ: BitType,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum DivDetails {
|
||||
Unsigned(UIntType),
|
||||
Signed(SIntType),
|
||||
Float(DivFloatDetails),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct DivFloatDetails {
|
||||
pub typ: FloatType,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub kind: DivFloatKind,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum DivFloatKind {
|
||||
Approx,
|
||||
Full,
|
||||
Rounding(RoundingMode),
|
||||
}
|
||||
|
||||
pub enum NumsOrArrays<'a> {
|
||||
|
@ -1140,6 +1164,25 @@ pub enum NumsOrArrays<'a> {
|
|||
Arrays(Vec<NumsOrArrays<'a>>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct SqrtDetails {
|
||||
pub typ: FloatType,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub kind: SqrtKind,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum SqrtKind {
|
||||
Approx,
|
||||
Rounding(RoundingMode),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub struct RsqrtDetails {
|
||||
pub typ: FloatType,
|
||||
pub flush_to_zero: bool,
|
||||
}
|
||||
|
||||
impl<'a> NumsOrArrays<'a> {
|
||||
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
|
||||
self.normalize_dimensions(dimensions)?;
|
||||
|
|
|
@ -66,6 +66,7 @@ match {
|
|||
".f64",
|
||||
".file",
|
||||
".ftz",
|
||||
".full",
|
||||
".func",
|
||||
".ge",
|
||||
".geu",
|
||||
|
@ -94,6 +95,7 @@ match {
|
|||
".num",
|
||||
".or",
|
||||
".param",
|
||||
".pragma",
|
||||
".pred",
|
||||
".reg",
|
||||
".relaxed",
|
||||
|
@ -145,6 +147,7 @@ match {
|
|||
"cvt",
|
||||
"cvta",
|
||||
"debug",
|
||||
"div",
|
||||
"fma",
|
||||
"ld",
|
||||
"mad",
|
||||
|
@ -157,11 +160,13 @@ match {
|
|||
"or",
|
||||
"rcp",
|
||||
"ret",
|
||||
"rsqrt",
|
||||
"selp",
|
||||
"setp",
|
||||
"shl",
|
||||
"shr",
|
||||
r"sm_[0-9]+" => ShaderModel,
|
||||
"sqrt",
|
||||
"st",
|
||||
"sub",
|
||||
"texmode_independent",
|
||||
|
@ -184,6 +189,7 @@ ExtendedID : &'input str = {
|
|||
"cvt",
|
||||
"cvta",
|
||||
"debug",
|
||||
"div",
|
||||
"fma",
|
||||
"ld",
|
||||
"mad",
|
||||
|
@ -196,11 +202,13 @@ ExtendedID : &'input str = {
|
|||
"or",
|
||||
"rcp",
|
||||
"ret",
|
||||
"rsqrt",
|
||||
"selp",
|
||||
"setp",
|
||||
"shl",
|
||||
"shr",
|
||||
ShaderModel,
|
||||
"sqrt",
|
||||
"st",
|
||||
"sub",
|
||||
"texmode_independent",
|
||||
|
@ -415,9 +423,14 @@ Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
|
|||
DebugDirective => None,
|
||||
<v:MultiVariable> ";" => Some(ast::Statement::Variable(v)),
|
||||
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
|
||||
PragmaStatement => None,
|
||||
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
|
||||
};
|
||||
|
||||
PragmaStatement: () = {
|
||||
".pragma" String ";"
|
||||
}
|
||||
|
||||
DebugDirective: () = {
|
||||
DebugLocation
|
||||
};
|
||||
|
@ -667,7 +680,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
InstSelp,
|
||||
InstBar,
|
||||
InstAtom,
|
||||
InstAtomCas
|
||||
InstAtomCas,
|
||||
InstDiv,
|
||||
InstSqrt,
|
||||
InstRsqrt,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||
|
@ -1485,6 +1501,82 @@ AtomSIntType: ast::SIntType = {
|
|||
".s64" => ast::SIntType::S64,
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div
|
||||
InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"div" <t:UIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Unsigned(t), a),
|
||||
"div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a),
|
||||
"div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => {
|
||||
let inner = ast::DivFloatDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
kind
|
||||
};
|
||||
ast::Instruction::Div(ast::DivDetails::Float(inner), a)
|
||||
},
|
||||
"div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => {
|
||||
let inner = ast::DivFloatDetails {
|
||||
typ: ast::FloatType::F64,
|
||||
flush_to_zero: None,
|
||||
kind: ast::DivFloatKind::Rounding(rnd)
|
||||
};
|
||||
ast::Instruction::Div(ast::DivDetails::Float(inner), a)
|
||||
},
|
||||
}
|
||||
|
||||
DivFloatKind: ast::DivFloatKind = {
|
||||
".approx" => ast::DivFloatKind::Approx,
|
||||
".full" => ast::DivFloatKind::Full,
|
||||
<rnd:RoundingModeFloat> => ast::DivFloatKind::Rounding(rnd),
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt
|
||||
InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||
let details = ast::SqrtDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
kind: ast::SqrtKind::Approx,
|
||||
};
|
||||
ast::Instruction::Sqrt(details, a)
|
||||
},
|
||||
"sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||
let details = ast::SqrtDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
kind: ast::SqrtKind::Rounding(rnd),
|
||||
};
|
||||
ast::Instruction::Sqrt(details, a)
|
||||
},
|
||||
"sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => {
|
||||
let details = ast::SqrtDetails {
|
||||
typ: ast::FloatType::F64,
|
||||
flush_to_zero: None,
|
||||
kind: ast::SqrtKind::Rounding(rnd),
|
||||
};
|
||||
ast::Instruction::Sqrt(details, a)
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64
|
||||
InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||
let details = ast::RsqrtDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
};
|
||||
ast::Instruction::Rsqrt(details, a)
|
||||
},
|
||||
"rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => {
|
||||
let details = ast::RsqrtDetails {
|
||||
typ: ast::FloatType::F64,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
};
|
||||
ast::Instruction::Rsqrt(details, a)
|
||||
},
|
||||
}
|
||||
|
||||
ArithDetails: ast::ArithDetails = {
|
||||
<t:UIntType> => ast::ArithDetails::Unsigned(t),
|
||||
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
|
|
23
ptx/src/test/spirv_run/div_approx.ptx
Normal file
23
ptx/src/test/spirv_run/div_approx.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry div_approx(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp1;
|
||||
.reg .f32 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 temp1, [in_addr];
|
||||
ld.f32 temp2, [in_addr+4];
|
||||
div.approx.f32 temp1, temp1, temp2;
|
||||
st.f32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
65
ptx/src/test/spirv_run/div_approx.spvtxt
Normal file
65
ptx/src/test/spirv_run/div_approx.spvtxt
Normal file
|
@ -0,0 +1,65 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: rspirv
|
||||
; Bound: 38
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
; OpCapability FunctionFloatControlINTEL
|
||||
; OpExtension "SPV_INTEL_float_controls2"
|
||||
%30 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "div_approx"
|
||||
OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
|
||||
OpDecorate %18 FPFastMathMode AllowRecip
|
||||
%31 = OpTypeVoid
|
||||
%32 = OpTypeInt 64 0
|
||||
%33 = OpTypeFunction %31 %32 %32
|
||||
%34 = OpTypePointer Function %32
|
||||
%35 = OpTypeFloat 32
|
||||
%36 = OpTypePointer Function %35
|
||||
%37 = OpTypePointer Generic %35
|
||||
%23 = OpConstant %32 4
|
||||
%1 = OpFunction %31 None %33
|
||||
%8 = OpFunctionParameter %32
|
||||
%9 = OpFunctionParameter %32
|
||||
%28 = OpLabel
|
||||
%2 = OpVariable %34 Function
|
||||
%3 = OpVariable %34 Function
|
||||
%4 = OpVariable %34 Function
|
||||
%5 = OpVariable %34 Function
|
||||
%6 = OpVariable %36 Function
|
||||
%7 = OpVariable %36 Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %32 %2
|
||||
%10 = OpCopyObject %32 %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %32 %3
|
||||
%12 = OpCopyObject %32 %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %32 %4
|
||||
%25 = OpConvertUToPtr %37 %15
|
||||
%14 = OpLoad %35 %25
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %32 %4
|
||||
%24 = OpIAdd %32 %17 %23
|
||||
%26 = OpConvertUToPtr %37 %24
|
||||
%16 = OpLoad %35 %26
|
||||
OpStore %7 %16
|
||||
%19 = OpLoad %35 %6
|
||||
%20 = OpLoad %35 %7
|
||||
%18 = OpFDiv %35 %19 %20
|
||||
OpStore %6 %18
|
||||
%21 = OpLoad %32 %5
|
||||
%22 = OpLoad %35 %6
|
||||
%27 = OpConvertUToPtr %37 %21
|
||||
OpStore %27 %22
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -97,9 +97,13 @@ test_ptx!(and, [6u32, 3u32], [2u32]);
|
|||
test_ptx!(selp, [100u16, 200u16], [200u16]);
|
||||
test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
|
||||
test_ptx!(shared_variable, [513u64], [513u64]);
|
||||
test_ptx!(shared_ptr_32, [513u64], [513u64]);
|
||||
test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
|
||||
test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
|
||||
test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
|
||||
test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
|
||||
test_ptx!(sqrt, [0.25f32], [0.5f32]);
|
||||
test_ptx!(rsqrt, [0.25f64], [2f64]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
|
21
ptx/src/test/spirv_run/rsqrt.ptx
Normal file
21
ptx/src/test/spirv_run/rsqrt.ptx
Normal file
|
@ -0,0 +1,21 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry rsqrt(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f64 temp1;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f64 temp1, [in_addr];
|
||||
rsqrt.approx.f64 temp1, temp1;
|
||||
st.f64 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
56
ptx/src/test/spirv_run/rsqrt.spvtxt
Normal file
56
ptx/src/test/spirv_run/rsqrt.spvtxt
Normal file
|
@ -0,0 +1,56 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: rspirv
|
||||
; Bound: 31
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
; OpCapability FunctionFloatControlINTEL
|
||||
; OpExtension "SPV_INTEL_float_controls2"
|
||||
%23 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "rsqrt"
|
||||
OpDecorate %1 FunctionDenormModeINTEL 64 Preserve
|
||||
%24 = OpTypeVoid
|
||||
%25 = OpTypeInt 64 0
|
||||
%26 = OpTypeFunction %24 %25 %25
|
||||
%27 = OpTypePointer Function %25
|
||||
%28 = OpTypeFloat 64
|
||||
%29 = OpTypePointer Function %28
|
||||
%30 = OpTypePointer Generic %28
|
||||
%1 = OpFunction %24 None %26
|
||||
%7 = OpFunctionParameter %25
|
||||
%8 = OpFunctionParameter %25
|
||||
%21 = OpLabel
|
||||
%2 = OpVariable %27 Function
|
||||
%3 = OpVariable %27 Function
|
||||
%4 = OpVariable %27 Function
|
||||
%5 = OpVariable %27 Function
|
||||
%6 = OpVariable %29 Function
|
||||
OpStore %2 %7
|
||||
OpStore %3 %8
|
||||
%10 = OpLoad %25 %2
|
||||
%9 = OpCopyObject %25 %10
|
||||
OpStore %4 %9
|
||||
%12 = OpLoad %25 %3
|
||||
%11 = OpCopyObject %25 %12
|
||||
OpStore %5 %11
|
||||
%14 = OpLoad %25 %4
|
||||
%19 = OpConvertUToPtr %30 %14
|
||||
%13 = OpLoad %28 %19
|
||||
OpStore %6 %13
|
||||
%16 = OpLoad %28 %6
|
||||
%15 = OpExtInst %28 %23 native_rsqrt %16
|
||||
OpStore %6 %15
|
||||
%17 = OpLoad %25 %5
|
||||
%18 = OpLoad %28 %6
|
||||
%20 = OpConvertUToPtr %30 %17
|
||||
OpStore %20 %18
|
||||
OpReturn
|
||||
OpFunctionEnd
|
29
ptx/src/test/spirv_run/shared_ptr_32.ptx
Normal file
29
ptx/src/test/spirv_run/shared_ptr_32.ptx
Normal file
|
@ -0,0 +1,29 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
|
||||
.visible .entry shared_ptr_32(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.shared .align 4 .b8 shared_mem1[128];
|
||||
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u32 shared_addr;
|
||||
|
||||
.reg .u64 temp1;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
mov.u32 shared_addr, shared_mem1;
|
||||
|
||||
ld.global.u64 temp1, [in_addr];
|
||||
st.shared.u64 [shared_addr], temp1;
|
||||
ld.shared.u64 temp2, [shared_addr+0];
|
||||
st.global.u64 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
74
ptx/src/test/spirv_run/shared_ptr_32.spvtxt
Normal file
74
ptx/src/test/spirv_run/shared_ptr_32.spvtxt
Normal file
|
@ -0,0 +1,74 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: rspirv
|
||||
; Bound: 47
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
; OpCapability FunctionFloatControlINTEL
|
||||
; OpExtension "SPV_INTEL_float_controls2"
|
||||
%34 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "shared_ptr_32" %4
|
||||
OpDecorate %4 Alignment 4
|
||||
%35 = OpTypeVoid
|
||||
%36 = OpTypeInt 32 0
|
||||
%37 = OpTypeInt 8 0
|
||||
%38 = OpConstant %36 128
|
||||
%39 = OpTypeArray %37 %38
|
||||
%40 = OpTypePointer Workgroup %39
|
||||
%4 = OpVariable %40 Workgroup
|
||||
%41 = OpTypeInt 64 0
|
||||
%42 = OpTypeFunction %35 %41 %41
|
||||
%43 = OpTypePointer Function %41
|
||||
%44 = OpTypePointer Function %36
|
||||
%45 = OpTypePointer CrossWorkgroup %41
|
||||
%46 = OpTypePointer Workgroup %41
|
||||
%25 = OpConstant %36 0
|
||||
%1 = OpFunction %35 None %42
|
||||
%10 = OpFunctionParameter %41
|
||||
%11 = OpFunctionParameter %41
|
||||
%32 = OpLabel
|
||||
%2 = OpVariable %43 Function
|
||||
%3 = OpVariable %43 Function
|
||||
%5 = OpVariable %43 Function
|
||||
%6 = OpVariable %43 Function
|
||||
%7 = OpVariable %44 Function
|
||||
%8 = OpVariable %43 Function
|
||||
%9 = OpVariable %43 Function
|
||||
OpStore %2 %10
|
||||
OpStore %3 %11
|
||||
%13 = OpLoad %41 %2
|
||||
%12 = OpCopyObject %41 %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %41 %3
|
||||
%14 = OpCopyObject %41 %15
|
||||
OpStore %6 %14
|
||||
%27 = OpConvertPtrToU %36 %4
|
||||
%16 = OpCopyObject %36 %27
|
||||
OpStore %7 %16
|
||||
%18 = OpLoad %41 %5
|
||||
%28 = OpConvertUToPtr %45 %18
|
||||
%17 = OpLoad %41 %28
|
||||
OpStore %8 %17
|
||||
%19 = OpLoad %36 %7
|
||||
%20 = OpLoad %41 %8
|
||||
%29 = OpConvertUToPtr %46 %19
|
||||
OpStore %29 %20
|
||||
%22 = OpLoad %36 %7
|
||||
%26 = OpIAdd %36 %22 %25
|
||||
%30 = OpConvertUToPtr %46 %26
|
||||
%21 = OpLoad %41 %30
|
||||
OpStore %9 %21
|
||||
%23 = OpLoad %41 %6
|
||||
%24 = OpLoad %41 %9
|
||||
%31 = OpConvertUToPtr %45 %23
|
||||
OpStore %31 %24
|
||||
OpReturn
|
||||
OpFunctionEnd
|
21
ptx/src/test/spirv_run/sqrt.ptx
Normal file
21
ptx/src/test/spirv_run/sqrt.ptx
Normal file
|
@ -0,0 +1,21 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry sqrt(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .f32 temp1;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.f32 temp1, [in_addr];
|
||||
sqrt.approx.f32 temp1, temp1;
|
||||
st.f32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
56
ptx/src/test/spirv_run/sqrt.spvtxt
Normal file
56
ptx/src/test/spirv_run/sqrt.spvtxt
Normal file
|
@ -0,0 +1,56 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: rspirv
|
||||
; Bound: 31
|
||||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
; OpCapability FunctionFloatControlINTEL
|
||||
; OpExtension "SPV_INTEL_float_controls2"
|
||||
%23 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "sqrt"
|
||||
OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
|
||||
%24 = OpTypeVoid
|
||||
%25 = OpTypeInt 64 0
|
||||
%26 = OpTypeFunction %24 %25 %25
|
||||
%27 = OpTypePointer Function %25
|
||||
%28 = OpTypeFloat 32
|
||||
%29 = OpTypePointer Function %28
|
||||
%30 = OpTypePointer Generic %28
|
||||
%1 = OpFunction %24 None %26
|
||||
%7 = OpFunctionParameter %25
|
||||
%8 = OpFunctionParameter %25
|
||||
%21 = OpLabel
|
||||
%2 = OpVariable %27 Function
|
||||
%3 = OpVariable %27 Function
|
||||
%4 = OpVariable %27 Function
|
||||
%5 = OpVariable %27 Function
|
||||
%6 = OpVariable %29 Function
|
||||
OpStore %2 %7
|
||||
OpStore %3 %8
|
||||
%10 = OpLoad %25 %2
|
||||
%9 = OpCopyObject %25 %10
|
||||
OpStore %4 %9
|
||||
%12 = OpLoad %25 %3
|
||||
%11 = OpCopyObject %25 %12
|
||||
OpStore %5 %11
|
||||
%14 = OpLoad %25 %4
|
||||
%19 = OpConvertUToPtr %30 %14
|
||||
%13 = OpLoad %28 %19
|
||||
OpStore %6 %13
|
||||
%16 = OpLoad %28 %6
|
||||
%15 = OpExtInst %28 %23 native_sqrt %16
|
||||
OpStore %6 %15
|
||||
%17 = OpLoad %25 %5
|
||||
%18 = OpLoad %28 %6
|
||||
%20 = OpConvertUToPtr %30 %17
|
||||
OpStore %20 %18
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -1,8 +1,11 @@
|
|||
use crate::ast;
|
||||
use half::f16;
|
||||
use rspirv::{binary::Disassemble, dr};
|
||||
use std::collections::{hash_map, HashMap, HashSet};
|
||||
use std::{borrow::Cow, hash::Hash, iter, mem};
|
||||
use std::{
|
||||
collections::{hash_map, HashMap, HashSet},
|
||||
convert::TryInto,
|
||||
};
|
||||
|
||||
use rspirv::binary::Assemble;
|
||||
|
||||
|
@ -1499,6 +1502,15 @@ fn convert_to_typed_statements(
|
|||
ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
|
||||
ast::Instruction::AtomCas(d, a.cast()),
|
||||
)),
|
||||
ast::Instruction::Div(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
|
||||
}
|
||||
ast::Instruction::Sqrt(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
|
||||
}
|
||||
ast::Instruction::Rsqrt(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
|
||||
}
|
||||
},
|
||||
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||
|
@ -1982,7 +1994,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
| ArgumentSemantics::DefaultRelaxed
|
||||
| ArgumentSemantics::PhysicalPointer => {
|
||||
if desc.sema == ArgumentSemantics::PhysicalPointer {
|
||||
typ = ast::Type::Scalar(ast::ScalarType::U64);
|
||||
typ = self.id_def.get_typed(reg)?;
|
||||
}
|
||||
let (width, kind) = match typ {
|
||||
ast::Type::Scalar(scalar_t) => {
|
||||
|
@ -2013,7 +2025,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::from_parts(width, kind),
|
||||
value: ast::ImmediateValue::S64(-(offset as i64)),
|
||||
value: ast::ImmediateValue::U64(-(offset as i64) as u64),
|
||||
}));
|
||||
self.func.push(Statement::Instruction(
|
||||
ast::Instruction::<ExpandedArgParams>::Sub(
|
||||
|
@ -2765,6 +2777,34 @@ fn emit_function_body_ops(
|
|||
arg.src2,
|
||||
)?;
|
||||
}
|
||||
ast::Instruction::Div(details, arg) => match details {
|
||||
ast::DivDetails::Unsigned(t) => {
|
||||
let result_type = map.get_or_add_scalar(builder, (*t).into());
|
||||
builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
}
|
||||
ast::DivDetails::Signed(t) => {
|
||||
let result_type = map.get_or_add_scalar(builder, (*t).into());
|
||||
builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
}
|
||||
ast::DivDetails::Float(t) => {
|
||||
let result_type = map.get_or_add_scalar(builder, t.typ.into());
|
||||
builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
|
||||
emit_float_div_decoration(builder, arg.dst, t.kind);
|
||||
}
|
||||
},
|
||||
ast::Instruction::Sqrt(details, a) => {
|
||||
emit_sqrt(builder, map, opencl, details, a)?;
|
||||
}
|
||||
ast::Instruction::Rsqrt(details, a) => {
|
||||
let result_type = map.get_or_add_scalar(builder, details.typ.into());
|
||||
builder.ext_inst(
|
||||
result_type,
|
||||
Some(a.dst),
|
||||
opencl,
|
||||
spirv::CLOp::native_rsqrt as spirv::Word,
|
||||
&[a.src],
|
||||
)?;
|
||||
}
|
||||
},
|
||||
Statement::LoadVar(arg, typ) => {
|
||||
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
|
||||
|
@ -2795,6 +2835,47 @@ fn emit_function_body_ops(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_sqrt(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
opencl: spirv::Word,
|
||||
details: &ast::SqrtDetails,
|
||||
a: &ast::Arg2<ExpandedArgParams>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let result_type = map.get_or_add_scalar(builder, details.typ.into());
|
||||
let (ocl_op, rounding) = match details.kind {
|
||||
ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None),
|
||||
ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
|
||||
};
|
||||
builder.ext_inst(
|
||||
result_type,
|
||||
Some(a.dst),
|
||||
opencl,
|
||||
ocl_op as spirv::Word,
|
||||
&[a.src],
|
||||
)?;
|
||||
emit_rounding_decoration(builder, a.dst, rounding);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) {
|
||||
match kind {
|
||||
ast::DivFloatKind::Approx => {
|
||||
builder.decorate(
|
||||
dst,
|
||||
spirv::Decoration::FPFastMathMode,
|
||||
&[dr::Operand::FPFastMathMode(
|
||||
spirv::FPFastMathMode::ALLOW_RECIP,
|
||||
)],
|
||||
);
|
||||
}
|
||||
ast::DivFloatKind::Rounding(rnd) => {
|
||||
emit_rounding_decoration(builder, dst, Some(rnd));
|
||||
}
|
||||
ast::DivFloatKind::Full => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_atom(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -3307,7 +3388,25 @@ fn emit_setp(
|
|||
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
|
||||
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
_ => todo!(),
|
||||
(ast::SetpCompareOp::NanEq, _) => {
|
||||
builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NanNotEq, _) => {
|
||||
builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NanLess, _) => {
|
||||
builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NanLessOrEq, _) => {
|
||||
builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NanGreater, _) => {
|
||||
builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NanGreaterOrEq, _) => {
|
||||
builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
_ => todo!()
|
||||
}?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -3486,8 +3585,8 @@ fn emit_implicit_conversion(
|
|||
let from_parts = cv.from.to_parts();
|
||||
let to_parts = cv.to.to_parts();
|
||||
match (from_parts.kind, to_parts.kind, cv.kind) {
|
||||
(_, _, ConversionKind::PtrToBit) => {
|
||||
let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
|
||||
(_, _, ConversionKind::PtrToBit(typ)) => {
|
||||
let dst_type = map.get_or_add_scalar(builder, typ.into());
|
||||
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
(_, _, ConversionKind::BitToPtr(_)) => {
|
||||
|
@ -4570,6 +4669,15 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
ast::Instruction::AtomCas(d, a) => {
|
||||
ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
|
||||
}
|
||||
ast::Instruction::Div(d, a) => {
|
||||
ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?)
|
||||
}
|
||||
ast::Instruction::Sqrt(d, a) => {
|
||||
ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
|
||||
}
|
||||
ast::Instruction::Rsqrt(d, a) => {
|
||||
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4794,32 +4902,7 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||
fn jump_target(&self) -> Option<spirv::Word> {
|
||||
match self {
|
||||
ast::Instruction::Bra(_, a) => Some(a.src),
|
||||
ast::Instruction::Ld(_, _)
|
||||
| ast::Instruction::Mov(_, _)
|
||||
| ast::Instruction::Mul(_, _)
|
||||
| ast::Instruction::Add(_, _)
|
||||
| ast::Instruction::Setp(_, _)
|
||||
| ast::Instruction::SetpBool(_, _)
|
||||
| ast::Instruction::Not(_, _)
|
||||
| ast::Instruction::Cvt(_, _)
|
||||
| ast::Instruction::Cvta(_, _)
|
||||
| ast::Instruction::Shl(_, _)
|
||||
| ast::Instruction::Shr(_, _)
|
||||
| ast::Instruction::St(_, _)
|
||||
| ast::Instruction::Ret(_)
|
||||
| ast::Instruction::Abs(_, _)
|
||||
| ast::Instruction::Call(_)
|
||||
| ast::Instruction::Or(_, _)
|
||||
| ast::Instruction::Sub(_, _)
|
||||
| ast::Instruction::Min(_, _)
|
||||
| ast::Instruction::Max(_, _)
|
||||
| ast::Instruction::Rcp(_, _)
|
||||
| ast::Instruction::And(_, _)
|
||||
| ast::Instruction::Selp(_, _)
|
||||
| ast::Instruction::Bar(_, _)
|
||||
| ast::Instruction::Atom(_, _)
|
||||
| ast::Instruction::AtomCas(_, _)
|
||||
| ast::Instruction::Mad(_, _) => None,
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4856,6 +4939,9 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||
ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
|
||||
ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
|
||||
ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
|
||||
ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
|
||||
ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
|
||||
ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
|
||||
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
|
||||
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
|
||||
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
|
||||
|
@ -4884,14 +4970,20 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||
ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }),
|
||||
_,
|
||||
)
|
||||
| ast::Instruction::Cvt(
|
||||
ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }),
|
||||
_,
|
||||
)
|
||||
| ast::Instruction::Cvt(
|
||||
ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
|
||||
_,
|
||||
) => flush_to_zero.map(|ftz| (ftz, 4)),
|
||||
ast::Instruction::Div(ast::DivDetails::Float(details), _) => details
|
||||
.flush_to_zero
|
||||
.map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
|
||||
ast::Instruction::Sqrt(details, _) => details
|
||||
.flush_to_zero
|
||||
.map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
|
||||
ast::Instruction::Rsqrt(details, _) => Some((
|
||||
details.flush_to_zero,
|
||||
ast::ScalarType::from(details.typ).size_of(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4978,13 +5070,13 @@ struct ImplicitConversion {
|
|||
kind: ConversionKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
#[derive(PartialEq, Copy, Clone)]
|
||||
enum ConversionKind {
|
||||
Default,
|
||||
// zero-extend/chop/bitcast depending on types
|
||||
SignExtend,
|
||||
BitToPtr(ast::LdStateSpace),
|
||||
PtrToBit,
|
||||
PtrToBit(ast::UIntType),
|
||||
PtrToPtr { spirv_ptr: bool },
|
||||
}
|
||||
|
||||
|
@ -6027,6 +6119,16 @@ impl ast::MinMaxDetails {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::DivDetails {
|
||||
fn get_type(&self) -> ast::Type {
|
||||
ast::Type::Scalar(match self {
|
||||
ast::DivDetails::Unsigned(t) => (*t).into(),
|
||||
ast::DivDetails::Signed(t) => (*t).into(),
|
||||
ast::DivDetails::Float(d) => d.typ.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::AtomInnerDetails {
|
||||
fn get_type(&self) -> ast::ScalarType {
|
||||
match self {
|
||||
|
@ -6193,6 +6295,15 @@ fn bitcast_physical_pointer(
|
|||
Err(TranslateError::Unreachable)
|
||||
}
|
||||
}
|
||||
ast::Type::Scalar(ast::ScalarType::B32)
|
||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||
| ast::Type::Scalar(ast::ScalarType::S32) => {
|
||||
if let Some(ast::LdStateSpace::Shared) = ss {
|
||||
Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
}
|
||||
ast::Type::Pointer(op_scalar_t, op_space) => {
|
||||
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
|
||||
if op_space == instr_space {
|
||||
|
@ -6220,10 +6331,16 @@ fn bitcast_physical_pointer(
|
|||
|
||||
fn force_bitcast_ptr_to_bit(
|
||||
_: &ast::Type,
|
||||
_: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
_: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
Ok(Some(ConversionKind::PtrToBit))
|
||||
// TODO: verify this on f32, u16 and the like
|
||||
if let ast::Type::Scalar(scalar_t) = instr_type {
|
||||
if let Ok(int_type) = (*scalar_t).try_into() {
|
||||
return Ok(Some(ConversionKind::PtrToBit(int_type)));
|
||||
}
|
||||
}
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||
|
@ -6542,9 +6659,9 @@ mod tests {
|
|||
&ast::Type::Scalar(*instr_type),
|
||||
);
|
||||
if instr_idx == op_idx {
|
||||
assert_eq!(conversion, None);
|
||||
assert!(conversion == None);
|
||||
} else {
|
||||
assert_eq!(conversion, conv_table[instr_idx][op_idx]);
|
||||
assert!(conversion == conv_table[instr_idx][op_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue