Start implementing .shared unification

This commit is contained in:
Andrzej Janik 2021-09-24 01:31:50 +02:00
parent 9609f86033
commit 370c0bd09e
5 changed files with 305 additions and 119 deletions

View file

@ -1970,6 +1970,9 @@ ArgCall: (Vec<&'input str>, &'input str, Vec<ast::Operand<&'input str>>) = {
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => { "(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => {
(ret_params, func, param_list) (ret_params, func, param_list)
}, },
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> => {
(ret_params, func, Vec::new())
},
<func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list), <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list),
<func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()), <func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()),
}; };

View file

@ -221,6 +221,8 @@ test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]); test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]); test_ptx!(membar, [152731u32], [152731u32]);
test_ptx!(shared_unify_extern, [7681u64], [15362u64]);
test_ptx!(func_ptr); test_ptx!(func_ptr);
test_ptx!(lanemask_lt); test_ptx!(lanemask_lt);
test_ptx!(extern_func); test_ptx!(extern_func);

View file

@ -0,0 +1,34 @@
.version 6.5
.target sm_30
.address_size 64
.extern .shared .b32 shared_ex[];
.shared .b32 shared_mod[4];
.func (.reg .b64 out) load_from_shared()
{
ld.shared.u64 out, [shared_mod];
ret;
}
.visible .entry shared_unify_extern(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp1;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp1, [in_addr];
st.shared.u64 [shared_ex], temp1;
call (temp2), load_from_shared;
add.u64 temp2, temp2, temp1;
st.u64 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,62 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
OpExecutionMode %2 ContractionOff
OpDecorate %1 Alignment 4
OpDecorate %1 LinkageAttributes "shared_mem" Import
%void = OpTypeVoid
%uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
%1 = OpVariable %_ptr_Workgroup_uchar Workgroup
%ulong = OpTypeInt 64 0
%35 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%2 = OpFunction %void None %35
%10 = OpFunctionParameter %ulong
%11 = OpFunctionParameter %ulong
%28 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
%8 = OpVariable %_ptr_Function_ulong Function
%9 = OpVariable %_ptr_Function_ulong Function
OpStore %3 %10
OpStore %4 %11
%12 = OpLoad %ulong %3 Aligned 8
OpStore %5 %12
%13 = OpLoad %ulong %4 Aligned 8
OpStore %6 %13
%23 = OpConvertPtrToU %ulong %1
%14 = OpCopyObject %ulong %23
OpStore %7 %14
%16 = OpLoad %ulong %5
%24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
%15 = OpLoad %ulong %24 Aligned 8
OpStore %8 %15
%17 = OpLoad %ulong %7
%18 = OpLoad %ulong %8
%25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17
OpStore %25 %18 Aligned 8
%20 = OpLoad %ulong %7
%26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
%19 = OpLoad %ulong %26 Aligned 8
OpStore %9 %19
%21 = OpLoad %ulong %6
%22 = OpLoad %ulong %9
%27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
OpStore %27 %22 Aligned 8
OpReturn
OpFunctionEnd

View file

@ -443,7 +443,8 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let mut builder = dr::Builder::new(); let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id()); builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives); let call_map = get_kernels_call_map(&directives);
//let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id()); let mut directives =
convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
normalize_variable_decls(&mut directives); normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives); let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@ -607,7 +608,7 @@ fn emit_directives<'input>(
} }
} }
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
emit_function_linkage(builder, id_defs, f, fn_id); emit_function_linkage(builder, id_defs, f, fn_id)?;
builder.select_block(None)?; builder.select_block(None)?;
builder.end_function()?; builder.end_function()?;
} }
@ -683,7 +684,7 @@ fn get_kernels_call_map<'input>(
} }
fn add_call_map_single<'input>( fn add_call_map_single<'input>(
directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>, directly_called_by: &HashMap<ast::MethodName<'input, spirv::Word>, Vec<spirv::Word>>,
visited: &mut HashSet<spirv::Word>, visited: &mut HashSet<spirv::Word>,
current: spirv::Word, current: spirv::Word,
) { ) {
@ -697,15 +698,21 @@ fn add_call_map_single<'input>(
} }
} }
type MultiHashMap<K, V> = HashMap<K, Vec<V>>; fn multi_hash_map_append<
K: Eq + std::hash::Hash,
fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) { V,
Collection: std::iter::Extend<V> + std::default::Default,
>(
m: &mut HashMap<K, Collection>,
key: K,
value: V,
) {
match m.entry(key) { match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => { hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().push(value); entry.get_mut().extend(iter::once(value));
} }
hash_map::Entry::Vacant(entry) => { hash_map::Entry::Vacant(entry) => {
entry.insert(vec![value]); entry.insert(Default::default());
} }
} }
} }
@ -713,7 +720,8 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
/* /*
PTX represents dynamically allocated shared local memory as PTX represents dynamically allocated shared local memory as
.extern .shared .b32 shared_mem[]; .extern .shared .b32 shared_mem[];
In SPIRV/OpenCL world this is expressed as an additional argument In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
And in AMD compilation
This pass looks for all uses of .extern .shared and converts them to This pass looks for all uses of .extern .shared and converts them to
an additional method argument an additional method argument
The question is how this artificial argument should be expressed. There are The question is how this artificial argument should be expressed. There are
@ -735,30 +743,35 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
*/ */
fn convert_dynamic_shared_memory_usage<'input>( fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>, module: Vec<Directive<'input>>,
kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> { ) -> Vec<Directive<'input>> {
let mut extern_shared_decls = HashMap::new(); let mut globals_shared = HashMap::new();
for dir in module.iter() { for dir in module.iter() {
match dir { match dir {
Directive::Variable( Directive::Variable(
linking, linking,
ast::Variable { ast::Variable {
v_type: ast::Type::Array(p_type, dims),
state_space: ast::StateSpace::Shared, state_space: ast::StateSpace::Shared,
name, name,
v_type,
.. ..
}, },
) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => { ) => {
extern_shared_decls.insert(*name, *p_type); let size = if linking.contains(ast::LinkingDirective::EXTERN) {
GlobalSharedSize::ExternUnsized
} else {
GlobalSharedSize::Sized((*v_type).size_of())
};
globals_shared.insert(*name, (size, v_type.clone()));
} }
_ => {} _ => {}
} }
} }
if extern_shared_decls.len() == 0 { if globals_shared.len() == 0 {
return module; return module;
} }
let mut methods_using_extern_shared = HashSet::new(); let mut methods_to_globals_shared_direct_only_use = HashMap::<_, GlobalSharedSize>::new();
let mut directly_called_by = MultiHashMap::new();
let module = module let module = module
.into_iter() .into_iter()
.map(|directive| match directive { .map(|directive| match directive {
@ -773,17 +786,21 @@ fn convert_dynamic_shared_memory_usage<'input>(
let call_key = (*func_decl).borrow().name; let call_key = (*func_decl).borrow().name;
let statements = statements let statements = statements
.into_iter() .into_iter()
.map(|statement| match statement { .map(|statement| {
Statement::Call(call) => { statement.map_id(&mut |id, _| {
multi_hash_map_append(&mut directly_called_by, call.name, call_key); if let Some((size, _)) = globals_shared.get(&id) {
Statement::Call(call) match methods_to_globals_shared_direct_only_use.entry(call_key) {
} hash_map::Entry::Occupied(mut e) => {
statement => statement.map_id(&mut |id, _| { let original_size = *e.get();
if extern_shared_decls.contains_key(&id) { e.insert(original_size.fold(*size));
methods_using_extern_shared.insert(call_key); }
hash_map::Entry::Vacant(mut e) => {
e.insert(*size);
}
}
} }
id id
}), })
}) })
.collect(); .collect();
Directive::Method(Function { Directive::Method(Function {
@ -800,11 +817,15 @@ fn convert_dynamic_shared_memory_usage<'input>(
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
// make sure it gets propagated to `fn1` and `kernel` // make sure it gets propagated to `fn1` and `kernel`
get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by); let (kernels_to_global_shared, functions_to_global_shared) =
resolve_indirect_uses_of_globals_shared(
methods_to_globals_shared_direct_only_use,
kernels_methods_call_map,
);
// now visit every method declaration and inject those additional arguments // now visit every method declaration and inject those additional arguments
module let mut result = Vec::with_capacity(module.len());
.into_iter() for directive in module.into_iter() {
.map(|directive| match directive { match directive {
Directive::Method(Function { Directive::Method(Function {
func_decl, func_decl,
globals, globals,
@ -813,46 +834,119 @@ fn convert_dynamic_shared_memory_usage<'input>(
tuning, tuning,
linkage, linkage,
}) => { }) => {
if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) { let statements = {
return Directive::Method(Function { let func_decl_ref = &mut (*func_decl).borrow_mut();
func_decl, let method_name = func_decl_ref.name;
globals, insert_arguments_remap_statements(
body: Some(statements), method_name,
import_as, &kernels_to_global_shared,
tuning, new_id,
linkage, &mut result,
}); &functions_to_global_shared,
} func_decl_ref,
let shared_id_param = new_id(); &globals_shared,
{ statements,
let mut func_decl = (*func_decl).borrow_mut(); )
func_decl.shared_mem = Some(shared_id_param); };
} result.push(Directive::Method(Function {
let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
&mut methods_using_extern_shared,
shared_id_param,
statements,
);
Directive::Method(Function {
func_decl, func_decl,
globals, globals,
body: Some(statements), body: Some(statements),
import_as, import_as,
tuning, tuning,
linkage, linkage,
}) }));
} }
directive => directive, directive => result.push(directive),
}) }
.collect::<Vec<_>>() }
result
}
fn insert_arguments_remap_statements(
method_name: ast::MethodName<u32>,
kernels_to_global_shared: &HashMap<&str, GlobalSharedSize>,
new_id: &mut impl FnMut() -> u32,
result: &mut Vec<Directive>,
functions_to_global_shared: &HashSet<u32>,
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<u32>>,
globals_shared: &HashMap<u32, (GlobalSharedSize, ast::Type)>,
statements: Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>,
) -> Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>> {
let shared_id_param = match method_name {
ast::MethodName::Kernel(kernel_name) => {
let globals_shared_size = match kernels_to_global_shared.get(kernel_name) {
Some(s) => *s,
None => return statements,
};
let shared_id_param = new_id();
let (linkage, type_) = match globals_shared_size {
GlobalSharedSize::ExternUnsized => (
ast::LinkingDirective::EXTERN,
ast::Type::Array(ast::ScalarType::U8, Vec::new()),
),
GlobalSharedSize::Sized(size) => (
ast::LinkingDirective::NONE,
ast::Type::Array(ast::ScalarType::U8, vec![size as u32]),
),
};
result.push(Directive::Variable(
linkage,
ast::Variable {
align: None,
v_type: type_,
state_space: ast::StateSpace::Shared,
name: shared_id_param,
array_init: Vec::new(),
},
));
shared_id_param
}
ast::MethodName::Func(function_name) => {
if !functions_to_global_shared.contains(&function_name) {
return statements;
}
let shared_id_param = new_id();
func_decl_ref.input_arguments.push(ast::Variable {
align: None,
v_type: ast::Type::Pointer(ast::ScalarType::B8, ast::StateSpace::Shared),
state_space: ast::StateSpace::Reg,
name: shared_id_param,
array_init: Vec::new(),
});
shared_id_param
}
};
replace_uses_of_shared_memory(
new_id,
globals_shared,
functions_to_global_shared,
shared_id_param,
statements,
)
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
enum GlobalSharedSize {
ExternUnsized,
Sized(usize),
}
impl GlobalSharedSize {
fn fold(self, other: GlobalSharedSize) -> GlobalSharedSize {
match (self, other) {
(GlobalSharedSize::Sized(s1), GlobalSharedSize::Sized(s2)) => {
GlobalSharedSize::Sized(usize::max(s1, s2))
}
_ => GlobalSharedSize::ExternUnsized,
}
}
} }
fn replace_uses_of_shared_memory<'a>( fn replace_uses_of_shared_memory<'a>(
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>, extern_shared_decls: &HashMap<spirv::Word, (GlobalSharedSize, ast::Type)>,
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>, methods_using_extern_shared: &HashSet<spirv::Word>,
shared_id_param: spirv::Word, shared_id_param: spirv::Word,
statements: Vec<ExpandedStatement>, statements: Vec<ExpandedStatement>,
) -> Vec<ExpandedStatement> { ) -> Vec<ExpandedStatement> {
@ -863,7 +957,7 @@ fn replace_uses_of_shared_memory<'a>(
// We can safely skip checking call arguments, // We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr // because there's simply no way to pass shared ptr
// without converting it to .b64 first // without converting it to .b64 first
if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) { if methods_using_extern_shared.contains(&call.name) {
call.input_arguments.push(( call.input_arguments.push((
shared_id_param, shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8), ast::Type::Scalar(ast::ScalarType::B8),
@ -874,8 +968,8 @@ fn replace_uses_of_shared_memory<'a>(
} }
statement => { statement => {
let new_statement = statement.map_id(&mut |id, _| { let new_statement = statement.map_id(&mut |id, _| {
if let Some(scalar_type) = extern_shared_decls.get(&id) { if let Some((_, type_)) = extern_shared_decls.get(&id) {
if *scalar_type == ast::ScalarType::B8 { if *type_ == ast::Type::Scalar(ast::ScalarType::B8) {
return shared_id_param; return shared_id_param;
} }
let replacement_id = new_id(); let replacement_id = new_id();
@ -884,7 +978,7 @@ fn replace_uses_of_shared_memory<'a>(
dst: replacement_id, dst: replacement_id,
from_type: ast::Type::Scalar(ast::ScalarType::B8), from_type: ast::Type::Scalar(ast::ScalarType::B8),
from_space: ast::StateSpace::Shared, from_space: ast::StateSpace::Shared,
to_type: ast::Type::Scalar(*scalar_type), to_type: type_.clone(),
to_space: ast::StateSpace::Shared, to_space: ast::StateSpace::Shared,
kind: ConversionKind::PtrToPtr, kind: ConversionKind::PtrToPtr,
})); }));
@ -900,43 +994,40 @@ fn replace_uses_of_shared_memory<'a>(
result result
} }
fn get_callers_of_extern_shared<'a>( // We need to compute two kinds of information:
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>, // * If it's a kernel -> size of .shared globals in use (direct or indirect)
directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>, // * If it's a function -> does it use .shared global (directly or indirectly)
) { fn resolve_indirect_uses_of_globals_shared<'input>(
let direct_uses_of_extern_shared = methods_using_extern_shared mut methods_use_of_globals_shared: HashMap<
.iter() ast::MethodName<'input, spirv::Word>,
.filter_map(|method| { GlobalSharedSize,
if let ast::MethodName::Func(f_id) = method { >,
Some(*f_id) kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
} else { ) -> (HashMap<&'input str, GlobalSharedSize>, HashSet<spirv::Word>) {
None let mut kernel_use = HashMap::new();
} let mut functions_using_global = HashSet::new();
}) let empty = HashSet::new();
.collect::<Vec<_>>(); for (method, globals) in methods_use_of_globals_shared.iter() {
for fn_id in direct_uses_of_extern_shared { match method {
get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id); ast::MethodName::Kernel(kernel_name) => {
} let mut size = *globals;
} for &called_subfunction in
kernels_methods_call_map.get(kernel_name).unwrap_or(&empty)
fn get_callers_of_extern_shared_single<'a>( {
methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>, if let Some(new_size) = methods_use_of_globals_shared
directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>, .get(&ast::MethodName::Func(called_subfunction))
fn_id: spirv::Word, {
) { size = size.fold(*new_size);
if let Some(callers) = directly_called_by.get(&fn_id) { }
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
if let ast::MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
*caller_fn,
);
} }
kernel_use.insert(*kernel_name, size);
}
ast::MethodName::Func(fn_id) => {
functions_using_global.insert(*fn_id);
} }
} }
} }
(kernel_use, functions_using_global)
} }
type DenormCountMap<T> = HashMap<T, isize>; type DenormCountMap<T> = HashMap<T, isize>;
@ -3480,7 +3571,10 @@ fn emit_variable<'input>(
[dr::Operand::LiteralInt32(align)].iter().cloned(), [dr::Operand::LiteralInt32(align)].iter().cloned(),
); );
} }
emit_linking_decoration(builder, id_defs, None, var.name, linking); if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN)
{
emit_linking_decoration(builder, id_defs, None, var.name, linking);
}
Ok(()) Ok(())
} }
@ -3494,9 +3588,9 @@ fn emit_linking_decoration<'input>(
if linking == ast::LinkingDirective::NONE { if linking == ast::LinkingDirective::NONE {
return; return;
} }
let string_name =
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
if linking.contains(ast::LinkingDirective::VISIBLE) { if linking.contains(ast::LinkingDirective::VISIBLE) {
let string_name =
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate( builder.decorate(
name, name,
spirv::Decoration::LinkageAttributes, spirv::Decoration::LinkageAttributes,
@ -3508,6 +3602,8 @@ fn emit_linking_decoration<'input>(
.cloned(), .cloned(),
); );
} else if linking.contains(ast::LinkingDirective::EXTERN) { } else if linking.contains(ast::LinkingDirective::EXTERN) {
let string_name =
name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate( builder.decorate(
name, name,
spirv::Decoration::LinkageAttributes, spirv::Decoration::LinkageAttributes,
@ -4454,7 +4550,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
}) })
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let mut stateful_markers = Vec::new(); let mut stateful_markers = Vec::new();
let mut stateful_init_reg = MultiHashMap::new(); let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
for statement in func_body.iter() { for statement in func_body.iter() {
match statement { match statement {
Statement::Instruction(ast::Instruction::Cvta( Statement::Instruction(ast::Instruction::Cvta(
@ -7863,26 +7959,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
impl<'a> ast::MethodDeclaration<'a, spirv::Word> { impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ { fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel(); let is_kernel = self.name.is_kernel();
self.input_arguments self.input_arguments.iter().map(move |arg| {
.iter() if !is_kernel && arg.state_space != ast::StateSpace::Reg {
.map(move |arg| { let spirv_type =
if !is_kernel && arg.state_space != ast::StateSpace::Reg { SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
let spirv_type = (arg.name, spirv_type)
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); } else {
(arg.name, spirv_type) (arg.name, SpirvType::new(arg.v_type.clone()))
} else { }
(arg.name, SpirvType::new(arg.v_type.clone())) })
}
})
.chain(self.shared_mem.iter().map(|id| {
(
*id,
SpirvType::Pointer(
Box::new(SpirvType::Base(SpirvScalarKey::B8)),
spirv::StorageClass::Workgroup,
),
)
}))
} }
} }