mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Use calls to OpenCL builtins when translating sregs, do SPIRV->LLVM conversion on every build
This commit is contained in:
parent
4a71fefb8a
commit
b4de21fbc5
7 changed files with 278 additions and 81 deletions
|
@ -61,6 +61,7 @@ test_ptx!(block, [1u64], [2u64]);
|
|||
test_ptx!(local_align, [1u64], [1u64]);
|
||||
test_ptx!(call, [1u64], [2u64]);
|
||||
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
||||
test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]);
|
||||
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
|
||||
test_ptx!(ntid, [3u32], [4u32]);
|
||||
test_ptx!(reg_local, [12u64], [13u64]);
|
||||
|
|
|
@ -7,24 +7,27 @@
|
|||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%28 = OpExtInstImport "OpenCL.std"
|
||||
%31 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize
|
||||
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
|
||||
OpEntryPoint Kernel %1 "ntid"
|
||||
OpExecutionMode %1 ContractionOff
|
||||
OpDecorate %24 LinkageAttributes "get_local_size" Import
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%v3ulong = OpTypeVector %ulong 3
|
||||
%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong
|
||||
%gl_WorkGroupSize = OpVariable %_ptr_Input_v3ulong Input
|
||||
%33 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%35 = OpTypeFunction %ulong %uint
|
||||
%36 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%1 = OpFunction %void None %33
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%24 = OpFunction %ulong None %35
|
||||
%26 = OpFunctionParameter %uint
|
||||
OpFunctionEnd
|
||||
%1 = OpFunction %void None %36
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%10 = OpFunctionParameter %ulong
|
||||
%26 = OpLabel
|
||||
%29 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
|
@ -38,13 +41,12 @@
|
|||
%12 = OpLoad %ulong %3 Aligned 8
|
||||
OpStore %5 %12
|
||||
%14 = OpLoad %ulong %4
|
||||
%24 = OpConvertUToPtr %_ptr_Generic_uint %14
|
||||
%13 = OpLoad %uint %24 Aligned 4
|
||||
%27 = OpConvertUToPtr %_ptr_Generic_uint %14
|
||||
%13 = OpLoad %uint %27 Aligned 4
|
||||
OpStore %6 %13
|
||||
%38 = OpLoad %v3ulong %gl_WorkGroupSize
|
||||
%23 = OpCompositeExtract %ulong %38 0
|
||||
%39 = OpBitcast %ulong %23
|
||||
%16 = OpUConvert %uint %39
|
||||
%23 = OpFunctionCall %ulong %24 %uint_0
|
||||
%40 = OpBitcast %ulong %23
|
||||
%16 = OpUConvert %uint %40
|
||||
%15 = OpCopyObject %uint %16
|
||||
OpStore %7 %15
|
||||
%18 = OpLoad %uint %6
|
||||
|
@ -53,7 +55,7 @@
|
|||
OpStore %6 %17
|
||||
%20 = OpLoad %ulong %5
|
||||
%21 = OpLoad %uint %6
|
||||
%25 = OpConvertUToPtr %_ptr_Generic_uint %20
|
||||
OpStore %25 %21 Aligned 4
|
||||
%28 = OpConvertUToPtr %_ptr_Generic_uint %20
|
||||
OpStore %28 %21 Aligned 4
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
OpFunctionEnd
|
22
ptx/src/test/spirv_run/vector4.ptx
Normal file
22
ptx/src/test/spirv_run/vector4.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_60
|
||||
.address_size 64
|
||||
|
||||
.visible .entry vector4(
|
||||
.param .u64 input_p,
|
||||
.param .u64 output_p
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .v4 .u32 temp;
|
||||
.reg .u32 temp_scalar;
|
||||
|
||||
ld.param.u64 in_addr, [input_p];
|
||||
ld.param.u64 out_addr, [output_p];
|
||||
|
||||
ld.v4.u32 temp, [in_addr];
|
||||
mov.b32 temp_scalar, temp.w;
|
||||
st.u32 [out_addr], temp_scalar;
|
||||
ret;
|
||||
}
|
99
ptx/src/test/spirv_run/vector4.spvtxt
Normal file
99
ptx/src/test/spirv_run/vector4.spvtxt
Normal file
|
@ -0,0 +1,99 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%51 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %25 "vector"
|
||||
%void = OpTypeVoid
|
||||
%uint = OpTypeInt 32 0
|
||||
%v2uint = OpTypeVector %uint 2
|
||||
%55 = OpTypeFunction %v2uint %v2uint
|
||||
%_ptr_Function_v2uint = OpTypePointer Function %v2uint
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%uint_1 = OpConstant %uint 1
|
||||
%ulong = OpTypeInt 64 0
|
||||
%67 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
|
||||
%1 = OpFunction %v2uint None %55
|
||||
%7 = OpFunctionParameter %v2uint
|
||||
%24 = OpLabel
|
||||
%3 = OpVariable %_ptr_Function_v2uint Function
|
||||
%2 = OpVariable %_ptr_Function_v2uint Function
|
||||
%4 = OpVariable %_ptr_Function_v2uint Function
|
||||
%5 = OpVariable %_ptr_Function_uint Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %3 %7
|
||||
%59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0
|
||||
%9 = OpLoad %uint %59
|
||||
%8 = OpCopyObject %uint %9
|
||||
OpStore %5 %8
|
||||
%61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1
|
||||
%11 = OpLoad %uint %61
|
||||
%10 = OpCopyObject %uint %11
|
||||
OpStore %6 %10
|
||||
%13 = OpLoad %uint %5
|
||||
%14 = OpLoad %uint %6
|
||||
%12 = OpIAdd %uint %13 %14
|
||||
OpStore %6 %12
|
||||
%16 = OpLoad %uint %6
|
||||
%15 = OpCopyObject %uint %16
|
||||
%62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
|
||||
OpStore %62 %15
|
||||
%18 = OpLoad %uint %6
|
||||
%17 = OpCopyObject %uint %18
|
||||
%63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
|
||||
OpStore %63 %17
|
||||
%64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
|
||||
%20 = OpLoad %uint %64
|
||||
%19 = OpCopyObject %uint %20
|
||||
%65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
|
||||
OpStore %65 %19
|
||||
%22 = OpLoad %v2uint %4
|
||||
%21 = OpCopyObject %v2uint %22
|
||||
OpStore %2 %21
|
||||
%23 = OpLoad %v2uint %2
|
||||
OpReturnValue %23
|
||||
OpFunctionEnd
|
||||
%25 = OpFunction %void None %67
|
||||
%34 = OpFunctionParameter %ulong
|
||||
%35 = OpFunctionParameter %ulong
|
||||
%49 = OpLabel
|
||||
%26 = OpVariable %_ptr_Function_ulong Function
|
||||
%27 = OpVariable %_ptr_Function_ulong Function
|
||||
%28 = OpVariable %_ptr_Function_ulong Function
|
||||
%29 = OpVariable %_ptr_Function_ulong Function
|
||||
%30 = OpVariable %_ptr_Function_v2uint Function
|
||||
%31 = OpVariable %_ptr_Function_uint Function
|
||||
%32 = OpVariable %_ptr_Function_uint Function
|
||||
%33 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %26 %34
|
||||
OpStore %27 %35
|
||||
%36 = OpLoad %ulong %26 Aligned 8
|
||||
OpStore %28 %36
|
||||
%37 = OpLoad %ulong %27 Aligned 8
|
||||
OpStore %29 %37
|
||||
%39 = OpLoad %ulong %28
|
||||
%46 = OpConvertUToPtr %_ptr_Generic_v2uint %39
|
||||
%38 = OpLoad %v2uint %46 Aligned 8
|
||||
OpStore %30 %38
|
||||
%41 = OpLoad %v2uint %30
|
||||
%40 = OpFunctionCall %v2uint %1 %41
|
||||
OpStore %30 %40
|
||||
%43 = OpLoad %v2uint %30
|
||||
%47 = OpBitcast %ulong %43
|
||||
%42 = OpCopyObject %ulong %47
|
||||
OpStore %33 %42
|
||||
%44 = OpLoad %ulong %29
|
||||
%45 = OpLoad %v2uint %30
|
||||
%48 = OpConvertUToPtr %_ptr_Generic_v2uint %44
|
||||
OpStore %48 %45 Aligned 8
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
|
|||
let opencl_id = emit_opencl_import(&mut builder);
|
||||
emit_memory_model(&mut builder);
|
||||
let mut map = TypeWordMap::new(&mut builder);
|
||||
emit_builtins(&mut builder, &mut map, &id_defs);
|
||||
//emit_builtins(&mut builder, &mut map, &id_defs);
|
||||
let mut kernel_info = HashMap::new();
|
||||
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
|
||||
emit_directives(
|
||||
|
@ -1250,7 +1250,8 @@ fn to_ssa<'input, 'b>(
|
|||
&mut numeric_id_defs,
|
||||
&mut (*func_decl).borrow_mut(),
|
||||
)?;
|
||||
let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
|
||||
let ssa_statements =
|
||||
fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?;
|
||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||
let expanded_statements =
|
||||
|
@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>(
|
|||
}
|
||||
|
||||
fn fix_special_registers(
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
typed_statements: Vec<TypedStatement>,
|
||||
numeric_id_defs: &mut NumericIdResolver,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
|
@ -1276,7 +1278,6 @@ fn fix_special_registers(
|
|||
for s in typed_statements {
|
||||
match s {
|
||||
Statement::LoadVar(
|
||||
mut
|
||||
details
|
||||
@
|
||||
LoadVarDetails {
|
||||
|
@ -1285,48 +1286,53 @@ fn fix_special_registers(
|
|||
},
|
||||
) => {
|
||||
let index = details.member_index.unwrap().0;
|
||||
if index == 3 {
|
||||
result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: details.arg.dst,
|
||||
typ: ast::ScalarType::U32,
|
||||
value: ast::ImmediateValue::U64(0),
|
||||
}));
|
||||
} else {
|
||||
let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src)
|
||||
{
|
||||
Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg),
|
||||
None => None,
|
||||
};
|
||||
let (sreg_src, scalar_typ, vector_width) = match sreg_and_type {
|
||||
Some(sreg_and_type) => sreg_and_type,
|
||||
None => {
|
||||
result.push(Statement::LoadVar(details));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let temp_id = numeric_id_defs
|
||||
.register_intermediate(Some((details.typ.clone(), details.state_space)));
|
||||
let real_dst = details.arg.dst;
|
||||
details.arg.dst = temp_id;
|
||||
result.push(Statement::LoadVar(LoadVarDetails {
|
||||
arg: Arg2 {
|
||||
src: sreg_src,
|
||||
dst: temp_id,
|
||||
},
|
||||
state_space: ast::StateSpace::Sreg,
|
||||
typ: ast::Type::Scalar(scalar_typ),
|
||||
member_index: Some((index, Some(vector_width))),
|
||||
}));
|
||||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: temp_id,
|
||||
dst: real_dst,
|
||||
from_type: ast::Type::Scalar(scalar_typ),
|
||||
from_space: ast::StateSpace::Sreg,
|
||||
to_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
to_space: ast::StateSpace::Sreg,
|
||||
kind: ConversionKind::Default,
|
||||
}));
|
||||
}
|
||||
let sreg = numeric_id_defs
|
||||
.special_registers
|
||||
.get(details.arg.src)
|
||||
.ok_or_else(|| error_unreachable())?;
|
||||
let (ocl_name, ocl_type) = sreg.get_opencl_fn_type();
|
||||
let index_constant = numeric_id_defs.register_intermediate(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: index_constant,
|
||||
typ: ast::ScalarType::U32,
|
||||
value: ast::ImmediateValue::U64(index as u64),
|
||||
}));
|
||||
let fn_result = numeric_id_defs.register_intermediate(Some((
|
||||
ast::Type::Scalar(ocl_type),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
let return_arguments =
|
||||
vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)];
|
||||
let input_arguments = vec![(
|
||||
TypedOperand::Reg(index_constant),
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let fn_call = register_external_fn_call(
|
||||
numeric_id_defs,
|
||||
ptx_impl_imports,
|
||||
ocl_name.to_string(),
|
||||
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||
)?;
|
||||
result.push(Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
return_arguments,
|
||||
name: fn_call,
|
||||
input_arguments,
|
||||
}));
|
||||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: fn_result,
|
||||
dst: details.arg.dst,
|
||||
from_type: ast::Type::Scalar(ocl_type),
|
||||
from_space: ast::StateSpace::Reg,
|
||||
to_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
to_space: ast::StateSpace::Reg,
|
||||
kind: ConversionKind::Default,
|
||||
}));
|
||||
}
|
||||
s => result.push(s),
|
||||
}
|
||||
|
@ -1721,8 +1727,8 @@ fn instruction_to_fn_call(
|
|||
id_defs,
|
||||
ptx_impl_imports,
|
||||
fn_name,
|
||||
return_arguments,
|
||||
input_arguments,
|
||||
return_arguments.iter().map(|(_, typ, state)| (typ, *state)),
|
||||
input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
|
||||
)?;
|
||||
Ok(Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
|
@ -1732,12 +1738,12 @@ fn instruction_to_fn_call(
|
|||
}))
|
||||
}
|
||||
|
||||
fn register_external_fn_call(
|
||||
fn register_external_fn_call<'a>(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
name: String,
|
||||
return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||
input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
match ptx_impl_imports.entry(name) {
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
|
@ -1770,19 +1776,18 @@ fn register_external_fn_call(
|
|||
}
|
||||
}
|
||||
|
||||
fn fn_arguments_to_variables(
|
||||
fn fn_arguments_to_variables<'a>(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
||||
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||
) -> Vec<ast::Variable<spirv::Word>> {
|
||||
args.iter()
|
||||
.map(|(_, typ, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: typ.clone(),
|
||||
state_space: *space,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
args.map(|(typ, space)| ast::Variable {
|
||||
align: None,
|
||||
v_type: typ.clone(),
|
||||
state_space: space,
|
||||
name: id_defs.register_intermediate(None),
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn arguments_to_resolved_arguments(
|
||||
|
@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>(
|
|||
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
|
||||
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
|
||||
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
|
||||
Statement::Constant(_) => return Err(error_unreachable()),
|
||||
Statement::Constant(c) => result.push(Statement::Constant(c)),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
|
@ -4686,6 +4691,19 @@ impl PtxSpecialRegister {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_scalar_type(self) -> ast::ScalarType {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid
|
||||
| PtxSpecialRegister::Ntid
|
||||
| PtxSpecialRegister::Ctaid
|
||||
| PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
|
||||
PtxSpecialRegister::Tid64
|
||||
| PtxSpecialRegister::Ntid64
|
||||
| PtxSpecialRegister::Ctaid64
|
||||
| PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_builtin(self) -> spirv::BuiltIn {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
|
||||
|
@ -4701,6 +4719,23 @@ impl PtxSpecialRegister {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
|
||||
("_Z12get_local_idj", ast::ScalarType::U64)
|
||||
}
|
||||
PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
|
||||
("_Z14get_local_sizej", ast::ScalarType::U64)
|
||||
}
|
||||
PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => {
|
||||
("_Z12get_group_idj", ast::ScalarType::U64)
|
||||
}
|
||||
PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
|
||||
("_Z14get_num_groupsj", ast::ScalarType::U64)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
|
||||
match self {
|
||||
PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
|
||||
|
@ -4743,6 +4778,8 @@ impl SpecialRegistersMap {
|
|||
}
|
||||
|
||||
fn interface(&self) -> Vec<spirv::Word> {
|
||||
return Vec::new();
|
||||
/*
|
||||
self.reg_to_id
|
||||
.iter()
|
||||
.filter_map(|(sreg, id)| {
|
||||
|
@ -4753,6 +4790,7 @@ impl SpecialRegistersMap {
|
|||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
*/
|
||||
}
|
||||
|
||||
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {
|
||||
|
|
|
@ -14,6 +14,7 @@ level_zero-sys = { path = "../level_zero-sys" }
|
|||
lazy_static = "1.4"
|
||||
num_enum = "0.4"
|
||||
lz4-sys = "1.9"
|
||||
tempfile = "3"
|
||||
|
||||
[dependencies.ocl-core]
|
||||
version = "0.11"
|
||||
|
|
|
@ -4,8 +4,10 @@ use std::{
|
|||
ffi::c_void,
|
||||
ffi::CStr,
|
||||
ffi::CString,
|
||||
io::{self, Write},
|
||||
mem,
|
||||
os::raw::{c_char, c_int, c_uint},
|
||||
process::{Command, Stdio},
|
||||
ptr, slice,
|
||||
};
|
||||
|
||||
|
@ -20,6 +22,7 @@ use super::{
|
|||
CUresult, GlobalState, HasLivenessCookie, LiveCheck,
|
||||
};
|
||||
use ptx;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
pub type Module = LiveCheck<ModuleData>;
|
||||
|
||||
|
@ -88,6 +91,36 @@ impl SpirvModule {
|
|||
})
|
||||
}
|
||||
|
||||
const LLVM_SPIRV: &'static str = "/home/vosen/amd/llvm-project/build/bin/llvm-spirv";
|
||||
const AMDGPU: &'static str = "/opt/amdgpu-pro/";
|
||||
const AMDGPU_BITCODE: [&'static str; 8] = [
|
||||
"opencl",
|
||||
"ocml",
|
||||
"ockl",
|
||||
"oclc_correctly_rounded_sqrt_off",
|
||||
"oclc_daz_opt_on",
|
||||
"oclc_finite_only_off",
|
||||
"oclc_unsafe_math_off",
|
||||
"oclc_wavefrontsize64_off",
|
||||
];
|
||||
const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
|
||||
const AMDGPU_DEVICE: &'static str = "gfx1010";
|
||||
|
||||
fn compile_amd(spirv_il: &[u8]) -> io::Result<()> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let mut spirv = NamedTempFile::new_in(&dir)?;
|
||||
let llvm = NamedTempFile::new_in(&dir)?;
|
||||
spirv.write_all(spirv_il)?;
|
||||
let mut cmd = Command::new(Self::LLVM_SPIRV)
|
||||
.arg("-r")
|
||||
.arg("-o")
|
||||
.arg(llvm.path())
|
||||
.arg(spirv.path())
|
||||
.status()?;
|
||||
assert!(cmd.success());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn compile<'a>(
|
||||
&self,
|
||||
ctx: &ocl_core::Context,
|
||||
|
@ -99,6 +132,7 @@ impl SpirvModule {
|
|||
self.binaries.len() * mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
Self::compile_amd(byte_il).unwrap();
|
||||
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
|
||||
let main_module = match self.should_link_ptx_impl {
|
||||
None => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue