Use calls to OpenCL builtins when translating sregs, do SPIRV->LLVM conversion on every build

This commit is contained in:
Andrzej Janik 2021-08-02 01:04:05 +02:00
parent 4a71fefb8a
commit b4de21fbc5
7 changed files with 278 additions and 81 deletions

View file

@ -61,6 +61,7 @@ test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_local, [12u64], [13u64]);

View file

@ -7,24 +7,27 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
%28 = OpExtInstImport "OpenCL.std"
%31 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
OpEntryPoint Kernel %1 "ntid"
OpExecutionMode %1 ContractionOff
OpDecorate %24 LinkageAttributes "get_local_size" Import
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%v3ulong = OpTypeVector %ulong 3
%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong
%gl_WorkGroupSize = OpVariable %_ptr_Input_v3ulong Input
%33 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%35 = OpTypeFunction %ulong %uint
%36 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %33
%uint_0 = OpConstant %uint 0
%24 = OpFunction %ulong None %35
%26 = OpFunctionParameter %uint
OpFunctionEnd
%1 = OpFunction %void None %36
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
%26 = OpLabel
%29 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@ -38,13 +41,12 @@
%12 = OpLoad %ulong %3 Aligned 8
OpStore %5 %12
%14 = OpLoad %ulong %4
%24 = OpConvertUToPtr %_ptr_Generic_uint %14
%13 = OpLoad %uint %24 Aligned 4
%27 = OpConvertUToPtr %_ptr_Generic_uint %14
%13 = OpLoad %uint %27 Aligned 4
OpStore %6 %13
%38 = OpLoad %v3ulong %gl_WorkGroupSize
%23 = OpCompositeExtract %ulong %38 0
%39 = OpBitcast %ulong %23
%16 = OpUConvert %uint %39
%23 = OpFunctionCall %ulong %24 %uint_0
%40 = OpBitcast %ulong %23
%16 = OpUConvert %uint %40
%15 = OpCopyObject %uint %16
OpStore %7 %15
%18 = OpLoad %uint %6
@ -53,7 +55,7 @@
OpStore %6 %17
%20 = OpLoad %ulong %5
%21 = OpLoad %uint %6
%25 = OpConvertUToPtr %_ptr_Generic_uint %20
OpStore %25 %21 Aligned 4
%28 = OpConvertUToPtr %_ptr_Generic_uint %20
OpStore %28 %21 Aligned 4
OpReturn
OpFunctionEnd
OpFunctionEnd

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_60
.address_size 64
.visible .entry vector4(
.param .u64 input_p,
.param .u64 output_p
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .v4 .u32 temp;
.reg .u32 temp_scalar;
ld.param.u64 in_addr, [input_p];
ld.param.u64 out_addr, [output_p];
ld.v4.u32 temp, [in_addr];
mov.b32 temp_scalar, temp.w;
st.u32 [out_addr], temp_scalar;
ret;
}

View file

@ -0,0 +1,99 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%51 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %25 "vector"
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%v2uint = OpTypeVector %uint 2
%55 = OpTypeFunction %v2uint %v2uint
%_ptr_Function_v2uint = OpTypePointer Function %v2uint
%_ptr_Function_uint = OpTypePointer Function %uint
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%ulong = OpTypeInt 64 0
%67 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
%1 = OpFunction %v2uint None %55
%7 = OpFunctionParameter %v2uint
%24 = OpLabel
%3 = OpVariable %_ptr_Function_v2uint Function
%2 = OpVariable %_ptr_Function_v2uint Function
%4 = OpVariable %_ptr_Function_v2uint Function
%5 = OpVariable %_ptr_Function_uint Function
%6 = OpVariable %_ptr_Function_uint Function
OpStore %3 %7
%59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0
%9 = OpLoad %uint %59
%8 = OpCopyObject %uint %9
OpStore %5 %8
%61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1
%11 = OpLoad %uint %61
%10 = OpCopyObject %uint %11
OpStore %6 %10
%13 = OpLoad %uint %5
%14 = OpLoad %uint %6
%12 = OpIAdd %uint %13 %14
OpStore %6 %12
%16 = OpLoad %uint %6
%15 = OpCopyObject %uint %16
%62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
OpStore %62 %15
%18 = OpLoad %uint %6
%17 = OpCopyObject %uint %18
%63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
OpStore %63 %17
%64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
%20 = OpLoad %uint %64
%19 = OpCopyObject %uint %20
%65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
OpStore %65 %19
%22 = OpLoad %v2uint %4
%21 = OpCopyObject %v2uint %22
OpStore %2 %21
%23 = OpLoad %v2uint %2
OpReturnValue %23
OpFunctionEnd
%25 = OpFunction %void None %67
%34 = OpFunctionParameter %ulong
%35 = OpFunctionParameter %ulong
%49 = OpLabel
%26 = OpVariable %_ptr_Function_ulong Function
%27 = OpVariable %_ptr_Function_ulong Function
%28 = OpVariable %_ptr_Function_ulong Function
%29 = OpVariable %_ptr_Function_ulong Function
%30 = OpVariable %_ptr_Function_v2uint Function
%31 = OpVariable %_ptr_Function_uint Function
%32 = OpVariable %_ptr_Function_uint Function
%33 = OpVariable %_ptr_Function_ulong Function
OpStore %26 %34
OpStore %27 %35
%36 = OpLoad %ulong %26 Aligned 8
OpStore %28 %36
%37 = OpLoad %ulong %27 Aligned 8
OpStore %29 %37
%39 = OpLoad %ulong %28
%46 = OpConvertUToPtr %_ptr_Generic_v2uint %39
%38 = OpLoad %v2uint %46 Aligned 8
OpStore %30 %38
%41 = OpLoad %v2uint %30
%40 = OpFunctionCall %v2uint %1 %41
OpStore %30 %40
%43 = OpLoad %v2uint %30
%47 = OpBitcast %ulong %43
%42 = OpCopyObject %ulong %47
OpStore %33 %42
%44 = OpLoad %ulong %29
%45 = OpLoad %v2uint %30
%48 = OpConvertUToPtr %_ptr_Generic_v2uint %44
OpStore %48 %45 Aligned 8
OpReturn
OpFunctionEnd

View file

@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
//emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
emit_directives(
@ -1250,7 +1250,8 @@ fn to_ssa<'input, 'b>(
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
let ssa_statements =
fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>(
}
fn fix_special_registers(
ptx_impl_imports: &mut HashMap<String, Directive>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@ -1276,7 +1278,6 @@ fn fix_special_registers(
for s in typed_statements {
match s {
Statement::LoadVar(
mut
details
@
LoadVarDetails {
@ -1285,48 +1286,53 @@ fn fix_special_registers(
},
) => {
let index = details.member_index.unwrap().0;
if index == 3 {
result.push(Statement::Constant(ConstantDefinition {
dst: details.arg.dst,
typ: ast::ScalarType::U32,
value: ast::ImmediateValue::U64(0),
}));
} else {
let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src)
{
Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg),
None => None,
};
let (sreg_src, scalar_typ, vector_width) = match sreg_and_type {
Some(sreg_and_type) => sreg_and_type,
None => {
result.push(Statement::LoadVar(details));
continue;
}
};
let temp_id = numeric_id_defs
.register_intermediate(Some((details.typ.clone(), details.state_space)));
let real_dst = details.arg.dst;
details.arg.dst = temp_id;
result.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
src: sreg_src,
dst: temp_id,
},
state_space: ast::StateSpace::Sreg,
typ: ast::Type::Scalar(scalar_typ),
member_index: Some((index, Some(vector_width))),
}));
result.push(Statement::Conversion(ImplicitConversion {
src: temp_id,
dst: real_dst,
from_type: ast::Type::Scalar(scalar_typ),
from_space: ast::StateSpace::Sreg,
to_type: ast::Type::Scalar(ast::ScalarType::U32),
to_space: ast::StateSpace::Sreg,
kind: ConversionKind::Default,
}));
}
let sreg = numeric_id_defs
.special_registers
.get(details.arg.src)
.ok_or_else(|| error_unreachable())?;
let (ocl_name, ocl_type) = sreg.get_opencl_fn_type();
let index_constant = numeric_id_defs.register_intermediate(Some((
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
)));
result.push(Statement::Constant(ConstantDefinition {
dst: index_constant,
typ: ast::ScalarType::U32,
value: ast::ImmediateValue::U64(index as u64),
}));
let fn_result = numeric_id_defs.register_intermediate(Some((
ast::Type::Scalar(ocl_type),
ast::StateSpace::Reg,
)));
let return_arguments =
vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)];
let input_arguments = vec![(
TypedOperand::Reg(index_constant),
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
)];
let fn_call = register_external_fn_call(
numeric_id_defs,
ptx_impl_imports,
ocl_name.to_string(),
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
)?;
result.push(Statement::Call(ResolvedCall {
uniform: false,
return_arguments,
name: fn_call,
input_arguments,
}));
result.push(Statement::Conversion(ImplicitConversion {
src: fn_result,
dst: details.arg.dst,
from_type: ast::Type::Scalar(ocl_type),
from_space: ast::StateSpace::Reg,
to_type: ast::Type::Scalar(ast::ScalarType::U32),
to_space: ast::StateSpace::Reg,
kind: ConversionKind::Default,
}));
}
s => result.push(s),
}
@ -1721,8 +1727,8 @@ fn instruction_to_fn_call(
id_defs,
ptx_impl_imports,
fn_name,
return_arguments,
input_arguments,
return_arguments.iter().map(|(_, typ, state)| (typ, *state)),
input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
)?;
Ok(Statement::Call(ResolvedCall {
uniform: false,
@ -1732,12 +1738,12 @@ fn instruction_to_fn_call(
}))
}
fn register_external_fn_call(
fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
@ -1770,19 +1776,18 @@ fn register_external_fn_call(
}
}
fn fn_arguments_to_variables(
fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<spirv::Word>> {
args.iter()
.map(|(_, typ, space)| ast::Variable {
align: None,
v_type: typ.clone(),
state_space: *space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
args.map(|(typ, space)| ast::Variable {
align: None,
v_type: typ.clone(),
state_space: space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}
fn arguments_to_resolved_arguments(
@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>(
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
Statement::Constant(_) => return Err(error_unreachable()),
Statement::Constant(c) => result.push(Statement::Constant(c)),
}
}
Ok(result)
@ -4686,6 +4691,19 @@ impl PtxSpecialRegister {
}
}
fn get_scalar_type(self) -> ast::ScalarType {
match self {
PtxSpecialRegister::Tid
| PtxSpecialRegister::Ntid
| PtxSpecialRegister::Ctaid
| PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
PtxSpecialRegister::Tid64
| PtxSpecialRegister::Ntid64
| PtxSpecialRegister::Ctaid64
| PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64,
}
}
fn get_builtin(self) -> spirv::BuiltIn {
match self {
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
@ -4701,6 +4719,23 @@ impl PtxSpecialRegister {
}
}
fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) {
match self {
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
("_Z12get_local_idj", ast::ScalarType::U64)
}
PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
("_Z14get_local_sizej", ast::ScalarType::U64)
}
PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => {
("_Z12get_group_idj", ast::ScalarType::U64)
}
PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
("_Z14get_num_groupsj", ast::ScalarType::U64)
}
}
}
fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
match self {
PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
@ -4743,6 +4778,8 @@ impl SpecialRegistersMap {
}
fn interface(&self) -> Vec<spirv::Word> {
return Vec::new();
/*
self.reg_to_id
.iter()
.filter_map(|(sreg, id)| {
@ -4753,6 +4790,7 @@ impl SpecialRegistersMap {
}
})
.collect::<Vec<_>>()
*/
}
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {

View file

@ -14,6 +14,7 @@ level_zero-sys = { path = "../level_zero-sys" }
lazy_static = "1.4"
num_enum = "0.4"
lz4-sys = "1.9"
tempfile = "3"
[dependencies.ocl-core]
version = "0.11"

View file

@ -4,8 +4,10 @@ use std::{
ffi::c_void,
ffi::CStr,
ffi::CString,
io::{self, Write},
mem,
os::raw::{c_char, c_int, c_uint},
process::{Command, Stdio},
ptr, slice,
};
@ -20,6 +22,7 @@ use super::{
CUresult, GlobalState, HasLivenessCookie, LiveCheck,
};
use ptx;
use tempfile::NamedTempFile;
pub type Module = LiveCheck<ModuleData>;
@ -88,6 +91,36 @@ impl SpirvModule {
})
}
const LLVM_SPIRV: &'static str = "/home/vosen/amd/llvm-project/build/bin/llvm-spirv";
const AMDGPU: &'static str = "/opt/amdgpu-pro/";
const AMDGPU_BITCODE: [&'static str; 8] = [
"opencl",
"ocml",
"ockl",
"oclc_correctly_rounded_sqrt_off",
"oclc_daz_opt_on",
"oclc_finite_only_off",
"oclc_unsafe_math_off",
"oclc_wavefrontsize64_off",
];
const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
const AMDGPU_DEVICE: &'static str = "gfx1010";
fn compile_amd(spirv_il: &[u8]) -> io::Result<()> {
let dir = tempfile::tempdir()?;
let mut spirv = NamedTempFile::new_in(&dir)?;
let llvm = NamedTempFile::new_in(&dir)?;
spirv.write_all(spirv_il)?;
let mut cmd = Command::new(Self::LLVM_SPIRV)
.arg("-r")
.arg("-o")
.arg(llvm.path())
.arg(spirv.path())
.status()?;
assert!(cmd.success());
Ok(())
}
pub fn compile<'a>(
&self,
ctx: &ocl_core::Context,
@ -99,6 +132,7 @@ impl SpirvModule {
self.binaries.len() * mem::size_of::<u32>(),
)
};
Self::compile_amd(byte_il).unwrap();
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
let main_module = match self.should_link_ptx_impl {
None => {