mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Start implementing .shared unification
This commit is contained in:
parent
9609f86033
commit
370c0bd09e
5 changed files with 305 additions and 119 deletions
|
@ -1970,6 +1970,9 @@ ArgCall: (Vec<&'input str>, &'input str, Vec<ast::Operand<&'input str>>) = {
|
||||||
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => {
|
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => {
|
||||||
(ret_params, func, param_list)
|
(ret_params, func, param_list)
|
||||||
},
|
},
|
||||||
|
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> => {
|
||||||
|
(ret_params, func, Vec::new())
|
||||||
|
},
|
||||||
<func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list),
|
<func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list),
|
||||||
<func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()),
|
<func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()),
|
||||||
};
|
};
|
||||||
|
|
|
@ -221,6 +221,8 @@ test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
|
||||||
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
|
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
|
||||||
test_ptx!(activemask, [0u32], [1u32]);
|
test_ptx!(activemask, [0u32], [1u32]);
|
||||||
test_ptx!(membar, [152731u32], [152731u32]);
|
test_ptx!(membar, [152731u32], [152731u32]);
|
||||||
|
test_ptx!(shared_unify_extern, [7681u64], [15362u64]);
|
||||||
|
|
||||||
test_ptx!(func_ptr);
|
test_ptx!(func_ptr);
|
||||||
test_ptx!(lanemask_lt);
|
test_ptx!(lanemask_lt);
|
||||||
test_ptx!(extern_func);
|
test_ptx!(extern_func);
|
||||||
|
|
34
ptx/src/test/spirv_run/shared_unify_extern.ptx
Normal file
34
ptx/src/test/spirv_run/shared_unify_extern.ptx
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.extern .shared .b32 shared_ex[];
|
||||||
|
.shared .b32 shared_mod[4];
|
||||||
|
|
||||||
|
|
||||||
|
.func (.reg .b64 out) load_from_shared()
|
||||||
|
{
|
||||||
|
ld.shared.u64 out, [shared_mod];
|
||||||
|
ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
.visible .entry shared_unify_extern(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .u64 temp1;
|
||||||
|
.reg .u64 temp2;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.u64 temp1, [in_addr];
|
||||||
|
st.shared.u64 [shared_ex], temp1;
|
||||||
|
call (temp2), load_from_shared;
|
||||||
|
add.u64 temp2, temp2, temp1;
|
||||||
|
st.u64 [out_addr], temp2;
|
||||||
|
ret;
|
||||||
|
}
|
62
ptx/src/test/spirv_run/shared_unify_extern.spvtxt
Normal file
62
ptx/src/test/spirv_run/shared_unify_extern.spvtxt
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int8
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Float64
|
||||||
|
%30 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
|
||||||
|
OpExecutionMode %2 ContractionOff
|
||||||
|
OpDecorate %1 Alignment 4
|
||||||
|
OpDecorate %1 LinkageAttributes "shared_mem" Import
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%uchar = OpTypeInt 8 0
|
||||||
|
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
|
||||||
|
%1 = OpVariable %_ptr_Workgroup_uchar Workgroup
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%35 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
|
||||||
|
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
|
||||||
|
%2 = OpFunction %void None %35
|
||||||
|
%10 = OpFunctionParameter %ulong
|
||||||
|
%11 = OpFunctionParameter %ulong
|
||||||
|
%28 = OpLabel
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%7 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%8 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%9 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %3 %10
|
||||||
|
OpStore %4 %11
|
||||||
|
%12 = OpLoad %ulong %3 Aligned 8
|
||||||
|
OpStore %5 %12
|
||||||
|
%13 = OpLoad %ulong %4 Aligned 8
|
||||||
|
OpStore %6 %13
|
||||||
|
%23 = OpConvertPtrToU %ulong %1
|
||||||
|
%14 = OpCopyObject %ulong %23
|
||||||
|
OpStore %7 %14
|
||||||
|
%16 = OpLoad %ulong %5
|
||||||
|
%24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
|
||||||
|
%15 = OpLoad %ulong %24 Aligned 8
|
||||||
|
OpStore %8 %15
|
||||||
|
%17 = OpLoad %ulong %7
|
||||||
|
%18 = OpLoad %ulong %8
|
||||||
|
%25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17
|
||||||
|
OpStore %25 %18 Aligned 8
|
||||||
|
%20 = OpLoad %ulong %7
|
||||||
|
%26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
|
||||||
|
%19 = OpLoad %ulong %26 Aligned 8
|
||||||
|
OpStore %9 %19
|
||||||
|
%21 = OpLoad %ulong %6
|
||||||
|
%22 = OpLoad %ulong %9
|
||||||
|
%27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
|
||||||
|
OpStore %27 %22 Aligned 8
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
|
@ -443,7 +443,8 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
|
||||||
let mut builder = dr::Builder::new();
|
let mut builder = dr::Builder::new();
|
||||||
builder.reserve_ids(id_defs.current_id());
|
builder.reserve_ids(id_defs.current_id());
|
||||||
let call_map = get_kernels_call_map(&directives);
|
let call_map = get_kernels_call_map(&directives);
|
||||||
//let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
|
let mut directives =
|
||||||
|
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
|
||||||
normalize_variable_decls(&mut directives);
|
normalize_variable_decls(&mut directives);
|
||||||
let denorm_information = compute_denorm_information(&directives);
|
let denorm_information = compute_denorm_information(&directives);
|
||||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||||
|
@ -607,7 +608,7 @@ fn emit_directives<'input>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
|
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
|
||||||
emit_function_linkage(builder, id_defs, f, fn_id);
|
emit_function_linkage(builder, id_defs, f, fn_id)?;
|
||||||
builder.select_block(None)?;
|
builder.select_block(None)?;
|
||||||
builder.end_function()?;
|
builder.end_function()?;
|
||||||
}
|
}
|
||||||
|
@ -683,7 +684,7 @@ fn get_kernels_call_map<'input>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_call_map_single<'input>(
|
fn add_call_map_single<'input>(
|
||||||
directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
|
directly_called_by: &HashMap<ast::MethodName<'input, spirv::Word>, Vec<spirv::Word>>,
|
||||||
visited: &mut HashSet<spirv::Word>,
|
visited: &mut HashSet<spirv::Word>,
|
||||||
current: spirv::Word,
|
current: spirv::Word,
|
||||||
) {
|
) {
|
||||||
|
@ -697,15 +698,21 @@ fn add_call_map_single<'input>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
|
fn multi_hash_map_append<
|
||||||
|
K: Eq + std::hash::Hash,
|
||||||
fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
|
V,
|
||||||
|
Collection: std::iter::Extend<V> + std::default::Default,
|
||||||
|
>(
|
||||||
|
m: &mut HashMap<K, Collection>,
|
||||||
|
key: K,
|
||||||
|
value: V,
|
||||||
|
) {
|
||||||
match m.entry(key) {
|
match m.entry(key) {
|
||||||
hash_map::Entry::Occupied(mut entry) => {
|
hash_map::Entry::Occupied(mut entry) => {
|
||||||
entry.get_mut().push(value);
|
entry.get_mut().extend(iter::once(value));
|
||||||
}
|
}
|
||||||
hash_map::Entry::Vacant(entry) => {
|
hash_map::Entry::Vacant(entry) => {
|
||||||
entry.insert(vec![value]);
|
entry.insert(Default::default());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -713,7 +720,8 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
|
||||||
/*
|
/*
|
||||||
PTX represents dynamically allocated shared local memory as
|
PTX represents dynamically allocated shared local memory as
|
||||||
.extern .shared .b32 shared_mem[];
|
.extern .shared .b32 shared_mem[];
|
||||||
In SPIRV/OpenCL world this is expressed as an additional argument
|
In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
|
||||||
|
And in AMD compilation
|
||||||
This pass looks for all uses of .extern .shared and converts them to
|
This pass looks for all uses of .extern .shared and converts them to
|
||||||
an additional method argument
|
an additional method argument
|
||||||
The question is how this artificial argument should be expressed. There are
|
The question is how this artificial argument should be expressed. There are
|
||||||
|
@ -735,30 +743,35 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
|
||||||
*/
|
*/
|
||||||
fn convert_dynamic_shared_memory_usage<'input>(
|
fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
module: Vec<Directive<'input>>,
|
module: Vec<Directive<'input>>,
|
||||||
|
kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
|
||||||
new_id: &mut impl FnMut() -> spirv::Word,
|
new_id: &mut impl FnMut() -> spirv::Word,
|
||||||
) -> Vec<Directive<'input>> {
|
) -> Vec<Directive<'input>> {
|
||||||
let mut extern_shared_decls = HashMap::new();
|
let mut globals_shared = HashMap::new();
|
||||||
for dir in module.iter() {
|
for dir in module.iter() {
|
||||||
match dir {
|
match dir {
|
||||||
Directive::Variable(
|
Directive::Variable(
|
||||||
linking,
|
linking,
|
||||||
ast::Variable {
|
ast::Variable {
|
||||||
v_type: ast::Type::Array(p_type, dims),
|
|
||||||
state_space: ast::StateSpace::Shared,
|
state_space: ast::StateSpace::Shared,
|
||||||
name,
|
name,
|
||||||
|
v_type,
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
|
) => {
|
||||||
extern_shared_decls.insert(*name, *p_type);
|
let size = if linking.contains(ast::LinkingDirective::EXTERN) {
|
||||||
|
GlobalSharedSize::ExternUnsized
|
||||||
|
} else {
|
||||||
|
GlobalSharedSize::Sized((*v_type).size_of())
|
||||||
|
};
|
||||||
|
globals_shared.insert(*name, (size, v_type.clone()));
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if extern_shared_decls.len() == 0 {
|
if globals_shared.len() == 0 {
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
let mut methods_using_extern_shared = HashSet::new();
|
let mut methods_to_globals_shared_direct_only_use = HashMap::<_, GlobalSharedSize>::new();
|
||||||
let mut directly_called_by = MultiHashMap::new();
|
|
||||||
let module = module
|
let module = module
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| match directive {
|
.map(|directive| match directive {
|
||||||
|
@ -773,17 +786,21 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
let call_key = (*func_decl).borrow().name;
|
let call_key = (*func_decl).borrow().name;
|
||||||
let statements = statements
|
let statements = statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|statement| match statement {
|
.map(|statement| {
|
||||||
Statement::Call(call) => {
|
statement.map_id(&mut |id, _| {
|
||||||
multi_hash_map_append(&mut directly_called_by, call.name, call_key);
|
if let Some((size, _)) = globals_shared.get(&id) {
|
||||||
Statement::Call(call)
|
match methods_to_globals_shared_direct_only_use.entry(call_key) {
|
||||||
}
|
hash_map::Entry::Occupied(mut e) => {
|
||||||
statement => statement.map_id(&mut |id, _| {
|
let original_size = *e.get();
|
||||||
if extern_shared_decls.contains_key(&id) {
|
e.insert(original_size.fold(*size));
|
||||||
methods_using_extern_shared.insert(call_key);
|
}
|
||||||
|
hash_map::Entry::Vacant(mut e) => {
|
||||||
|
e.insert(*size);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
id
|
id
|
||||||
}),
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
Directive::Method(Function {
|
Directive::Method(Function {
|
||||||
|
@ -800,11 +817,15 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
|
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
|
||||||
// make sure it gets propagated to `fn1` and `kernel`
|
// make sure it gets propagated to `fn1` and `kernel`
|
||||||
get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
|
let (kernels_to_global_shared, functions_to_global_shared) =
|
||||||
|
resolve_indirect_uses_of_globals_shared(
|
||||||
|
methods_to_globals_shared_direct_only_use,
|
||||||
|
kernels_methods_call_map,
|
||||||
|
);
|
||||||
// now visit every method declaration and inject those additional arguments
|
// now visit every method declaration and inject those additional arguments
|
||||||
module
|
let mut result = Vec::with_capacity(module.len());
|
||||||
.into_iter()
|
for directive in module.into_iter() {
|
||||||
.map(|directive| match directive {
|
match directive {
|
||||||
Directive::Method(Function {
|
Directive::Method(Function {
|
||||||
func_decl,
|
func_decl,
|
||||||
globals,
|
globals,
|
||||||
|
@ -813,46 +834,119 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
tuning,
|
tuning,
|
||||||
linkage,
|
linkage,
|
||||||
}) => {
|
}) => {
|
||||||
if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
|
let statements = {
|
||||||
return Directive::Method(Function {
|
let func_decl_ref = &mut (*func_decl).borrow_mut();
|
||||||
func_decl,
|
let method_name = func_decl_ref.name;
|
||||||
globals,
|
insert_arguments_remap_statements(
|
||||||
body: Some(statements),
|
method_name,
|
||||||
import_as,
|
&kernels_to_global_shared,
|
||||||
tuning,
|
new_id,
|
||||||
linkage,
|
&mut result,
|
||||||
});
|
&functions_to_global_shared,
|
||||||
}
|
func_decl_ref,
|
||||||
let shared_id_param = new_id();
|
&globals_shared,
|
||||||
{
|
statements,
|
||||||
let mut func_decl = (*func_decl).borrow_mut();
|
)
|
||||||
func_decl.shared_mem = Some(shared_id_param);
|
};
|
||||||
}
|
result.push(Directive::Method(Function {
|
||||||
let statements = replace_uses_of_shared_memory(
|
|
||||||
new_id,
|
|
||||||
&extern_shared_decls,
|
|
||||||
&mut methods_using_extern_shared,
|
|
||||||
shared_id_param,
|
|
||||||
statements,
|
|
||||||
);
|
|
||||||
Directive::Method(Function {
|
|
||||||
func_decl,
|
func_decl,
|
||||||
globals,
|
globals,
|
||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
import_as,
|
import_as,
|
||||||
tuning,
|
tuning,
|
||||||
linkage,
|
linkage,
|
||||||
})
|
}));
|
||||||
}
|
}
|
||||||
directive => directive,
|
directive => result.push(directive),
|
||||||
})
|
}
|
||||||
.collect::<Vec<_>>()
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_arguments_remap_statements(
|
||||||
|
method_name: ast::MethodName<u32>,
|
||||||
|
kernels_to_global_shared: &HashMap<&str, GlobalSharedSize>,
|
||||||
|
new_id: &mut impl FnMut() -> u32,
|
||||||
|
result: &mut Vec<Directive>,
|
||||||
|
functions_to_global_shared: &HashSet<u32>,
|
||||||
|
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<u32>>,
|
||||||
|
globals_shared: &HashMap<u32, (GlobalSharedSize, ast::Type)>,
|
||||||
|
statements: Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>,
|
||||||
|
) -> Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>> {
|
||||||
|
let shared_id_param = match method_name {
|
||||||
|
ast::MethodName::Kernel(kernel_name) => {
|
||||||
|
let globals_shared_size = match kernels_to_global_shared.get(kernel_name) {
|
||||||
|
Some(s) => *s,
|
||||||
|
None => return statements,
|
||||||
|
};
|
||||||
|
let shared_id_param = new_id();
|
||||||
|
let (linkage, type_) = match globals_shared_size {
|
||||||
|
GlobalSharedSize::ExternUnsized => (
|
||||||
|
ast::LinkingDirective::EXTERN,
|
||||||
|
ast::Type::Array(ast::ScalarType::U8, Vec::new()),
|
||||||
|
),
|
||||||
|
GlobalSharedSize::Sized(size) => (
|
||||||
|
ast::LinkingDirective::NONE,
|
||||||
|
ast::Type::Array(ast::ScalarType::U8, vec![size as u32]),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
result.push(Directive::Variable(
|
||||||
|
linkage,
|
||||||
|
ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: type_,
|
||||||
|
state_space: ast::StateSpace::Shared,
|
||||||
|
name: shared_id_param,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
},
|
||||||
|
));
|
||||||
|
shared_id_param
|
||||||
|
}
|
||||||
|
ast::MethodName::Func(function_name) => {
|
||||||
|
if !functions_to_global_shared.contains(&function_name) {
|
||||||
|
return statements;
|
||||||
|
}
|
||||||
|
let shared_id_param = new_id();
|
||||||
|
func_decl_ref.input_arguments.push(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: ast::Type::Pointer(ast::ScalarType::B8, ast::StateSpace::Shared),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
name: shared_id_param,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
});
|
||||||
|
shared_id_param
|
||||||
|
}
|
||||||
|
};
|
||||||
|
replace_uses_of_shared_memory(
|
||||||
|
new_id,
|
||||||
|
globals_shared,
|
||||||
|
functions_to_global_shared,
|
||||||
|
shared_id_param,
|
||||||
|
statements,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
enum GlobalSharedSize {
|
||||||
|
ExternUnsized,
|
||||||
|
Sized(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GlobalSharedSize {
|
||||||
|
fn fold(self, other: GlobalSharedSize) -> GlobalSharedSize {
|
||||||
|
match (self, other) {
|
||||||
|
(GlobalSharedSize::Sized(s1), GlobalSharedSize::Sized(s2)) => {
|
||||||
|
GlobalSharedSize::Sized(usize::max(s1, s2))
|
||||||
|
}
|
||||||
|
_ => GlobalSharedSize::ExternUnsized,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn replace_uses_of_shared_memory<'a>(
|
fn replace_uses_of_shared_memory<'a>(
|
||||||
new_id: &mut impl FnMut() -> spirv::Word,
|
new_id: &mut impl FnMut() -> spirv::Word,
|
||||||
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
|
extern_shared_decls: &HashMap<spirv::Word, (GlobalSharedSize, ast::Type)>,
|
||||||
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
|
methods_using_extern_shared: &HashSet<spirv::Word>,
|
||||||
shared_id_param: spirv::Word,
|
shared_id_param: spirv::Word,
|
||||||
statements: Vec<ExpandedStatement>,
|
statements: Vec<ExpandedStatement>,
|
||||||
) -> Vec<ExpandedStatement> {
|
) -> Vec<ExpandedStatement> {
|
||||||
|
@ -863,7 +957,7 @@ fn replace_uses_of_shared_memory<'a>(
|
||||||
// We can safely skip checking call arguments,
|
// We can safely skip checking call arguments,
|
||||||
// because there's simply no way to pass shared ptr
|
// because there's simply no way to pass shared ptr
|
||||||
// without converting it to .b64 first
|
// without converting it to .b64 first
|
||||||
if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
|
if methods_using_extern_shared.contains(&call.name) {
|
||||||
call.input_arguments.push((
|
call.input_arguments.push((
|
||||||
shared_id_param,
|
shared_id_param,
|
||||||
ast::Type::Scalar(ast::ScalarType::B8),
|
ast::Type::Scalar(ast::ScalarType::B8),
|
||||||
|
@ -874,8 +968,8 @@ fn replace_uses_of_shared_memory<'a>(
|
||||||
}
|
}
|
||||||
statement => {
|
statement => {
|
||||||
let new_statement = statement.map_id(&mut |id, _| {
|
let new_statement = statement.map_id(&mut |id, _| {
|
||||||
if let Some(scalar_type) = extern_shared_decls.get(&id) {
|
if let Some((_, type_)) = extern_shared_decls.get(&id) {
|
||||||
if *scalar_type == ast::ScalarType::B8 {
|
if *type_ == ast::Type::Scalar(ast::ScalarType::B8) {
|
||||||
return shared_id_param;
|
return shared_id_param;
|
||||||
}
|
}
|
||||||
let replacement_id = new_id();
|
let replacement_id = new_id();
|
||||||
|
@ -884,7 +978,7 @@ fn replace_uses_of_shared_memory<'a>(
|
||||||
dst: replacement_id,
|
dst: replacement_id,
|
||||||
from_type: ast::Type::Scalar(ast::ScalarType::B8),
|
from_type: ast::Type::Scalar(ast::ScalarType::B8),
|
||||||
from_space: ast::StateSpace::Shared,
|
from_space: ast::StateSpace::Shared,
|
||||||
to_type: ast::Type::Scalar(*scalar_type),
|
to_type: type_.clone(),
|
||||||
to_space: ast::StateSpace::Shared,
|
to_space: ast::StateSpace::Shared,
|
||||||
kind: ConversionKind::PtrToPtr,
|
kind: ConversionKind::PtrToPtr,
|
||||||
}));
|
}));
|
||||||
|
@ -900,43 +994,40 @@ fn replace_uses_of_shared_memory<'a>(
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_callers_of_extern_shared<'a>(
|
// We need to compute two kinds of information:
|
||||||
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
|
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
|
||||||
directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
|
// * If it's a function -> does it use .shared global (directly or indirectly)
|
||||||
) {
|
fn resolve_indirect_uses_of_globals_shared<'input>(
|
||||||
let direct_uses_of_extern_shared = methods_using_extern_shared
|
mut methods_use_of_globals_shared: HashMap<
|
||||||
.iter()
|
ast::MethodName<'input, spirv::Word>,
|
||||||
.filter_map(|method| {
|
GlobalSharedSize,
|
||||||
if let ast::MethodName::Func(f_id) = method {
|
>,
|
||||||
Some(*f_id)
|
kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
|
||||||
} else {
|
) -> (HashMap<&'input str, GlobalSharedSize>, HashSet<spirv::Word>) {
|
||||||
None
|
let mut kernel_use = HashMap::new();
|
||||||
}
|
let mut functions_using_global = HashSet::new();
|
||||||
})
|
let empty = HashSet::new();
|
||||||
.collect::<Vec<_>>();
|
for (method, globals) in methods_use_of_globals_shared.iter() {
|
||||||
for fn_id in direct_uses_of_extern_shared {
|
match method {
|
||||||
get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
|
ast::MethodName::Kernel(kernel_name) => {
|
||||||
}
|
let mut size = *globals;
|
||||||
}
|
for &called_subfunction in
|
||||||
|
kernels_methods_call_map.get(kernel_name).unwrap_or(&empty)
|
||||||
fn get_callers_of_extern_shared_single<'a>(
|
{
|
||||||
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
|
if let Some(new_size) = methods_use_of_globals_shared
|
||||||
directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
|
.get(&ast::MethodName::Func(called_subfunction))
|
||||||
fn_id: spirv::Word,
|
{
|
||||||
) {
|
size = size.fold(*new_size);
|
||||||
if let Some(callers) = directly_called_by.get(&fn_id) {
|
}
|
||||||
for caller in callers {
|
|
||||||
if methods_using_extern_shared.insert(*caller) {
|
|
||||||
if let ast::MethodName::Func(caller_fn) = caller {
|
|
||||||
get_callers_of_extern_shared_single(
|
|
||||||
methods_using_extern_shared,
|
|
||||||
directly_called_by,
|
|
||||||
*caller_fn,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
kernel_use.insert(*kernel_name, size);
|
||||||
|
}
|
||||||
|
ast::MethodName::Func(fn_id) => {
|
||||||
|
functions_using_global.insert(*fn_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
(kernel_use, functions_using_global)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DenormCountMap<T> = HashMap<T, isize>;
|
type DenormCountMap<T> = HashMap<T, isize>;
|
||||||
|
@ -3480,7 +3571,10 @@ fn emit_variable<'input>(
|
||||||
[dr::Operand::LiteralInt32(align)].iter().cloned(),
|
[dr::Operand::LiteralInt32(align)].iter().cloned(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
emit_linking_decoration(builder, id_defs, None, var.name, linking);
|
if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN)
|
||||||
|
{
|
||||||
|
emit_linking_decoration(builder, id_defs, None, var.name, linking);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3494,9 +3588,9 @@ fn emit_linking_decoration<'input>(
|
||||||
if linking == ast::LinkingDirective::NONE {
|
if linking == ast::LinkingDirective::NONE {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let string_name =
|
|
||||||
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
|
|
||||||
if linking.contains(ast::LinkingDirective::VISIBLE) {
|
if linking.contains(ast::LinkingDirective::VISIBLE) {
|
||||||
|
let string_name =
|
||||||
|
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
|
||||||
builder.decorate(
|
builder.decorate(
|
||||||
name,
|
name,
|
||||||
spirv::Decoration::LinkageAttributes,
|
spirv::Decoration::LinkageAttributes,
|
||||||
|
@ -3508,6 +3602,8 @@ fn emit_linking_decoration<'input>(
|
||||||
.cloned(),
|
.cloned(),
|
||||||
);
|
);
|
||||||
} else if linking.contains(ast::LinkingDirective::EXTERN) {
|
} else if linking.contains(ast::LinkingDirective::EXTERN) {
|
||||||
|
let string_name =
|
||||||
|
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
|
||||||
builder.decorate(
|
builder.decorate(
|
||||||
name,
|
name,
|
||||||
spirv::Decoration::LinkageAttributes,
|
spirv::Decoration::LinkageAttributes,
|
||||||
|
@ -4454,7 +4550,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
})
|
})
|
||||||
.collect::<HashSet<_>>();
|
.collect::<HashSet<_>>();
|
||||||
let mut stateful_markers = Vec::new();
|
let mut stateful_markers = Vec::new();
|
||||||
let mut stateful_init_reg = MultiHashMap::new();
|
let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
|
||||||
for statement in func_body.iter() {
|
for statement in func_body.iter() {
|
||||||
match statement {
|
match statement {
|
||||||
Statement::Instruction(ast::Instruction::Cvta(
|
Statement::Instruction(ast::Instruction::Cvta(
|
||||||
|
@ -7863,26 +7959,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
|
||||||
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
|
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
|
||||||
fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
|
fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
|
||||||
let is_kernel = self.name.is_kernel();
|
let is_kernel = self.name.is_kernel();
|
||||||
self.input_arguments
|
self.input_arguments.iter().map(move |arg| {
|
||||||
.iter()
|
if !is_kernel && arg.state_space != ast::StateSpace::Reg {
|
||||||
.map(move |arg| {
|
let spirv_type =
|
||||||
if !is_kernel && arg.state_space != ast::StateSpace::Reg {
|
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
|
||||||
let spirv_type =
|
(arg.name, spirv_type)
|
||||||
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
|
} else {
|
||||||
(arg.name, spirv_type)
|
(arg.name, SpirvType::new(arg.v_type.clone()))
|
||||||
} else {
|
}
|
||||||
(arg.name, SpirvType::new(arg.v_type.clone()))
|
})
|
||||||
}
|
|
||||||
})
|
|
||||||
.chain(self.shared_mem.iter().map(|id| {
|
|
||||||
(
|
|
||||||
*id,
|
|
||||||
SpirvType::Pointer(
|
|
||||||
Box::new(SpirvType::Base(SpirvScalarKey::B8)),
|
|
||||||
spirv::StorageClass::Workgroup,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue