mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Implement neg instruction
This commit is contained in:
parent
b7d61baf37
commit
e5a53ed5d3
6 changed files with 131 additions and 1 deletions
|
@ -542,6 +542,7 @@ pub enum Instruction<P: ArgParams> {
|
|||
Div(DivDetails, Arg3<P>),
|
||||
Sqrt(SqrtDetails, Arg2<P>),
|
||||
Rsqrt(RsqrtDetails, Arg2<P>),
|
||||
Neg(NegDetails, Arg2<P>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
@ -1183,6 +1184,12 @@ pub struct RsqrtDetails {
|
|||
pub flush_to_zero: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub struct NegDetails {
|
||||
pub typ: ScalarType,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
}
|
||||
|
||||
impl<'a> NumsOrArrays<'a> {
|
||||
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
|
||||
self.normalize_dimensions(dimensions)?;
|
||||
|
|
|
@ -156,6 +156,7 @@ match {
|
|||
"min",
|
||||
"mov",
|
||||
"mul",
|
||||
"neg",
|
||||
"not",
|
||||
"or",
|
||||
"rcp",
|
||||
|
@ -198,6 +199,7 @@ ExtendedID : &'input str = {
|
|||
"min",
|
||||
"mov",
|
||||
"mul",
|
||||
"neg",
|
||||
"not",
|
||||
"or",
|
||||
"rcp",
|
||||
|
@ -684,6 +686,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
InstDiv,
|
||||
InstSqrt,
|
||||
InstRsqrt,
|
||||
InstNeg,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||
|
@ -1577,6 +1580,39 @@ InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
},
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg
|
||||
InstNeg: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"neg" <ftz:".ftz"?> <typ:NegTypeFtz> <a:Arg2> => {
|
||||
let details = ast::NegDetails {
|
||||
typ,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
};
|
||||
ast::Instruction::Neg(details, a)
|
||||
},
|
||||
"neg" <typ:NegTypeNonFtz> <a:Arg2> => {
|
||||
let details = ast::NegDetails {
|
||||
typ,
|
||||
flush_to_zero: None,
|
||||
};
|
||||
ast::Instruction::Neg(details, a)
|
||||
},
|
||||
}
|
||||
|
||||
NegTypeFtz: ast::ScalarType = {
|
||||
".f16" => ast::ScalarType::F16,
|
||||
".f16x2" => ast::ScalarType::F16x2,
|
||||
".f32" => ast::ScalarType::F32,
|
||||
}
|
||||
|
||||
NegTypeNonFtz: ast::ScalarType = {
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
".f64" => ast::ScalarType::F64
|
||||
}
|
||||
|
||||
ArithDetails: ast::ArithDetails = {
|
||||
<t:UIntType> => ast::ArithDetails::Unsigned(t),
|
||||
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
|
|
|
@ -104,6 +104,7 @@ test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
|
|||
test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
|
||||
test_ptx!(sqrt, [0.25f32], [0.5f32]);
|
||||
test_ptx!(rsqrt, [0.25f64], [2f64]);
|
||||
test_ptx!(neg, [181i32], [-181i32]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
|
21
ptx/src/test/spirv_run/neg.ptx
Normal file
21
ptx/src/test/spirv_run/neg.ptx
Normal file
|
@ -0,0 +1,21 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry neg(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .s32 temp1;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.s32 temp1, [in_addr];
|
||||
neg.s32 temp1, temp1;
|
||||
st.s32 [out_addr], temp1;
|
||||
ret;
|
||||
}
|
47
ptx/src/test/spirv_run/neg.spvtxt
Normal file
47
ptx/src/test/spirv_run/neg.spvtxt
Normal file
|
@ -0,0 +1,47 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%26 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "not"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%29 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%1 = OpFunction %void None %29
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%24 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%20 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||
%14 = OpLoad %ulong %20
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %6
|
||||
%22 = OpCopyObject %ulong %17
|
||||
%21 = OpNot %ulong %22
|
||||
%16 = OpCopyObject %ulong %21
|
||||
OpStore %7 %16
|
||||
%18 = OpLoad %ulong %5
|
||||
%19 = OpLoad %ulong %7
|
||||
%23 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||
OpStore %23 %19
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -1511,6 +1511,9 @@ fn convert_to_typed_statements(
|
|||
ast::Instruction::Rsqrt(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
|
||||
}
|
||||
ast::Instruction::Neg(d, a) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast())))
|
||||
}
|
||||
},
|
||||
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||
|
@ -2805,6 +2808,15 @@ fn emit_function_body_ops(
|
|||
&[a.src],
|
||||
)?;
|
||||
}
|
||||
ast::Instruction::Neg(details, arg) => {
|
||||
let result_type = map.get_or_add_scalar(builder, details.typ);
|
||||
let negate_func = if details.typ.kind() == ScalarKind::Float {
|
||||
dr::Builder::f_negate
|
||||
} else {
|
||||
dr::Builder::s_negate
|
||||
};
|
||||
negate_func(builder, result_type, Some(arg.dst), arg.src)?;
|
||||
}
|
||||
},
|
||||
Statement::LoadVar(arg, typ) => {
|
||||
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
|
||||
|
@ -3406,7 +3418,7 @@ fn emit_setp(
|
|||
(ast::SetpCompareOp::NanGreaterOrEq, _) => {
|
||||
builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
_ => todo!()
|
||||
_ => todo!(),
|
||||
}?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -4678,6 +4690,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
ast::Instruction::Rsqrt(d, a) => {
|
||||
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
|
||||
}
|
||||
ast::Instruction::Neg(d, a) => {
|
||||
ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4984,6 +4999,9 @@ impl ast::Instruction<ExpandedArgParams> {
|
|||
details.flush_to_zero,
|
||||
ast::ScalarType::from(details.typ).size_of(),
|
||||
)),
|
||||
ast::Instruction::Neg(details, _) => details
|
||||
.flush_to_zero
|
||||
.map(|ftz| (ftz, details.typ.size_of())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue