diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 364ec01..e45a6fb 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,6 +1,6 @@
use half::f16;
use lalrpop_util::{lexer::Token, ParseError};
-use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
+use std::{convert::From, mem, num::ParseFloatError, rc::Rc, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
#[derive(Debug, thiserror::Error)]
@@ -86,19 +86,20 @@ pub enum Directive<'a, P: ArgParams> {
Method(Function<'a, &'a str, Statement
>),
}
-pub enum MethodDecl<'a, ID> {
- Func(Vec>, ID, Vec>),
- Kernel {
- name: &'a str,
- in_args: Vec>,
- },
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+pub enum MethodName<'input, ID> {
+ Kernel(&'input str),
+ Func(ID),
}
-pub type FnArgument = Variable;
-pub type KernelArgument = Variable;
+pub struct MethodDeclaration<'input, ID> {
+ pub return_arguments: Vec>,
+ pub name: MethodName<'input, ID>,
+ pub input_arguments: Vec>,
+}
pub struct Function<'a, ID, S> {
- pub func_directive: MethodDecl<'a, ID>,
+ pub func_directive: MethodDeclaration<'a, ID>,
pub tuning: Vec,
pub body: Option>,
}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 8fee7c2..78ebf1d 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -360,7 +360,7 @@ AddressSize = {
Function: ast::Function<'input, &'input str, ast::Statement>> = {
LinkingDirectives
-
+
=> ast::Function{<>}
};
@@ -388,19 +388,24 @@ LinkingDirectives: ast::LinkingDirective = {
}
}
-MethodDecl: ast::MethodDecl<'input, &'input str> = {
- ".entry" =>
- ast::MethodDecl::Kernel{ name, in_args },
- ".func" => {
- ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
+MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = {
+ ".entry" => {
+ let return_arguments = Vec::new();
+ let name = ast::MethodName::Kernel(name);
+ ast::MethodDeclaration{ return_arguments, name, input_arguments }
+ },
+ ".func" => {
+ let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
+ let name = ast::MethodName::Func(name);
+ ast::MethodDeclaration{ return_arguments, name, input_arguments }
}
};
-KernelArguments: Vec> = {
+KernelArguments: Vec> = {
"(" > ")" => args
};
-FnArguments: Vec> = {
+FnArguments: Vec> = {
"(" > ")" => args
};
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 1a2eda3..88ef51b 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,7 +1,9 @@
use crate::ast;
+use core::borrow;
use half::f16;
use rspirv::dr;
-use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
+use std::{borrow::Borrow, cell::RefCell};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -458,7 +460,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
- let call_map = get_call_map(&directives);
+ let call_map = get_kernels_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
@@ -496,9 +498,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result(
call_map: &HashMap<&str, HashSet>,
- denorm_information: &HashMap>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap,
+ >,
) -> CString {
let denorm_counts = denorm_information
.iter()
@@ -516,10 +521,12 @@ fn emit_denorm_build_string(
.collect::>();
let mut flush_over_preserve = 0;
for (kernel, children) in call_map {
- flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Kernel(kernel))
+ .unwrap_or(&0);
for child_fn in children {
flush_over_preserve += *denorm_counts
- .get(&MethodName::Func(*child_fn))
+ .get(&ast::MethodName::Func(*child_fn))
.unwrap_or(&0);
}
}
@@ -535,9 +542,12 @@ fn emit_directives<'input>(
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word,
- denorm_information: &HashMap, HashMap>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap,
+ >,
call_map: &HashMap<&'input str, HashSet>,
- directives: Vec,
+ directives: Vec>,
kernel_info: &mut HashMap,
) -> Result<(), TranslateError> {
let empty_body = Vec::new();
@@ -560,16 +570,18 @@ fn emit_directives<'input>(
for var in f.globals.iter() {
emit_variable(builder, map, var)?;
}
+ let func_decl = (*f.func_decl).borrow();
let fn_id = emit_function_header(
builder,
map,
&id_defs,
&f.globals,
- &f.spirv_decl,
+ &*func_decl,
&denorm_information,
call_map,
&directives,
kernel_info,
+ f.uses_shared_mem,
)?;
for t in f.tuning.iter() {
match *t {
@@ -594,8 +606,13 @@ fn emit_directives<'input>(
}
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
- if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
- (&f.func_decl, &f.import_as)
+ if let (
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func(fn_id),
+ ..
+ },
+ Some(name),
+ ) = (&*func_decl, &f.import_as)
{
builder.decorate(
*fn_id,
@@ -614,7 +631,7 @@ fn emit_directives<'input>(
Ok(())
}
-fn get_call_map<'input>(
+fn get_kernels_call_map<'input>(
module: &[Directive<'input>],
) -> HashMap<&'input str, HashSet> {
let mut directly_called_by = HashMap::new();
@@ -625,7 +642,7 @@ fn get_call_map<'input>(
body: Some(statements),
..
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
entry.insert(Vec::new());
}
@@ -644,28 +661,28 @@ fn get_call_map<'input>(
let mut result = HashMap::new();
for (method_key, children) in directly_called_by.iter() {
match method_key {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let mut visited = HashSet::new();
for child in children {
add_call_map_single(&directly_called_by, &mut visited, *child);
}
result.insert(*name, visited);
}
- MethodName::Func(_) => {}
+ ast::MethodName::Func(_) => {}
}
}
result
}
fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap, spirv::Word>,
+ directly_called_by: &MultiHashMap, spirv::Word>,
visited: &mut HashSet,
current: spirv::Word,
) {
if !visited.insert(current) {
return;
}
- if let Some(children) = directly_called_by.get(&MethodName::Func(current)) {
+ if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
for child in children {
add_call_map_single(directly_called_by, visited, *child);
}
@@ -739,10 +756,10 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem,
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
@@ -763,8 +780,8 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem,
})
}
directive => directive,
@@ -782,30 +799,32 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- mut spirv_decl,
tuning,
+ uses_shared_mem,
}) => {
- if !methods_using_extern_shared.contains(&spirv_decl.name) {
+ if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem,
});
}
let shared_id_param = new_id();
- spirv_decl.input.push({
- ast::Variable {
- name: shared_id_param,
- align: None,
- v_type: ast::Type::Pointer(ast::ScalarType::B8),
- state_space: ast::StateSpace::Shared,
- array_init: Vec::new(),
- }
- });
- spirv_decl.uses_shared_mem = true;
+ {
+ let mut func_decl = (*func_decl).borrow_mut();
+ func_decl.input_arguments.push({
+ ast::Variable {
+ name: shared_id_param,
+ align: None,
+ v_type: ast::Type::Pointer(ast::ScalarType::B8),
+ state_space: ast::StateSpace::Shared,
+ array_init: Vec::new(),
+ }
+ });
+ }
let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
@@ -818,8 +837,8 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem: true,
})
}
directive => directive,
@@ -830,7 +849,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
fn replace_uses_of_shared_memory<'a>(
new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap,
- methods_using_extern_shared: &mut HashSet>,
+ methods_using_extern_shared: &mut HashSet>,
shared_id_param: spirv::Word,
statements: Vec,
) -> Vec {
@@ -841,7 +860,7 @@ fn replace_uses_of_shared_memory<'a>(
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
+ if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) {
call.param_list.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
@@ -881,13 +900,13 @@ fn replace_uses_of_shared_memory<'a>(
}
fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet>,
- directly_called_by: &MultiHashMap>,
+ methods_using_extern_shared: &mut HashSet>,
+ directly_called_by: &MultiHashMap>,
) {
let direct_uses_of_extern_shared = methods_using_extern_shared
.iter()
.filter_map(|method| {
- if let MethodName::Func(f_id) = method {
+ if let ast::MethodName::Func(f_id) = method {
Some(*f_id)
} else {
None
@@ -900,14 +919,14 @@ fn get_callers_of_extern_shared<'a>(
}
fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet>,
- directly_called_by: &MultiHashMap>,
+ methods_using_extern_shared: &mut HashSet>,
+ directly_called_by: &MultiHashMap>,
fn_id: spirv::Word,
) {
if let Some(callers) = directly_called_by.get(&fn_id) {
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
- if let MethodName::Func(caller_fn) = caller {
+ if let ast::MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
@@ -949,7 +968,7 @@ fn denorm_count_map_update_impl(
// and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
-) -> HashMap, HashMap> {
+) -> HashMap, HashMap> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
@@ -960,7 +979,7 @@ fn compute_denorm_information<'input>(
..
}) => {
let mut flush_counter = DenormCountMap::new();
- let method_key = MethodName::new(func_decl);
+ let method_key = (**func_decl).borrow().name;
for statement in statements {
match statement {
Statement::Instruction(inst) => {
@@ -1004,21 +1023,6 @@ fn compute_denorm_information<'input>(
.collect()
}
-#[derive(Hash, PartialEq, Eq, Copy, Clone)]
-enum MethodName<'input> {
- Kernel(&'input str),
- Func(spirv::Word),
-}
-
-impl<'input> MethodName<'input> {
- fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- match decl {
- ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id),
- }
- }
-}
-
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1047,17 +1051,21 @@ fn emit_function_header<'a>(
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
synthetic_globals: &[ast::Variable],
- func_decl: &SpirvMethodDecl<'a>,
- _denorm_information: &HashMap, HashMap>,
+ func_decl: &ast::MethodDeclaration<'a, spirv::Word>,
+ _denorm_information: &HashMap<
+ ast::MethodName<'a, spirv::Word>,
+ HashMap,
+ >,
call_map: &HashMap<&'a str, HashSet>,
direcitves: &[Directive],
kernel_info: &mut HashMap,
+ uses_shared_mem: bool,
) -> Result {
- if let MethodName::Kernel(name) = func_decl.name {
- let input_args = if !func_decl.uses_shared_mem {
- func_decl.input.as_slice()
+ if let ast::MethodName::Kernel(name) = func_decl.name {
+ let input_args = if !uses_shared_mem {
+ func_decl.input_arguments.as_slice()
} else {
- &func_decl.input[0..func_decl.input.len() - 1]
+ &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
};
let args_lens = input_args
.iter()
@@ -1067,14 +1075,18 @@ fn emit_function_header<'a>(
name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
- uses_shared_mem: func_decl.uses_shared_mem,
+ uses_shared_mem: uses_shared_mem,
},
);
}
- let (ret_type, func_type) =
- get_function_type(builder, map, &func_decl.input, &func_decl.output);
+ let (ret_type, func_type) = get_function_type(
+ builder,
+ map,
+ &func_decl.input_arguments,
+ &func_decl.return_arguments,
+ );
let fn_id = match func_decl.name {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let fn_id = defined_globals.get_id(name)?;
let mut global_variables = defined_globals
.variables_type_check
@@ -1090,15 +1102,16 @@ fn emit_function_header<'a>(
for directive in direcitves {
match directive {
Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- globals,
- ..
+ func_decl, globals, ..
}) => {
- if child_fns.contains(name) {
- for var in globals {
- interface.push(var.name);
+ match (**func_decl).borrow().name {
+ ast::MethodName::Func(name) => {
+ for var in globals {
+ interface.push(var.name);
+ }
}
- }
+ ast::MethodName::Kernel(_) => {}
+ };
}
_ => {}
}
@@ -1107,7 +1120,7 @@ fn emit_function_header<'a>(
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
- MethodName::Func(name) => name,
+ ast::MethodName::Func(name) => name,
};
builder.begin_function(
ret_type,
@@ -1130,7 +1143,7 @@ fn emit_function_header<'a>(
}
}
*/
- for input in &func_decl.input {
+ for input in &func_decl.input_arguments {
let result_type = map.get_or_add(
builder,
SpirvType::new(input.v_type.clone(), input.state_space),
@@ -1225,9 +1238,10 @@ fn translate_function<'a>(
f: ast::ParsedFunction<'a>,
) -> Result