Implement neg instruction

This commit is contained in:
Andrzej Janik 2020-11-01 14:58:44 +01:00
parent b7d61baf37
commit e5a53ed5d3
6 changed files with 131 additions and 1 deletions

View file

@ -542,6 +542,7 @@ pub enum Instruction<P: ArgParams> {
Div(DivDetails, Arg3<P>),
Sqrt(SqrtDetails, Arg2<P>),
Rsqrt(RsqrtDetails, Arg2<P>),
Neg(NegDetails, Arg2<P>),
}
#[derive(Copy, Clone)]
@ -1183,6 +1184,12 @@ pub struct RsqrtDetails {
pub flush_to_zero: bool,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct NegDetails {
pub typ: ScalarType,
pub flush_to_zero: Option<bool>,
}
impl<'a> NumsOrArrays<'a> {
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
self.normalize_dimensions(dimensions)?;

View file

@ -156,6 +156,7 @@ match {
"min",
"mov",
"mul",
"neg",
"not",
"or",
"rcp",
@ -198,6 +199,7 @@ ExtendedID : &'input str = {
"min",
"mov",
"mul",
"neg",
"not",
"or",
"rcp",
@ -684,6 +686,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstDiv,
InstSqrt,
InstRsqrt,
InstNeg,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@ -1577,6 +1580,39 @@ InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg
InstNeg: ast::Instruction<ast::ParsedArgParams<'input>> = {
"neg" <ftz:".ftz"?> <typ:NegTypeFtz> <a:Arg2> => {
let details = ast::NegDetails {
typ,
flush_to_zero: Some(ftz.is_some()),
};
ast::Instruction::Neg(details, a)
},
"neg" <typ:NegTypeNonFtz> <a:Arg2> => {
let details = ast::NegDetails {
typ,
flush_to_zero: None,
};
ast::Instruction::Neg(details, a)
},
}
NegTypeFtz: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
".f32" => ast::ScalarType::F32,
}
NegTypeNonFtz: ast::ScalarType = {
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
".f64" => ast::ScalarType::F64
}
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {

View file

@ -104,6 +104,7 @@ test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
test_ptx!(sqrt, [0.25f32], [0.5f32]);
test_ptx!(rsqrt, [0.25f64], [2f64]);
test_ptx!(neg, [181i32], [-181i32]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry neg(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp1;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp1, [in_addr];
neg.s32 temp1, temp1;
st.s32 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,47 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%26 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "not"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%29 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%1 = OpFunction %void None %29
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%24 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%20 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %20
OpStore %6 %14
%17 = OpLoad %ulong %6
%22 = OpCopyObject %ulong %17
%21 = OpNot %ulong %22
%16 = OpCopyObject %ulong %21
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%23 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %23 %19
OpReturn
OpFunctionEnd

View file

@ -1511,6 +1511,9 @@ fn convert_to_typed_statements(
ast::Instruction::Rsqrt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
}
ast::Instruction::Neg(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast())))
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@ -2805,6 +2808,15 @@ fn emit_function_body_ops(
&[a.src],
)?;
}
ast::Instruction::Neg(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ);
let negate_func = if details.typ.kind() == ScalarKind::Float {
dr::Builder::f_negate
} else {
dr::Builder::s_negate
};
negate_func(builder, result_type, Some(arg.dst), arg.src)?;
}
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@ -3406,7 +3418,7 @@ fn emit_setp(
(ast::SetpCompareOp::NanGreaterOrEq, _) => {
builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
_ => todo!()
_ => todo!(),
}?;
Ok(())
}
@ -4678,6 +4690,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Rsqrt(d, a) => {
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
}
ast::Instruction::Neg(d, a) => {
ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?)
}
})
}
}
@ -4984,6 +4999,9 @@ impl ast::Instruction<ExpandedArgParams> {
details.flush_to_zero,
ast::ScalarType::from(details.typ).size_of(),
)),
ast::Instruction::Neg(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, details.typ.size_of())),
}
}
}