Simplify typing

This commit is contained in:
Andrzej Janik 2021-05-07 18:22:09 +02:00
parent 7f051ad20e
commit 425edfcdd4
4 changed files with 247 additions and 333 deletions

View file

@ -1,6 +1,6 @@
use half::f16; use half::f16;
use lalrpop_util::{lexer::Token, ParseError}; use lalrpop_util::{lexer::Token, ParseError};
use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{convert::From, mem, num::ParseFloatError, rc::Rc, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError}; use std::{marker::PhantomData, num::ParseIntError};
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -86,19 +86,20 @@ pub enum Directive<'a, P: ArgParams> {
Method(Function<'a, &'a str, Statement<P>>), Method(Function<'a, &'a str, Statement<P>>),
} }
pub enum MethodDecl<'a, ID> { #[derive(Hash, PartialEq, Eq, Copy, Clone)]
Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>), pub enum MethodName<'input, ID> {
Kernel { Kernel(&'input str),
name: &'a str, Func(ID),
in_args: Vec<KernelArgument<ID>>,
},
} }
pub type FnArgument<ID> = Variable<ID>; pub struct MethodDeclaration<'input, ID> {
pub type KernelArgument<ID> = Variable<ID>; pub return_arguments: Vec<Variable<ID>>,
pub name: MethodName<'input, ID>,
pub input_arguments: Vec<Variable<ID>>,
}
pub struct Function<'a, ID, S> { pub struct Function<'a, ID, S> {
pub func_directive: MethodDecl<'a, ID>, pub func_directive: MethodDeclaration<'a, ID>,
pub tuning: Vec<TuningDirective>, pub tuning: Vec<TuningDirective>,
pub body: Option<Vec<S>>, pub body: Option<Vec<S>>,
} }

View file

@ -360,7 +360,7 @@ AddressSize = {
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = { Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
LinkingDirectives LinkingDirectives
<func_directive:MethodDecl> <func_directive:MethodDeclaration>
<tuning:TuningDirective*> <tuning:TuningDirective*>
<body:FunctionBody> => ast::Function{<>} <body:FunctionBody> => ast::Function{<>}
}; };
@ -388,19 +388,24 @@ LinkingDirectives: ast::LinkingDirective = {
} }
} }
MethodDecl: ast::MethodDecl<'input, &'input str> = { MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = {
".entry" <name:ExtendedID> <in_args:KernelArguments> => ".entry" <name:ExtendedID> <input_arguments:KernelArguments> => {
ast::MethodDecl::Kernel{ name, in_args }, let return_arguments = Vec::new();
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => { let name = ast::MethodName::Kernel(name);
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) ast::MethodDeclaration{ return_arguments, name, input_arguments }
},
".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => {
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
let name = ast::MethodName::Func(name);
ast::MethodDeclaration{ return_arguments, name, input_arguments }
} }
}; };
KernelArguments: Vec<ast::KernelArgument<&'input str>> = { KernelArguments: Vec<ast::Variable<&'input str>> = {
"(" <args:Comma<KernelInput>> ")" => args "(" <args:Comma<KernelInput>> ")" => args
}; };
FnArguments: Vec<ast::FnArgument<&'input str>> = { FnArguments: Vec<ast::Variable<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args "(" <args:Comma<FnInput>> ")" => args
}; };

File diff suppressed because it is too large Load diff

View file

@ -191,7 +191,10 @@ unsafe fn record_module_image(module: CUmodule, image: &str) {
unsafe fn try_dump_module_image(image: &str) -> Result<(), Box<dyn Error>> { unsafe fn try_dump_module_image(image: &str) -> Result<(), Box<dyn Error>> {
let mut dump_path = get_dump_dir()?; let mut dump_path = get_dump_dir()?;
dump_path.push(format!("module_{:04}.ptx", MODULES.as_ref().unwrap().len() - 1)); dump_path.push(format!(
"module_{:04}.ptx",
MODULES.as_ref().unwrap().len() - 1
));
let mut file = File::create(dump_path)?; let mut file = File::create(dump_path)?;
file.write_all(image.as_bytes())?; file.write_all(image.as_bytes())?;
Ok(()) Ok(())
@ -217,10 +220,15 @@ unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> { fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
match dir { match dir {
ast::Directive::Method(ast::Function { ast::Directive::Method(ast::Function {
func_directive: ast::MethodDecl::Kernel { name, in_args }, func_directive:
ast::MethodDeclaration {
name: ast::MethodName::Kernel(name),
input_arguments,
..
},
.. ..
}) => { }) => {
let arg_sizes = in_args let arg_sizes = input_arguments
.iter() .iter()
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of()) .map(|arg| ast::Type::from(arg.v_type.clone()).size_of())
.collect(); .collect();