Fix some unhandled cases in cvt instruction

This commit is contained in:
Andrzej Janik 2021-09-14 23:38:06 +00:00
parent 2cd0fcb650
commit 467782b1d0
9 changed files with 318 additions and 60 deletions

View file

@ -14,7 +14,7 @@ extern {
match {
r"\s+" => { },
r"//[^\n\r]*[\n\r]*" => { },
r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { },
r"/\*[^*]*\*+(?:[^/*][^*]*\*+)*/" => { },
r"0[fF][0-9a-zA-Z]{8}" => F32NumToken,
r"0[dD][0-9a-zA-Z]{16}" => F64NumToken,
r"0[xX][0-9a-zA-Z]+U?" => HexNumToken,
@ -1143,11 +1143,11 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
), a)
},
"cvt" <s:".sat"?> ".f64" ".f32" <a:Arg2> => {
"cvt" <s:".sat"?> <f:".ftz"?> ".f64" ".f32" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: None,
flush_to_zero: None,
flush_to_zero: Some(f.is_some()),
saturate: s.is_some(),
dst: ast::ScalarType::F64,
src: ast::ScalarType::F32

View file

@ -19,10 +19,10 @@
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.const.b16 temp1, constparams[0];
ld.const.b16 temp2, constparams[1];
ld.const.b16 temp3, constparams[2];
ld.const.b16 temp4, constparams[3];
ld.const.b16 temp1, [constparams];
ld.const.b16 temp2, [constparams+2];
ld.const.b16 temp3, [constparams+4];
ld.const.b16 temp4, [constparams+6];
st.u16 [out_addr], temp1;
st.u16 [out_addr+2], temp2;
st.u16 [out_addr+4], temp3;

View file

@ -7,41 +7,106 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
%21 = OpExtInstImport "OpenCL.std"
%53 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "clz"
OpEntryPoint Kernel %2 "const" %1
OpExecutionMode %2 ContractionOff
OpDecorate %1 Alignment 8
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%24 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %24
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
%19 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%ushort = OpTypeInt 16 0
%uint_4 = OpConstant %uint 4
%_arr_ushort_uint_4 = OpTypeArray %ushort %uint_4
%ushort_10 = OpConstant %ushort 10
%ushort_20 = OpConstant %ushort 20
%ushort_30 = OpConstant %ushort 30
%ushort_40 = OpConstant %ushort 40
%63 = OpConstantComposite %_arr_ushort_uint_4 %ushort_10 %ushort_20 %ushort_30 %ushort_40
%uint_4_0 = OpConstant %uint 4
%_ptr_UniformConstant__arr_ushort_uint_4 = OpTypePointer UniformConstant %_arr_ushort_uint_4
%1 = OpVariable %_ptr_UniformConstant__arr_ushort_uint_4 UniformConstant %63
%ulong = OpTypeInt 64 0
%67 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_ushort = OpTypePointer Function %ushort
%_ptr_UniformConstant_ushort = OpTypePointer UniformConstant %ushort
%ulong_2 = OpConstant %ulong 2
%uchar = OpTypeInt 8 0
%_ptr_UniformConstant_uchar = OpTypePointer UniformConstant %uchar
%ulong_4 = OpConstant %ulong 4
%ulong_6 = OpConstant %ulong 6
%_ptr_Generic_ushort = OpTypePointer Generic %ushort
%ulong_2_0 = OpConstant %ulong 2
%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_4_0 = OpConstant %ulong 4
%ulong_6_0 = OpConstant %ulong 6
%2 = OpFunction %void None %67
%11 = OpFunctionParameter %ulong
%12 = OpFunctionParameter %ulong
%51 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
OpStore %2 %7
OpStore %3 %8
%9 = OpLoad %ulong %2 Aligned 8
OpStore %4 %9
%10 = OpLoad %ulong %3 Aligned 8
OpStore %5 %10
%12 = OpLoad %ulong %4
%17 = OpConvertUToPtr %_ptr_Generic_uint %12
%11 = OpLoad %uint %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %uint %6
%13 = OpExtInst %uint %21 clz %14
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %uint %6
%18 = OpConvertUToPtr %_ptr_Generic_uint %15
OpStore %18 %16 Aligned 4
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ushort Function
%8 = OpVariable %_ptr_Function_ushort Function
%9 = OpVariable %_ptr_Function_ushort Function
%10 = OpVariable %_ptr_Function_ushort Function
OpStore %3 %11
OpStore %4 %12
%13 = OpLoad %ulong %3 Aligned 8
OpStore %5 %13
%14 = OpLoad %ulong %4 Aligned 8
OpStore %6 %14
%39 = OpBitcast %_ptr_UniformConstant_ushort %1
%15 = OpLoad %ushort %39 Aligned 2
OpStore %7 %15
%40 = OpBitcast %_ptr_UniformConstant_ushort %1
%73 = OpBitcast %_ptr_UniformConstant_uchar %40
%74 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %73 %ulong_2
%28 = OpBitcast %_ptr_UniformConstant_ushort %74
%16 = OpLoad %ushort %28 Aligned 2
OpStore %8 %16
%41 = OpBitcast %_ptr_UniformConstant_ushort %1
%75 = OpBitcast %_ptr_UniformConstant_uchar %41
%76 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %75 %ulong_4
%30 = OpBitcast %_ptr_UniformConstant_ushort %76
%17 = OpLoad %ushort %30 Aligned 2
OpStore %9 %17
%42 = OpBitcast %_ptr_UniformConstant_ushort %1
%77 = OpBitcast %_ptr_UniformConstant_uchar %42
%78 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %77 %ulong_6
%32 = OpBitcast %_ptr_UniformConstant_ushort %78
%18 = OpLoad %ushort %32 Aligned 2
OpStore %10 %18
%19 = OpLoad %ulong %6
%20 = OpLoad %ushort %7
%43 = OpConvertUToPtr %_ptr_Generic_ushort %19
%44 = OpCopyObject %ushort %20
OpStore %43 %44 Aligned 2
%21 = OpLoad %ulong %6
%22 = OpLoad %ushort %8
%45 = OpConvertUToPtr %_ptr_Generic_ushort %21
%81 = OpBitcast %_ptr_Generic_uchar %45
%82 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %81 %ulong_2_0
%34 = OpBitcast %_ptr_Generic_ushort %82
%46 = OpCopyObject %ushort %22
OpStore %34 %46 Aligned 2
%23 = OpLoad %ulong %6
%24 = OpLoad %ushort %9
%47 = OpConvertUToPtr %_ptr_Generic_ushort %23
%83 = OpBitcast %_ptr_Generic_uchar %47
%84 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %83 %ulong_4_0
%36 = OpBitcast %_ptr_Generic_ushort %84
%48 = OpCopyObject %ushort %24
OpStore %36 %48 Aligned 2
%25 = OpLoad %ulong %6
%26 = OpLoad %ushort %10
%49 = OpConvertUToPtr %_ptr_Generic_ushort %25
%85 = OpBitcast %_ptr_Generic_uchar %49
%86 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %85 %ulong_6_0
%38 = OpBitcast %_ptr_Generic_ushort %86
%50 = OpCopyObject %ushort %26
OpStore %38 %50 Aligned 2
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry cvt_f64_f32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp_f32;
.reg .f64 temp_f64;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.global.f32 temp_f32, [in_addr];
cvt.ftz.f64.f32 temp_f64, temp_f32;
st.f64 [out_addr], temp_f64;
ret;
}

View file

@ -0,0 +1,55 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%22 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_f64_f32"
OpExecutionMode %1 DenormFlushToZero 16
OpExecutionMode %1 DenormFlushToZero 32
OpExecutionMode %1 DenormFlushToZero 64
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
%double = OpTypeFloat 64
%_ptr_Function_double = OpTypePointer Function %double
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
%_ptr_Generic_double = OpTypePointer Generic %double
%1 = OpFunction %void None %25
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_float Function
%7 = OpVariable %_ptr_Function_double Function
OpStore %2 %8
OpStore %3 %9
%10 = OpLoad %ulong %2 Aligned 8
OpStore %4 %10
%11 = OpLoad %ulong %3 Aligned 8
OpStore %5 %11
%13 = OpLoad %ulong %4
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %13
%12 = OpLoad %float %18 Aligned 4
OpStore %6 %12
%15 = OpLoad %float %6
%14 = OpFConvert %double %15
OpStore %7 %14
%16 = OpLoad %ulong %5
%17 = OpLoad %double %7
%19 = OpConvertUToPtr %_ptr_Generic_double %16
OpStore %19 %17 Aligned 8
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,26 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry cvt_s16_s8(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b32 temp_16;
.reg .b32 temp_8;
// inline asm
/*ptx_texBake_end*/
// inline asm
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.global.b32 temp_8, [in_addr];
cvt.s16.s8 temp_16, temp_8;
st.b32 [out_addr], temp_16;
ret;
}

View file

@ -0,0 +1,58 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_s16_s8"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%27 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
%uchar = OpTypeInt 8 0
%ushort = OpTypeInt 16 0
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %27
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%22 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%10 = OpLoad %ulong %2 Aligned 8
OpStore %4 %10
%11 = OpLoad %ulong %3 Aligned 8
OpStore %5 %11
%13 = OpLoad %ulong %4
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %13
%12 = OpLoad %uint %18 Aligned 4
OpStore %7 %12
%15 = OpLoad %uint %7
%32 = OpBitcast %uint %15
%34 = OpUConvert %uchar %32
%20 = OpCopyObject %uchar %34
%35 = OpBitcast %uchar %20
%37 = OpSConvert %ushort %35
%19 = OpCopyObject %ushort %37
%14 = OpSConvert %uint %19
OpStore %6 %14
%16 = OpLoad %ulong %5
%17 = OpLoad %uint %6
%21 = OpConvertUToPtr %_ptr_Generic_uint %16
OpStore %21 %17 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -204,6 +204,8 @@ test_ptx!(
test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]);
test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]);
test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -37,6 +37,10 @@ fn error_unreachable() -> TranslateError {
TranslateError::Unreachable
}
fn error_unknown_symbol() -> TranslateError {
TranslateError::UnknownSymbol
}
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
@ -3301,7 +3305,7 @@ fn emit_variable<'input>(
}
ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Const => todo!(),
ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
ast::StateSpace::Generic => todo!(),
ast::StateSpace::Sreg => todo!(),
};
@ -4168,6 +4172,7 @@ fn normalize_identifiers<'input, 'b>(
match s {
ast::Statement::Label(id) => {
id_defs.add_def(*id, None, false);
eprintln!("{}", id);
}
_ => (),
}
@ -4996,7 +5001,7 @@ impl<'input> GlobalStringIdResolver<'input> {
self.variables
.get(id)
.copied()
.ok_or(TranslateError::UnknownSymbol)
.ok_or_else(error_unknown_symbol)
}
fn current_id(&self) -> spirv::Word {
@ -5058,7 +5063,7 @@ pub struct GlobalFnDeclResolver<'input, 'a> {
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
self.fns.get(&id).ok_or_else(error_unknown_symbol)
}
}
@ -5099,8 +5104,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
match self.global_variables.get(id) {
Some(id) => Ok(*id),
None => {
let sreg =
PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?;
let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?;
Ok(self.special_registers.get_or_add(self.current_id, sreg))
}
}
@ -5778,25 +5782,13 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::FloatFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
let (dst_t, src_t, int_to_int) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => ((desc.dst, desc.src, false)),
ast::CvtDetails::FloatFromInt(desc) => ((desc.dst, desc.src, false)),
ast::CvtDetails::IntFromFloat(desc) => ((desc.dst, desc.src, false)),
ast::CvtDetails::IntFromInt(desc) => ((desc.dst, desc.src, true)),
};
ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t, int_to_int)?)
}
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
@ -6413,6 +6405,44 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
})
}
fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
dst_t: ast::ScalarType,
src_t: ast::ScalarType,
is_int_to_int: bool,
) -> Result<ast::Arg2<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_memory_access: false,
non_default_implicit_conversion: if is_int_to_int {
Some(should_convert_relaxed_dst_wrapper)
} else {
None
},
},
&ast::Type::Scalar(dst_t),
ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
is_memory_access: false,
non_default_implicit_conversion: if is_int_to_int {
Some(should_convert_relaxed_src_wrapper)
} else {
None
},
},
&ast::Type::Scalar(src_t),
ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 { dst, src })
}
fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,