Fix some unhandled cases in cvt instruction

This commit is contained in:
Andrzej Janik 2021-09-14 23:38:06 +00:00
parent 2cd0fcb650
commit 467782b1d0
9 changed files with 318 additions and 60 deletions

View file

@ -14,7 +14,7 @@ extern {
match { match {
r"\s+" => { }, r"\s+" => { },
r"//[^\n\r]*[\n\r]*" => { }, r"//[^\n\r]*[\n\r]*" => { },
r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { }, r"/\*[^*]*\*+(?:[^/*][^*]*\*+)*/" => { },
r"0[fF][0-9a-zA-Z]{8}" => F32NumToken, r"0[fF][0-9a-zA-Z]{8}" => F32NumToken,
r"0[dD][0-9a-zA-Z]{16}" => F64NumToken, r"0[dD][0-9a-zA-Z]{16}" => F64NumToken,
r"0[xX][0-9a-zA-Z]+U?" => HexNumToken, r"0[xX][0-9a-zA-Z]+U?" => HexNumToken,
@ -1143,11 +1143,11 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
} }
), a) ), a)
}, },
"cvt" <s:".sat"?> ".f64" ".f32" <a:Arg2> => { "cvt" <s:".sat"?> <f:".ftz"?> ".f64" ".f32" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc { ast::CvtDesc {
rounding: None, rounding: None,
flush_to_zero: None, flush_to_zero: Some(f.is_some()),
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::ScalarType::F64, dst: ast::ScalarType::F64,
src: ast::ScalarType::F32 src: ast::ScalarType::F32

View file

@ -19,10 +19,10 @@
ld.param.u64 in_addr, [input]; ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output]; ld.param.u64 out_addr, [output];
ld.const.b16 temp1, constparams[0]; ld.const.b16 temp1, [constparams];
ld.const.b16 temp2, constparams[1]; ld.const.b16 temp2, [constparams+2];
ld.const.b16 temp3, constparams[2]; ld.const.b16 temp3, [constparams+4];
ld.const.b16 temp4, constparams[3]; ld.const.b16 temp4, [constparams+6];
st.u16 [out_addr], temp1; st.u16 [out_addr], temp1;
st.u16 [out_addr+2], temp2; st.u16 [out_addr+2], temp2;
st.u16 [out_addr+4], temp3; st.u16 [out_addr+4], temp3;

View file

@ -7,41 +7,106 @@
OpCapability Int64 OpCapability Int64
OpCapability Float16 OpCapability Float16
OpCapability Float64 OpCapability Float64
%21 = OpExtInstImport "OpenCL.std" %53 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "clz" OpEntryPoint Kernel %2 "const" %1
OpExecutionMode %2 ContractionOff
OpDecorate %1 Alignment 8
%void = OpTypeVoid %void = OpTypeVoid
%ulong = OpTypeInt 64 0
%24 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0 %uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint %ushort = OpTypeInt 16 0
%_ptr_Generic_uint = OpTypePointer Generic %uint %uint_4 = OpConstant %uint 4
%1 = OpFunction %void None %24 %_arr_ushort_uint_4 = OpTypeArray %ushort %uint_4
%7 = OpFunctionParameter %ulong %ushort_10 = OpConstant %ushort 10
%8 = OpFunctionParameter %ulong %ushort_20 = OpConstant %ushort 20
%19 = OpLabel %ushort_30 = OpConstant %ushort 30
%2 = OpVariable %_ptr_Function_ulong Function %ushort_40 = OpConstant %ushort 40
%63 = OpConstantComposite %_arr_ushort_uint_4 %ushort_10 %ushort_20 %ushort_30 %ushort_40
%uint_4_0 = OpConstant %uint 4
%_ptr_UniformConstant__arr_ushort_uint_4 = OpTypePointer UniformConstant %_arr_ushort_uint_4
%1 = OpVariable %_ptr_UniformConstant__arr_ushort_uint_4 UniformConstant %63
%ulong = OpTypeInt 64 0
%67 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_ushort = OpTypePointer Function %ushort
%_ptr_UniformConstant_ushort = OpTypePointer UniformConstant %ushort
%ulong_2 = OpConstant %ulong 2
%uchar = OpTypeInt 8 0
%_ptr_UniformConstant_uchar = OpTypePointer UniformConstant %uchar
%ulong_4 = OpConstant %ulong 4
%ulong_6 = OpConstant %ulong 6
%_ptr_Generic_ushort = OpTypePointer Generic %ushort
%ulong_2_0 = OpConstant %ulong 2
%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_4_0 = OpConstant %ulong 4
%ulong_6_0 = OpConstant %ulong 6
%2 = OpFunction %void None %67
%11 = OpFunctionParameter %ulong
%12 = OpFunctionParameter %ulong
%51 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %7 %7 = OpVariable %_ptr_Function_ushort Function
OpStore %3 %8 %8 = OpVariable %_ptr_Function_ushort Function
%9 = OpLoad %ulong %2 Aligned 8 %9 = OpVariable %_ptr_Function_ushort Function
OpStore %4 %9 %10 = OpVariable %_ptr_Function_ushort Function
%10 = OpLoad %ulong %3 Aligned 8 OpStore %3 %11
OpStore %5 %10 OpStore %4 %12
%12 = OpLoad %ulong %4 %13 = OpLoad %ulong %3 Aligned 8
%17 = OpConvertUToPtr %_ptr_Generic_uint %12 OpStore %5 %13
%11 = OpLoad %uint %17 Aligned 4 %14 = OpLoad %ulong %4 Aligned 8
OpStore %6 %11 OpStore %6 %14
%14 = OpLoad %uint %6 %39 = OpBitcast %_ptr_UniformConstant_ushort %1
%13 = OpExtInst %uint %21 clz %14 %15 = OpLoad %ushort %39 Aligned 2
OpStore %6 %13 OpStore %7 %15
%15 = OpLoad %ulong %5 %40 = OpBitcast %_ptr_UniformConstant_ushort %1
%16 = OpLoad %uint %6 %73 = OpBitcast %_ptr_UniformConstant_uchar %40
%18 = OpConvertUToPtr %_ptr_Generic_uint %15 %74 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %73 %ulong_2
OpStore %18 %16 Aligned 4 %28 = OpBitcast %_ptr_UniformConstant_ushort %74
%16 = OpLoad %ushort %28 Aligned 2
OpStore %8 %16
%41 = OpBitcast %_ptr_UniformConstant_ushort %1
%75 = OpBitcast %_ptr_UniformConstant_uchar %41
%76 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %75 %ulong_4
%30 = OpBitcast %_ptr_UniformConstant_ushort %76
%17 = OpLoad %ushort %30 Aligned 2
OpStore %9 %17
%42 = OpBitcast %_ptr_UniformConstant_ushort %1
%77 = OpBitcast %_ptr_UniformConstant_uchar %42
%78 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %77 %ulong_6
%32 = OpBitcast %_ptr_UniformConstant_ushort %78
%18 = OpLoad %ushort %32 Aligned 2
OpStore %10 %18
%19 = OpLoad %ulong %6
%20 = OpLoad %ushort %7
%43 = OpConvertUToPtr %_ptr_Generic_ushort %19
%44 = OpCopyObject %ushort %20
OpStore %43 %44 Aligned 2
%21 = OpLoad %ulong %6
%22 = OpLoad %ushort %8
%45 = OpConvertUToPtr %_ptr_Generic_ushort %21
%81 = OpBitcast %_ptr_Generic_uchar %45
%82 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %81 %ulong_2_0
%34 = OpBitcast %_ptr_Generic_ushort %82
%46 = OpCopyObject %ushort %22
OpStore %34 %46 Aligned 2
%23 = OpLoad %ulong %6
%24 = OpLoad %ushort %9
%47 = OpConvertUToPtr %_ptr_Generic_ushort %23
%83 = OpBitcast %_ptr_Generic_uchar %47
%84 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %83 %ulong_4_0
%36 = OpBitcast %_ptr_Generic_ushort %84
%48 = OpCopyObject %ushort %24
OpStore %36 %48 Aligned 2
%25 = OpLoad %ulong %6
%26 = OpLoad %ushort %10
%49 = OpConvertUToPtr %_ptr_Generic_ushort %25
%85 = OpBitcast %_ptr_Generic_uchar %49
%86 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %85 %ulong_6_0
%38 = OpBitcast %_ptr_Generic_ushort %86
%50 = OpCopyObject %ushort %26
OpStore %38 %50 Aligned 2
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry cvt_f64_f32(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp_f32;
.reg .f64 temp_f64;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.global.f32 temp_f32, [in_addr];
cvt.ftz.f64.f32 temp_f64, temp_f32;
st.f64 [out_addr], temp_f64;
ret;
}

View file

@ -0,0 +1,55 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%22 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_f64_f32"
OpExecutionMode %1 DenormFlushToZero 16
OpExecutionMode %1 DenormFlushToZero 32
OpExecutionMode %1 DenormFlushToZero 64
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
%double = OpTypeFloat 64
%_ptr_Function_double = OpTypePointer Function %double
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
%_ptr_Generic_double = OpTypePointer Generic %double
%1 = OpFunction %void None %25
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_float Function
%7 = OpVariable %_ptr_Function_double Function
OpStore %2 %8
OpStore %3 %9
%10 = OpLoad %ulong %2 Aligned 8
OpStore %4 %10
%11 = OpLoad %ulong %3 Aligned 8
OpStore %5 %11
%13 = OpLoad %ulong %4
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %13
%12 = OpLoad %float %18 Aligned 4
OpStore %6 %12
%15 = OpLoad %float %6
%14 = OpFConvert %double %15
OpStore %7 %14
%16 = OpLoad %ulong %5
%17 = OpLoad %double %7
%19 = OpConvertUToPtr %_ptr_Generic_double %16
OpStore %19 %17 Aligned 8
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,26 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry cvt_s16_s8(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b32 temp_16;
.reg .b32 temp_8;
// inline asm
/*ptx_texBake_end*/
// inline asm
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.global.b32 temp_8, [in_addr];
cvt.s16.s8 temp_16, temp_8;
st.b32 [out_addr], temp_16;
ret;
}

View file

@ -0,0 +1,58 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_s16_s8"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%27 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
%uchar = OpTypeInt 8 0
%ushort = OpTypeInt 16 0
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %27
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%22 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %8
OpStore %3 %9
%10 = OpLoad %ulong %2 Aligned 8
OpStore %4 %10
%11 = OpLoad %ulong %3 Aligned 8
OpStore %5 %11
%13 = OpLoad %ulong %4
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %13
%12 = OpLoad %uint %18 Aligned 4
OpStore %7 %12
%15 = OpLoad %uint %7
%32 = OpBitcast %uint %15
%34 = OpUConvert %uchar %32
%20 = OpCopyObject %uchar %34
%35 = OpBitcast %uchar %20
%37 = OpSConvert %ushort %35
%19 = OpCopyObject %ushort %37
%14 = OpSConvert %uint %19
OpStore %6 %14
%16 = OpLoad %ulong %5
%17 = OpLoad %uint %6
%21 = OpConvertUToPtr %_ptr_Generic_uint %16
OpStore %21 %17 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -204,6 +204,8 @@ test_ptx!(
test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]); test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]); test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]);
test_ptx!(const, [0u16], [10u16, 20, 30, 40]); test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]);
test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
struct DisplayError<T: Debug> { struct DisplayError<T: Debug> {
err: T, err: T,

View file

@ -37,6 +37,10 @@ fn error_unreachable() -> TranslateError {
TranslateError::Unreachable TranslateError::Unreachable
} }
fn error_unknown_symbol() -> TranslateError {
TranslateError::UnknownSymbol
}
#[derive(PartialEq, Eq, Hash, Clone)] #[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType { enum SpirvType {
Base(SpirvScalarKey), Base(SpirvScalarKey),
@ -3301,7 +3305,7 @@ fn emit_variable<'input>(
} }
ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Const => todo!(), ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
ast::StateSpace::Generic => todo!(), ast::StateSpace::Generic => todo!(),
ast::StateSpace::Sreg => todo!(), ast::StateSpace::Sreg => todo!(),
}; };
@ -4168,6 +4172,7 @@ fn normalize_identifiers<'input, 'b>(
match s { match s {
ast::Statement::Label(id) => { ast::Statement::Label(id) => {
id_defs.add_def(*id, None, false); id_defs.add_def(*id, None, false);
eprintln!("{}", id);
} }
_ => (), _ => (),
} }
@ -4996,7 +5001,7 @@ impl<'input> GlobalStringIdResolver<'input> {
self.variables self.variables
.get(id) .get(id)
.copied() .copied()
.ok_or(TranslateError::UnknownSymbol) .ok_or_else(error_unknown_symbol)
} }
fn current_id(&self) -> spirv::Word { fn current_id(&self) -> spirv::Word {
@ -5058,7 +5063,7 @@ pub struct GlobalFnDeclResolver<'input, 'a> {
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> { fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) self.fns.get(&id).ok_or_else(error_unknown_symbol)
} }
} }
@ -5099,8 +5104,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
match self.global_variables.get(id) { match self.global_variables.get(id) {
Some(id) => Ok(*id), Some(id) => Ok(*id),
None => { None => {
let sreg = let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?;
PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?;
Ok(self.special_registers.get_or_add(self.current_id, sreg)) Ok(self.special_registers.get_or_add(self.current_id, sreg))
} }
} }
@ -5778,25 +5782,13 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?) ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
} }
ast::Instruction::Cvt(d, a) => { ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d { let (dst_t, src_t, int_to_int) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => ( ast::CvtDetails::FloatFromFloat(desc) => ((desc.dst, desc.src, false)),
ast::Type::Scalar(desc.dst.into()), ast::CvtDetails::FloatFromInt(desc) => ((desc.dst, desc.src, false)),
ast::Type::Scalar(desc.src.into()), ast::CvtDetails::IntFromFloat(desc) => ((desc.dst, desc.src, false)),
), ast::CvtDetails::IntFromInt(desc) => ((desc.dst, desc.src, true)),
ast::CvtDetails::FloatFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
}; };
ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?) ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t, int_to_int)?)
} }
ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?) ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
@ -6413,6 +6405,44 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
}) })
} }
fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
dst_t: ast::ScalarType,
src_t: ast::ScalarType,
is_int_to_int: bool,
) -> Result<ast::Arg2<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
is_memory_access: false,
non_default_implicit_conversion: if is_int_to_int {
Some(should_convert_relaxed_dst_wrapper)
} else {
None
},
},
&ast::Type::Scalar(dst_t),
ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
is_memory_access: false,
non_default_implicit_conversion: if is_int_to_int {
Some(should_convert_relaxed_src_wrapper)
} else {
None
},
},
&ast::Type::Scalar(src_t),
ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 { dst, src })
}
fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,
visitor: &mut V, visitor: &mut V,