mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Implement pass to handle .extern .shared and add parsing code for it
This commit is contained in:
parent
27d25865af
commit
2b3ecc99e3
19 changed files with 877 additions and 123 deletions
|
@ -14,6 +14,7 @@ spirv_headers = "~1.4.2"
|
||||||
quick-error = "1.2"
|
quick-error = "1.2"
|
||||||
bit-vec = "0.6"
|
bit-vec = "0.6"
|
||||||
half ="1.6"
|
half ="1.6"
|
||||||
|
bitflags = "1.2"
|
||||||
|
|
||||||
[build-dependencies.lalrpop]
|
[build-dependencies.lalrpop]
|
||||||
version = "0.19"
|
version = "0.19"
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
|
use std::convert::TryInto;
|
||||||
|
use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
|
||||||
use std::{marker::PhantomData, num::ParseIntError};
|
use std::{marker::PhantomData, num::ParseIntError};
|
||||||
|
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
@ -22,6 +23,8 @@ quick_error! {
|
||||||
WrongVectorElement {}
|
WrongVectorElement {}
|
||||||
MultiArrayVariable {}
|
MultiArrayVariable {}
|
||||||
ZeroDimensionArray {}
|
ZeroDimensionArray {}
|
||||||
|
ArrayInitalizer {}
|
||||||
|
NonExternPointer {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +81,21 @@ macro_rules! sub_type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::convert::TryFrom<Type> for $type_name {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
#[allow(unreachable_patterns)]
|
||||||
|
fn try_from(t: Type) -> Result<Self, Self::Error> {
|
||||||
|
match t {
|
||||||
|
$(
|
||||||
|
Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
|
||||||
|
)+
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,14 +116,39 @@ sub_type! {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TryFrom<VariableGlobalType> for VariableLocalType {
|
||||||
|
type Error = PtxError;
|
||||||
|
|
||||||
|
fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
|
||||||
|
VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
|
||||||
|
VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
|
||||||
|
VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_type! {
|
||||||
|
VariableGlobalType {
|
||||||
|
Scalar(SizedScalarType),
|
||||||
|
Vector(SizedScalarType, u8),
|
||||||
|
Array(SizedScalarType, VecU32),
|
||||||
|
Pointer(SizedScalarType, PointerStateSpace),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// For some weird reson this is illegal:
|
// For some weird reson this is illegal:
|
||||||
// .param .f16x2 foobar;
|
// .param .f16x2 foobar;
|
||||||
// but this is legal:
|
// but this is legal:
|
||||||
// .param .f16x2 foobar[1];
|
// .param .f16x2 foobar[1];
|
||||||
|
// even more interestingly this is legal, but only in .func (not in .entry):
|
||||||
|
// .param .b32 foobar[]
|
||||||
sub_type! {
|
sub_type! {
|
||||||
VariableParamType {
|
VariableParamType {
|
||||||
Scalar(ParamScalarType),
|
Scalar(ParamScalarType),
|
||||||
Array(SizedScalarType, VecU32),
|
Array(SizedScalarType, VecU32),
|
||||||
|
Pointer(SizedScalarType, PointerStateSpace),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,7 +236,7 @@ pub enum MethodDecl<'a, ID> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
|
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
|
||||||
pub type KernelArgument<ID> = Variable<VariableParamType, ID>;
|
pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
|
||||||
|
|
||||||
pub struct Function<'a, ID, S> {
|
pub struct Function<'a, ID, S> {
|
||||||
pub func_directive: MethodDecl<'a, ID>,
|
pub func_directive: MethodDecl<'a, ID>,
|
||||||
|
@ -206,6 +249,12 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a
|
||||||
pub enum FnArgumentType {
|
pub enum FnArgumentType {
|
||||||
Reg(VariableRegType),
|
Reg(VariableRegType),
|
||||||
Param(VariableParamType),
|
Param(VariableParamType),
|
||||||
|
Shared,
|
||||||
|
}
|
||||||
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
|
pub enum KernelArgumentType {
|
||||||
|
Normal(VariableParamType),
|
||||||
|
Shared,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<FnArgumentType> for Type {
|
impl From<FnArgumentType> for Type {
|
||||||
|
@ -213,15 +262,25 @@ impl From<FnArgumentType> for Type {
|
||||||
match t {
|
match t {
|
||||||
FnArgumentType::Reg(x) => x.into(),
|
FnArgumentType::Reg(x) => x.into(),
|
||||||
FnArgumentType::Param(x) => x.into(),
|
FnArgumentType::Param(x) => x.into(),
|
||||||
|
FnArgumentType::Shared => Type::Scalar(ScalarType::B64),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum PointerStateSpace {
|
||||||
|
Global,
|
||||||
|
Const,
|
||||||
|
Shared,
|
||||||
|
Param,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
Vector(ScalarType, u8),
|
Vector(ScalarType, u8),
|
||||||
Array(ScalarType, Vec<u32>),
|
Array(ScalarType, Vec<u32>),
|
||||||
|
Pointer(ScalarType, PointerStateSpace),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
|
@ -343,7 +402,8 @@ pub enum VariableType {
|
||||||
Reg(VariableRegType),
|
Reg(VariableRegType),
|
||||||
Local(VariableLocalType),
|
Local(VariableLocalType),
|
||||||
Param(VariableParamType),
|
Param(VariableParamType),
|
||||||
Global(VariableLocalType),
|
Global(VariableGlobalType),
|
||||||
|
Shared(VariableGlobalType),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VariableType {
|
impl VariableType {
|
||||||
|
@ -353,6 +413,7 @@ impl VariableType {
|
||||||
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
|
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
|
||||||
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
|
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
|
||||||
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
|
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
|
||||||
|
VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -364,6 +425,7 @@ impl From<VariableType> for Type {
|
||||||
VariableType::Local(t) => t.into(),
|
VariableType::Local(t) => t.into(),
|
||||||
VariableType::Param(t) => t.into(),
|
VariableType::Param(t) => t.into(),
|
||||||
VariableType::Global(t) => t.into(),
|
VariableType::Global(t) => t.into(),
|
||||||
|
VariableType::Shared(t) => t.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1039,6 +1101,20 @@ impl<'a> NumsOrArrays<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum ArrayOrPointer {
|
||||||
|
Array { dimensions: Vec<u32>, init: Vec<u8> },
|
||||||
|
Pointer,
|
||||||
|
}
|
||||||
|
|
||||||
|
bitflags! {
|
||||||
|
pub struct LinkingDirective: u8 {
|
||||||
|
const NONE = 0b000;
|
||||||
|
const EXTERN = 0b001;
|
||||||
|
const VISIBLE = 0b10;
|
||||||
|
const WEAK = 0b100;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -17,6 +17,9 @@ extern crate spirv_headers as spirv;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
extern crate spirv_tools_sys as spirv_tools;
|
extern crate spirv_tools_sys as spirv_tools;
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
extern crate bitflags;
|
||||||
|
|
||||||
lalrpop_mod!(
|
lalrpop_mod!(
|
||||||
#[allow(warnings)]
|
#[allow(warnings)]
|
||||||
ptx
|
ptx
|
||||||
|
|
|
@ -3,6 +3,7 @@ use crate::ast::UnwrapWithVec;
|
||||||
use crate::{without_none, vector_index};
|
use crate::{without_none, vector_index};
|
||||||
|
|
||||||
use lalrpop_util::ParseError;
|
use lalrpop_util::ParseError;
|
||||||
|
use std::convert::TryInto;
|
||||||
|
|
||||||
grammar<'a>(errors: &mut Vec<ast::PtxError>);
|
grammar<'a>(errors: &mut Vec<ast::PtxError>);
|
||||||
|
|
||||||
|
@ -210,7 +211,7 @@ Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
|
||||||
<f:Function> => Some(ast::Directive::Method(f)),
|
<f:Function> => Some(ast::Directive::Method(f)),
|
||||||
File => None,
|
File => None,
|
||||||
Section => None,
|
Section => None,
|
||||||
<v:GlobalVariable> ";" => Some(ast::Directive::Variable(v)),
|
<v:ModuleVariable> ";" => Some(ast::Directive::Variable(v)),
|
||||||
};
|
};
|
||||||
|
|
||||||
AddressSize = {
|
AddressSize = {
|
||||||
|
@ -218,17 +219,23 @@ AddressSize = {
|
||||||
};
|
};
|
||||||
|
|
||||||
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
|
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
|
||||||
LinkingDirective*
|
LinkingDirectives
|
||||||
<func_directive:MethodDecl>
|
<func_directive:MethodDecl>
|
||||||
<body:FunctionBody> => ast::Function{<>}
|
<body:FunctionBody> => ast::Function{<>}
|
||||||
};
|
};
|
||||||
|
|
||||||
LinkingDirective = {
|
LinkingDirective: ast::LinkingDirective = {
|
||||||
".extern",
|
".extern" => ast::LinkingDirective::EXTERN,
|
||||||
".visible",
|
".visible" => ast::LinkingDirective::VISIBLE,
|
||||||
".weak"
|
".weak" => ast::LinkingDirective::WEAK,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
LinkingDirectives: ast::LinkingDirective = {
|
||||||
|
<ldirs:LinkingDirective*> => {
|
||||||
|
ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
MethodDecl: ast::MethodDecl<'input, &'input str> = {
|
MethodDecl: ast::MethodDecl<'input, &'input str> = {
|
||||||
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
|
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
|
||||||
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
|
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
|
||||||
|
@ -244,10 +251,15 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
|
||||||
"(" <args:Comma<FnInput>> ")" => args
|
"(" <args:Comma<FnInput>> ")" => args
|
||||||
};
|
};
|
||||||
|
|
||||||
KernelInput: ast::Variable<ast::VariableParamType, &'input str> = {
|
KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = {
|
||||||
<v:ParamDeclaration> => {
|
<v:ParamDeclaration> => {
|
||||||
let (align, v_type, name) = v;
|
let (align, v_type, name) = v;
|
||||||
ast::Variable{ align, v_type, name, array_init: Vec::new() }
|
ast::Variable {
|
||||||
|
align,
|
||||||
|
v_type: ast::KernelArgumentType::Normal(v_type),
|
||||||
|
name,
|
||||||
|
array_init: Vec::new()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,69 +369,120 @@ Variable: ast::Variable<ast::VariableType, &'input str> = {
|
||||||
};
|
};
|
||||||
|
|
||||||
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
|
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
|
||||||
".reg" <align:Align?> <t:ScalarType> <name:ExtendedID> => {
|
".reg" <var:VariableScalar<ScalarType>> => {
|
||||||
|
let (align, t, name) = var;
|
||||||
let v_type = ast::VariableRegType::Scalar(t);
|
let v_type = ast::VariableRegType::Scalar(t);
|
||||||
(align, v_type, name)
|
(align, v_type, name)
|
||||||
},
|
},
|
||||||
".reg" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
".reg" <var:VariableVector<SizedScalarType>> => {
|
||||||
|
let (align, v_len, t, name) = var;
|
||||||
let v_type = ast::VariableRegType::Vector(t, v_len);
|
let v_type = ast::VariableRegType::Vector(t, v_len);
|
||||||
(align, v_type, name)
|
(align, v_type, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
|
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||||
".local" <def:LocalVariableDefinition> => {
|
".local" <var:VariableScalar<SizedScalarType>> => {
|
||||||
let (align, array_init, v_type, name) = def;
|
let (align, t, name) = var;
|
||||||
ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }
|
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
|
||||||
|
ast::Variable { align, v_type, name, array_init: Vec::new() }
|
||||||
|
},
|
||||||
|
".local" <var:VariableVector<SizedScalarType>> => {
|
||||||
|
let (align, v_len, t, name) = var;
|
||||||
|
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
|
||||||
|
ast::Variable { align, v_type, name, array_init: Vec::new() }
|
||||||
|
},
|
||||||
|
".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
|
||||||
|
let (align, t, name, arr_or_ptr) = var;
|
||||||
|
let (v_type, array_init) = match arr_or_ptr {
|
||||||
|
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||||
|
(ast::VariableLocalType::Array(t, dimensions), init)
|
||||||
|
}
|
||||||
|
ast::ArrayOrPointer::Pointer => {
|
||||||
|
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalVariable: ast::Variable<ast::VariableType, &'input str> = {
|
ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||||
".global" <def:LocalVariableDefinition> => {
|
LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
|
||||||
let (align, array_init, v_type, name) = def;
|
let (align, v_type, name, array_init) = def;
|
||||||
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
|
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
|
||||||
|
},
|
||||||
|
LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
|
||||||
|
let (align, v_type, name, array_init) = def;
|
||||||
|
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
|
||||||
|
},
|
||||||
|
<ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
|
||||||
|
let (align, t, name, arr_or_ptr) = var;
|
||||||
|
let (v_type, array_init) = match arr_or_ptr {
|
||||||
|
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||||
|
if space == ".global" {
|
||||||
|
(ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init)
|
||||||
|
} else {
|
||||||
|
(ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ArrayOrPointer::Pointer => {
|
||||||
|
if !ldirs.contains(ast::LinkingDirective::EXTERN) {
|
||||||
|
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
|
||||||
|
}
|
||||||
|
if space == ".global" {
|
||||||
|
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new())
|
||||||
|
} else {
|
||||||
|
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(ast::Variable{ align, array_init, v_type, name })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||||
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
|
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
|
||||||
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
|
".param" <var:VariableScalar<ParamScalarType>> => {
|
||||||
|
let (align, t, name) = var;
|
||||||
let v_type = ast::VariableParamType::Scalar(t);
|
let v_type = ast::VariableParamType::Scalar(t);
|
||||||
(align, Vec::new(), v_type, name)
|
(align, Vec::new(), v_type, name)
|
||||||
},
|
},
|
||||||
".param" <align:Align?> <arr:ArrayDefinition> => {
|
".param" <var:VariableArrayOrPointer<SizedScalarType>> => {
|
||||||
let (array_init, name, (t, dimensions)) = arr;
|
let (align, t, name, arr_or_ptr) = var;
|
||||||
let v_type = ast::VariableParamType::Array(t, dimensions);
|
let (v_type, array_init) = match arr_or_ptr {
|
||||||
|
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||||
|
(ast::VariableParamType::Array(t, dimensions), init)
|
||||||
|
}
|
||||||
|
ast::ArrayOrPointer::Pointer => {
|
||||||
|
(ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new())
|
||||||
|
}
|
||||||
|
};
|
||||||
(align, array_init, v_type, name)
|
(align, array_init, v_type, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
|
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
|
||||||
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
|
<var:ParamVariable> =>? {
|
||||||
let v_type = ast::VariableParamType::Scalar(t);
|
let (align, array_init, v_type, name) = var;
|
||||||
(align, v_type, name)
|
if array_init.len() > 0 {
|
||||||
},
|
Err(ParseError::User { error: ast::PtxError::ArrayInitalizer })
|
||||||
".param" <align:Align?> <arr:ArrayDeclaration> => {
|
} else {
|
||||||
let (name, (t, dimensions)) = arr;
|
Ok((align, v_type, name))
|
||||||
let v_type = ast::VariableParamType::Array(t, dimensions);
|
}
|
||||||
(align, v_type, name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalVariableDefinition: (Option<u32>, Vec<u8>, ast::VariableLocalType, &'input str) = {
|
GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = {
|
||||||
<align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
|
<scalar:VariableScalar<SizedScalarType>> => {
|
||||||
let v_type = ast::VariableLocalType::Scalar(t);
|
let (align, t, name) = scalar;
|
||||||
(align, Vec::new(), v_type, name)
|
let v_type = ast::VariableGlobalType::Scalar(t);
|
||||||
|
(align, v_type, name, Vec::new())
|
||||||
},
|
},
|
||||||
<align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
<var:VariableVector<SizedScalarType>> => {
|
||||||
let v_type = ast::VariableLocalType::Vector(t, v_len);
|
let (align, v_len, t, name) = var;
|
||||||
(align, Vec::new(), v_type, name)
|
let v_type = ast::VariableGlobalType::Vector(t, v_len);
|
||||||
|
(align, v_type, name, Vec::new())
|
||||||
},
|
},
|
||||||
<align:Align?> <arr:ArrayDefinition> => {
|
|
||||||
let (array_init, name, (t, dimensions)) = arr;
|
|
||||||
let v_type = ast::VariableLocalType::Array(t, dimensions);
|
|
||||||
(align, array_init, v_type, name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -461,60 +524,6 @@ ParamScalarType: ast::ParamScalarType = {
|
||||||
".f64" => ast::ParamScalarType::F64,
|
".f64" => ast::ParamScalarType::F64,
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayDefinition: (Vec<u8>, &'input str, (ast::SizedScalarType, Vec<u32>)) = {
|
|
||||||
<typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? {
|
|
||||||
let mut dims = dims;
|
|
||||||
let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?;
|
|
||||||
Ok((
|
|
||||||
array_init,
|
|
||||||
name,
|
|
||||||
(typ, dims)
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec<u32>)) = {
|
|
||||||
<typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimension+> =>? {
|
|
||||||
let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::<Result<_,_>>()?;
|
|
||||||
Ok((name, (typ, dims)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// [0] and [] are treated the same
|
|
||||||
ArrayDimensions: Vec<u32> = {
|
|
||||||
ArrayEmptyDimension => vec![0u32],
|
|
||||||
ArrayEmptyDimension <dims:ArrayDimension+> => {
|
|
||||||
let mut dims = dims;
|
|
||||||
let mut result = vec![0u32];
|
|
||||||
result.append(&mut dims);
|
|
||||||
result
|
|
||||||
},
|
|
||||||
<dims:ArrayDimension+> => dims
|
|
||||||
}
|
|
||||||
|
|
||||||
ArrayEmptyDimension = {
|
|
||||||
"[" "]"
|
|
||||||
}
|
|
||||||
|
|
||||||
ArrayDimension: u32 = {
|
|
||||||
"[" <n:Num> "]" =>? {
|
|
||||||
str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ArrayInitializer: ast::NumsOrArrays<'input> = {
|
|
||||||
"=" <nums:NumsOrArraysBracket> => nums
|
|
||||||
}
|
|
||||||
|
|
||||||
NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
|
|
||||||
"{" <nums:NumsOrArrays> "}" => nums
|
|
||||||
}
|
|
||||||
|
|
||||||
NumsOrArrays: ast::NumsOrArrays<'input> = {
|
|
||||||
<n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
|
|
||||||
<n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
|
|
||||||
}
|
|
||||||
|
|
||||||
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
InstLd,
|
InstLd,
|
||||||
InstMov,
|
InstMov,
|
||||||
|
@ -1311,6 +1320,73 @@ BitType = {
|
||||||
".b8", ".b16", ".b32", ".b64"
|
".b8", ".b16", ".b32", ".b64"
|
||||||
};
|
};
|
||||||
|
|
||||||
|
VariableScalar<T>: (Option<u32>, T, &'input str) = {
|
||||||
|
<align:Align?> <v_type:T> <name:ExtendedID> => {
|
||||||
|
(align, v_type, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VariableVector<T>: (Option<u32>, u8, T, &'input str) = {
|
||||||
|
<align:Align?> <v_len:VectorPrefix> <v_type:T> <name:ExtendedID> => {
|
||||||
|
(align, v_len, v_type, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// empty dimensions [0] means it's a pointer
|
||||||
|
VariableArrayOrPointer<T>: (Option<u32>, T, &'input str, ast::ArrayOrPointer) = {
|
||||||
|
<align:Align?> <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? {
|
||||||
|
let mut dims = dims;
|
||||||
|
let array_init = match init {
|
||||||
|
Some(init) => {
|
||||||
|
let init_vec = init.to_vec(typ, &mut dims)?;
|
||||||
|
ast::ArrayOrPointer::Array { dimensions: dims, init: init_vec }
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if dims.len() > 1 && dims.contains(&0) {
|
||||||
|
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray })
|
||||||
|
}
|
||||||
|
ast::ArrayOrPointer::Pointer
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok((align, typ, name, array_init))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// [0] and [] are treated the same
|
||||||
|
ArrayDimensions: Vec<u32> = {
|
||||||
|
ArrayEmptyDimension => vec![0u32],
|
||||||
|
ArrayEmptyDimension <dims:ArrayDimension+> => {
|
||||||
|
let mut dims = dims;
|
||||||
|
let mut result = vec![0u32];
|
||||||
|
result.append(&mut dims);
|
||||||
|
result
|
||||||
|
},
|
||||||
|
<dims:ArrayDimension+> => dims
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayEmptyDimension = {
|
||||||
|
"[" "]"
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayDimension: u32 = {
|
||||||
|
"[" <n:Num> "]" =>? {
|
||||||
|
str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayInitializer: ast::NumsOrArrays<'input> = {
|
||||||
|
"=" <nums:NumsOrArraysBracket> => nums
|
||||||
|
}
|
||||||
|
|
||||||
|
NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
|
||||||
|
"{" <nums:NumsOrArrays> "}" => nums
|
||||||
|
}
|
||||||
|
|
||||||
|
NumsOrArrays: ast::NumsOrArrays<'input> = {
|
||||||
|
<n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
|
||||||
|
<n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
|
||||||
|
}
|
||||||
|
|
||||||
Comma<T>: Vec<T> = {
|
Comma<T>: Vec<T> = {
|
||||||
<v:(<T> ",")*> <e:T?> => match e {
|
<v:(<T> ",")*> <e:T?> => match e {
|
||||||
None => v,
|
None => v,
|
||||||
|
@ -1329,3 +1405,9 @@ CommaNonEmpty<T>: Vec<T> = {
|
||||||
v
|
v
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
Or<T1, T2>: T1 = {
|
||||||
|
T1,
|
||||||
|
T2
|
||||||
|
}
|
5
ptx/src/test/spirv_build/global_extern_array.ptx
Normal file
5
ptx/src/test/spirv_build/global_extern_array.ptx
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.extern .global .b32 foobar [1];
|
10
ptx/src/test/spirv_build/param_func_array_0.ptx
Normal file
10
ptx/src/test/spirv_build/param_func_array_0.ptx
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .func foobar(
|
||||||
|
.param .b32 foobar[]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
ret;
|
||||||
|
}
|
5
ptx/src/test/spirv_fail/const_ptr.ptx
Normal file
5
ptx/src/test/spirv_fail/const_ptr.ptx
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.const .b32 foobar [];
|
5
ptx/src/test/spirv_fail/global_ptr.ptx
Normal file
5
ptx/src/test/spirv_fail/global_ptr.ptx
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.global .b32 foobar [];
|
12
ptx/src/test/spirv_fail/local_ptr.txt
Normal file
12
ptx/src/test/spirv_fail/local_ptr.txt
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
|
||||||
|
.visible .entry func()
|
||||||
|
{
|
||||||
|
|
||||||
|
.local .b32 foobar [1];
|
||||||
|
|
||||||
|
ret;
|
||||||
|
}
|
10
ptx/src/test/spirv_fail/param_entry_array_0.ptx
Normal file
10
ptx/src/test/spirv_fail/param_entry_array_0.ptx
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry foobar(
|
||||||
|
.param .b32 foobar[]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
ret;
|
||||||
|
}
|
10
ptx/src/test/spirv_fail/param_vector.ptx
Normal file
10
ptx/src/test/spirv_fail/param_vector.ptx
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .func foobar(
|
||||||
|
.param .b32 .v2 foobar
|
||||||
|
)
|
||||||
|
{
|
||||||
|
ret;
|
||||||
|
}
|
5
ptx/src/test/spirv_fail/shared_ptr.ptx
Normal file
5
ptx/src/test/spirv_fail/shared_ptr.ptx
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
extern .shared .b32 foobar [];
|
13
ptx/src/test/spirv_fail/shared_ptr2.ptx
Normal file
13
ptx/src/test/spirv_fail/shared_ptr2.ptx
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.extern .shared .b32 foobar1 [];
|
||||||
|
|
||||||
|
.visible .func _Z4dupaPf(
|
||||||
|
.param .b64 _Z4dupaPf_param_0
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.shared .b32 foobar2 [];
|
||||||
|
ret;
|
||||||
|
}
|
24
ptx/src/test/spirv_run/extern_shared.ptx
Normal file
24
ptx/src/test/spirv_run/extern_shared.ptx
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.extern .shared .b32 shared_mem [];
|
||||||
|
|
||||||
|
.visible .entry extern_shared(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .u64 temp;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.global.u64 temp, [in_addr];
|
||||||
|
st.shared.u64 [shared_mem], temp;
|
||||||
|
ld.shared.u64 temp, [shared_mem];
|
||||||
|
st.global.u64 [out_addr], temp;
|
||||||
|
ret;
|
||||||
|
}
|
53
ptx/src/test/spirv_run/extern_shared.spvtxt
Normal file
53
ptx/src/test/spirv_run/extern_shared.spvtxt
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Int8
|
||||||
|
%29 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "cvta"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%32 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%float = OpTypeFloat 32
|
||||||
|
%_ptr_Function_float = OpTypePointer Function %float
|
||||||
|
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
|
||||||
|
%1 = OpFunction %void None %32
|
||||||
|
%7 = OpFunctionParameter %ulong
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%27 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_float Function
|
||||||
|
OpStore %2 %7
|
||||||
|
OpStore %3 %8
|
||||||
|
%10 = OpLoad %ulong %2
|
||||||
|
%9 = OpCopyObject %ulong %10
|
||||||
|
OpStore %4 %9
|
||||||
|
%12 = OpLoad %ulong %3
|
||||||
|
%11 = OpCopyObject %ulong %12
|
||||||
|
OpStore %5 %11
|
||||||
|
%14 = OpLoad %ulong %4
|
||||||
|
%22 = OpCopyObject %ulong %14
|
||||||
|
%21 = OpCopyObject %ulong %22
|
||||||
|
%13 = OpCopyObject %ulong %21
|
||||||
|
OpStore %4 %13
|
||||||
|
%16 = OpLoad %ulong %5
|
||||||
|
%24 = OpCopyObject %ulong %16
|
||||||
|
%23 = OpCopyObject %ulong %24
|
||||||
|
%15 = OpCopyObject %ulong %23
|
||||||
|
OpStore %5 %15
|
||||||
|
%18 = OpLoad %ulong %4
|
||||||
|
%25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
|
||||||
|
%17 = OpLoad %float %25
|
||||||
|
OpStore %6 %17
|
||||||
|
%19 = OpLoad %ulong %5
|
||||||
|
%20 = OpLoad %float %6
|
||||||
|
%26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
|
||||||
|
OpStore %26 %20
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
45
ptx/src/test/spirv_run/extern_shared_call.ptx
Normal file
45
ptx/src/test/spirv_run/extern_shared_call.ptx
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.extern .shared .align 4 .b32 shared_mem[];
|
||||||
|
|
||||||
|
.func (.param .u64 output) incr_shared_2_param(
|
||||||
|
.param .u64 .ptr .shared shared_mem_addr
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 temp;
|
||||||
|
ld.shared.u64 temp, [shared_mem_addr];
|
||||||
|
add.u64 temp, temp, 2;
|
||||||
|
st.param.u64 [output], temp;
|
||||||
|
ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
.func (.param .u64 output) incr_shared_2_global()
|
||||||
|
{
|
||||||
|
.reg .u64 temp;
|
||||||
|
ld.shared.u64 temp, [shared_mem];
|
||||||
|
add.u64 temp, temp, 2;
|
||||||
|
st.param.u64 [output], temp;
|
||||||
|
ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.visible .entry extern_shared(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .u64 temp;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.global.u64 temp, [in_addr];
|
||||||
|
st.shared.u64 [shared_mem], temp;
|
||||||
|
ld.shared.u64 temp, [shared_mem];
|
||||||
|
st.global.u64 [out_addr], temp;
|
||||||
|
ret;
|
||||||
|
}
|
53
ptx/src/test/spirv_run/extern_shared_call.spvtxt
Normal file
53
ptx/src/test/spirv_run/extern_shared_call.spvtxt
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Int8
|
||||||
|
%29 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %1 "cvta"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%32 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%float = OpTypeFloat 32
|
||||||
|
%_ptr_Function_float = OpTypePointer Function %float
|
||||||
|
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
|
||||||
|
%1 = OpFunction %void None %32
|
||||||
|
%7 = OpFunctionParameter %ulong
|
||||||
|
%8 = OpFunctionParameter %ulong
|
||||||
|
%27 = OpLabel
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%6 = OpVariable %_ptr_Function_float Function
|
||||||
|
OpStore %2 %7
|
||||||
|
OpStore %3 %8
|
||||||
|
%10 = OpLoad %ulong %2
|
||||||
|
%9 = OpCopyObject %ulong %10
|
||||||
|
OpStore %4 %9
|
||||||
|
%12 = OpLoad %ulong %3
|
||||||
|
%11 = OpCopyObject %ulong %12
|
||||||
|
OpStore %5 %11
|
||||||
|
%14 = OpLoad %ulong %4
|
||||||
|
%22 = OpCopyObject %ulong %14
|
||||||
|
%21 = OpCopyObject %ulong %22
|
||||||
|
%13 = OpCopyObject %ulong %21
|
||||||
|
OpStore %4 %13
|
||||||
|
%16 = OpLoad %ulong %5
|
||||||
|
%24 = OpCopyObject %ulong %16
|
||||||
|
%23 = OpCopyObject %ulong %24
|
||||||
|
%15 = OpCopyObject %ulong %23
|
||||||
|
OpStore %5 %15
|
||||||
|
%18 = OpLoad %ulong %4
|
||||||
|
%25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
|
||||||
|
%17 = OpLoad %float %25
|
||||||
|
OpStore %6 %17
|
||||||
|
%19 = OpLoad %ulong %5
|
||||||
|
%20 = OpLoad %float %6
|
||||||
|
%26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
|
||||||
|
OpStore %26 %20
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
|
@ -78,6 +78,7 @@ test_ptx!(sub, [2u64], [1u64]);
|
||||||
test_ptx!(min, [555i32, 444i32], [444i32]);
|
test_ptx!(min, [555i32, 444i32], [444i32]);
|
||||||
test_ptx!(max, [555i32, 444i32], [555i32]);
|
test_ptx!(max, [555i32, 444i32], [555i32]);
|
||||||
test_ptx!(global_array, [0xDEADu32], [1u32]);
|
test_ptx!(global_array, [0xDEADu32], [1u32]);
|
||||||
|
test_ptx!(extern_shared, [127u64], [127u64]);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
|
|
@ -34,11 +34,7 @@ enum SpirvType {
|
||||||
|
|
||||||
impl SpirvType {
|
impl SpirvType {
|
||||||
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
|
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
|
||||||
let key = match t {
|
let key = t.into();
|
||||||
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
|
||||||
ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len),
|
|
||||||
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
|
|
||||||
};
|
|
||||||
SpirvType::Pointer(Box::new(key), sc)
|
SpirvType::Pointer(Box::new(key), sc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,6 +45,20 @@ impl From<ast::Type> for SpirvType {
|
||||||
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
||||||
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
|
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
|
||||||
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
||||||
|
ast::Type::Pointer(typ, state_space) => {
|
||||||
|
SpirvType::Pointer(Box::new(SpirvType::Base(typ.into())), state_space.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Into<spirv::StorageClass> for ast::PointerStateSpace {
|
||||||
|
fn into(self) -> spirv::StorageClass {
|
||||||
|
match self {
|
||||||
|
ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
|
||||||
|
ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
|
||||||
|
ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
|
||||||
|
ast::PointerStateSpace::Param => spirv::StorageClass::Function,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -354,6 +364,14 @@ impl TypeWordMap {
|
||||||
b.constant_composite(result_type, None, &components)
|
b.constant_composite(result_type, None, &components)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
ast::Type::Pointer(typ, state_space) => {
|
||||||
|
let base = self.get_or_add_constant(b, &ast::Type::Scalar(*typ), &[])?;
|
||||||
|
let result_type = self.get_or_add(
|
||||||
|
b,
|
||||||
|
SpirvType::Pointer(Box::new(SpirvType::from(*typ)), (*state_space).into()),
|
||||||
|
);
|
||||||
|
b.variable(result_type, None, (*state_space).into(), Some(base))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -415,13 +433,7 @@ pub fn to_spirv_module<'a>(
|
||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
|
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
|
||||||
emit_function_header(
|
emit_function_header(&mut builder, &mut map, &id_defs, f.func_decl, &mut args_len)?;
|
||||||
&mut builder,
|
|
||||||
&mut map,
|
|
||||||
&id_defs,
|
|
||||||
f.func_directive,
|
|
||||||
&mut args_len,
|
|
||||||
)?;
|
|
||||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
||||||
builder.end_function()?;
|
builder.end_function()?;
|
||||||
}
|
}
|
||||||
|
@ -430,6 +442,202 @@ pub fn to_spirv_module<'a>(
|
||||||
Ok((builder.module(), args_len))
|
Ok((builder.module(), args_len))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
|
||||||
|
|
||||||
|
fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
|
||||||
|
match m.entry(key) {
|
||||||
|
hash_map::Entry::Occupied(mut entry) => {
|
||||||
|
entry.get_mut().push(value);
|
||||||
|
}
|
||||||
|
hash_map::Entry::Vacant(entry) => {
|
||||||
|
entry.insert(vec![value]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PTX represents dynamically allocated shared local memory as
|
||||||
|
// .extern .shared .align 4 .b8 shared_mem[];
|
||||||
|
// In SPIRV/OpenCL world this is expressed as an additional argument
|
||||||
|
// This pass looks for all uses of .extern .shared and converts them to
|
||||||
|
// an additional method argument
|
||||||
|
fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
|
new_id: &mut impl FnMut() -> spirv::Word,
|
||||||
|
id_defs: &mut GlobalStringIdResolver<'input>,
|
||||||
|
module: Vec<Directive<'input>>,
|
||||||
|
) -> Vec<Directive<'input>> {
|
||||||
|
let mut extern_shared_decls = HashSet::new();
|
||||||
|
for dir in module.iter() {
|
||||||
|
match dir {
|
||||||
|
Directive::Variable(var) => {
|
||||||
|
if let ast::VariableType::Shared(_) = var.v_type {
|
||||||
|
extern_shared_decls.insert(var.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if extern_shared_decls.len() == 0 {
|
||||||
|
return module;
|
||||||
|
}
|
||||||
|
let mut methods_using_extern_shared = HashSet::new();
|
||||||
|
let mut directly_called_by = MultiHashMap::new();
|
||||||
|
let module = module
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| match directive {
|
||||||
|
Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
}) => {
|
||||||
|
let call_key = match func_decl {
|
||||||
|
ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name),
|
||||||
|
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
||||||
|
};
|
||||||
|
let statements = statements
|
||||||
|
.into_iter()
|
||||||
|
.map(|statement| match statement {
|
||||||
|
Statement::Call(call) => {
|
||||||
|
multi_hash_map_append(&mut directly_called_by, call.func, call_key);
|
||||||
|
Statement::Call(call)
|
||||||
|
}
|
||||||
|
statement => statement.map_id(&mut |id| {
|
||||||
|
if extern_shared_decls.contains(&id) {
|
||||||
|
methods_using_extern_shared.insert(call_key);
|
||||||
|
}
|
||||||
|
id
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
directive => directive,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
|
||||||
|
// make sure it gets propagated to `fn1` and `kernel`
|
||||||
|
get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
|
||||||
|
// now visit every method declaration and inject those additional arguments
|
||||||
|
module
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| match directive {
|
||||||
|
Directive::Method(Function {
|
||||||
|
mut func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
}) => {
|
||||||
|
let call_key = match func_decl {
|
||||||
|
ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name),
|
||||||
|
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
||||||
|
};
|
||||||
|
if !methods_using_extern_shared.contains(&call_key) {
|
||||||
|
return Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let shared_id_param = new_id();
|
||||||
|
match &mut func_decl {
|
||||||
|
ast::MethodDecl::Func(_, _, input_args) => {
|
||||||
|
input_args.push(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: ast::FnArgumentType::Shared,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
name: shared_id_param,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
ast::MethodDecl::Kernel(_, input_args) => {
|
||||||
|
input_args.push(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: ast::KernelArgumentType::Shared,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
name: shared_id_param,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let statements = statements
|
||||||
|
.into_iter()
|
||||||
|
.map(|statement| match statement {
|
||||||
|
Statement::Call(mut call) => {
|
||||||
|
// We can safely skip checking call arguments,
|
||||||
|
// because there's simply no way to pass shared ptr
|
||||||
|
// without converting it to .b64 first
|
||||||
|
if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func))
|
||||||
|
{
|
||||||
|
call.param_list
|
||||||
|
.push((shared_id_param, ast::FnArgumentType::Shared));
|
||||||
|
}
|
||||||
|
Statement::Call(call)
|
||||||
|
}
|
||||||
|
statement => statement.map_id(&mut |id| {
|
||||||
|
if extern_shared_decls.contains(&id) {
|
||||||
|
shared_id_param
|
||||||
|
} else {
|
||||||
|
id
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
directive => directive,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_callers_of_extern_shared<'a>(
|
||||||
|
methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
|
||||||
|
directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>,
|
||||||
|
) {
|
||||||
|
let direct_uses_of_extern_shared = methods_using_extern_shared
|
||||||
|
.iter()
|
||||||
|
.filter_map(|method| {
|
||||||
|
if let CallgraphKey::Func(f_id) = method {
|
||||||
|
Some(*f_id)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for fn_id in direct_uses_of_extern_shared {
|
||||||
|
get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_callers_of_extern_shared_single<'a>(
|
||||||
|
methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
|
||||||
|
directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>,
|
||||||
|
fn_id: spirv::Word,
|
||||||
|
) {
|
||||||
|
if let Some(callers) = directly_called_by.get(&fn_id) {
|
||||||
|
for caller in callers {
|
||||||
|
if methods_using_extern_shared.insert(*caller) {
|
||||||
|
if let CallgraphKey::Func(caller_fn) = caller {
|
||||||
|
get_callers_of_extern_shared_single(
|
||||||
|
methods_using_extern_shared,
|
||||||
|
directly_called_by,
|
||||||
|
*caller_fn,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
|
||||||
|
enum CallgraphKey<'input> {
|
||||||
|
Kernel(&'input str),
|
||||||
|
Func(spirv::Word),
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_builtins(
|
fn emit_builtins(
|
||||||
builder: &mut dr::Builder,
|
builder: &mut dr::Builder,
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
|
@ -594,6 +802,7 @@ fn expand_fn_params<'a, 'b>(
|
||||||
let ss = match a.v_type {
|
let ss = match a.v_type {
|
||||||
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
|
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
|
||||||
ast::FnArgumentType::Param(_) => StateSpace::Param,
|
ast::FnArgumentType::Param(_) => StateSpace::Param,
|
||||||
|
ast::FnArgumentType::Shared => StateSpace::Shared,
|
||||||
};
|
};
|
||||||
ast::FnArgument {
|
ast::FnArgument {
|
||||||
name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type.clone())))),
|
name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type.clone())))),
|
||||||
|
@ -615,7 +824,7 @@ fn to_ssa<'input, 'b>(
|
||||||
Some(vec) => vec,
|
Some(vec) => vec,
|
||||||
None => {
|
None => {
|
||||||
return Ok(Function {
|
return Ok(Function {
|
||||||
func_directive: f_args,
|
func_decl: f_args,
|
||||||
body: None,
|
body: None,
|
||||||
globals: Vec::new(),
|
globals: Vec::new(),
|
||||||
})
|
})
|
||||||
|
@ -637,7 +846,7 @@ fn to_ssa<'input, 'b>(
|
||||||
let sorted_statements = normalize_variable_decls(labeled_statements);
|
let sorted_statements = normalize_variable_decls(labeled_statements);
|
||||||
let (f_body, globals) = extract_globals(sorted_statements);
|
let (f_body, globals) = extract_globals(sorted_statements);
|
||||||
Ok(Function {
|
Ok(Function {
|
||||||
func_directive: f_args,
|
func_decl: f_args,
|
||||||
globals: globals,
|
globals: globals,
|
||||||
body: Some(f_body),
|
body: Some(f_body),
|
||||||
})
|
})
|
||||||
|
@ -935,7 +1144,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
||||||
let new_id = id_def.new_id(typ.clone());
|
let new_id = id_def.new_id(typ.clone());
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: p.align,
|
align: p.align,
|
||||||
v_type: ast::VariableType::Param(p.v_type.clone()),
|
v_type: ast::VariableType::Param(p.v_type.clone().to_param()),
|
||||||
name: p.name,
|
name: p.name,
|
||||||
array_init: p.array_init.clone(),
|
array_init: p.array_init.clone(),
|
||||||
}));
|
}));
|
||||||
|
@ -1878,26 +2087,33 @@ fn emit_variable(
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
var: &ast::Variable<ast::VariableType, spirv::Word>,
|
var: &ast::Variable<ast::VariableType, spirv::Word>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let (should_init, st_class) = match var.v_type {
|
let (must_init, st_class) = match var.v_type {
|
||||||
ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
|
ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
|
||||||
(false, spirv::StorageClass::Function)
|
(false, spirv::StorageClass::Function)
|
||||||
}
|
}
|
||||||
ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
|
ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
|
||||||
|
ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
|
||||||
};
|
};
|
||||||
let type_id = map.get_or_add(
|
let initalizer = if var.array_init.len() > 0 {
|
||||||
builder,
|
|
||||||
SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
|
|
||||||
);
|
|
||||||
let initalizer = if should_init {
|
|
||||||
Some(map.get_or_add_constant(
|
Some(map.get_or_add_constant(
|
||||||
builder,
|
builder,
|
||||||
&ast::Type::from(var.v_type.clone()),
|
&ast::Type::from(var.v_type.clone()),
|
||||||
&*var.array_init,
|
&*var.array_init,
|
||||||
)?)
|
)?)
|
||||||
|
} else if must_init {
|
||||||
|
let type_id = map.get_or_add(
|
||||||
|
builder,
|
||||||
|
SpirvType::from(ast::Type::from(var.v_type.clone())),
|
||||||
|
);
|
||||||
|
Some(builder.constant_null(type_id, None))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
builder.variable(type_id, Some(var.name), st_class, initalizer);
|
let ptr_type_id = map.get_or_add(
|
||||||
|
builder,
|
||||||
|
SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
|
||||||
|
);
|
||||||
|
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
|
||||||
if let Some(align) = var.align {
|
if let Some(align) = var.align {
|
||||||
builder.decorate(
|
builder.decorate(
|
||||||
var.name,
|
var.name,
|
||||||
|
@ -2537,7 +2753,8 @@ fn expand_map_variables<'a, 'b>(
|
||||||
ast::VariableType::Reg(_) => StateSpace::Reg,
|
ast::VariableType::Reg(_) => StateSpace::Reg,
|
||||||
ast::VariableType::Local(_) => StateSpace::Local,
|
ast::VariableType::Local(_) => StateSpace::Local,
|
||||||
ast::VariableType::Param(_) => StateSpace::ParamReg,
|
ast::VariableType::Param(_) => StateSpace::ParamReg,
|
||||||
ast::VariableType::Global(_) => todo!(),
|
ast::VariableType::Global(_) => StateSpace::Global,
|
||||||
|
ast::VariableType::Shared(_) => StateSpace::Shared,
|
||||||
};
|
};
|
||||||
match var.count {
|
match var.count {
|
||||||
Some(count) => {
|
Some(count) => {
|
||||||
|
@ -2888,6 +3105,69 @@ enum Statement<I, P: ast::ArgParams> {
|
||||||
Undef(ast::Type, spirv::Word),
|
Undef(ast::Type, spirv::Word),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ExpandedStatement {
|
||||||
|
fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement {
|
||||||
|
match self {
|
||||||
|
Statement::Label(id) => Statement::Label(f(id)),
|
||||||
|
Statement::Variable(mut var) => {
|
||||||
|
var.name = f(var.name);
|
||||||
|
Statement::Variable(var)
|
||||||
|
}
|
||||||
|
Statement::Instruction(inst) => inst
|
||||||
|
.visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op)))
|
||||||
|
.unwrap(),
|
||||||
|
Statement::LoadVar(mut arg, typ) => {
|
||||||
|
arg.dst = f(arg.dst);
|
||||||
|
arg.src = f(arg.src);
|
||||||
|
Statement::LoadVar(arg, typ)
|
||||||
|
}
|
||||||
|
Statement::StoreVar(mut arg, typ) => {
|
||||||
|
arg.src1 = f(arg.src1);
|
||||||
|
arg.src2 = f(arg.src2);
|
||||||
|
Statement::StoreVar(arg, typ)
|
||||||
|
}
|
||||||
|
Statement::Call(mut call) => {
|
||||||
|
for (id, _) in call.ret_params.iter_mut() {
|
||||||
|
*id = f(*id);
|
||||||
|
}
|
||||||
|
call.func = f(call.func);
|
||||||
|
for (id, _) in call.param_list.iter_mut() {
|
||||||
|
*id = f(*id);
|
||||||
|
}
|
||||||
|
Statement::Call(call)
|
||||||
|
}
|
||||||
|
Statement::Composite(mut composite) => {
|
||||||
|
composite.dst = f(composite.dst);
|
||||||
|
composite.src_composite = f(composite.src_composite);
|
||||||
|
Statement::Composite(composite)
|
||||||
|
}
|
||||||
|
Statement::Conditional(mut conditional) => {
|
||||||
|
conditional.predicate = f(conditional.predicate);
|
||||||
|
conditional.if_true = f(conditional.if_true);
|
||||||
|
conditional.if_false = f(conditional.if_false);
|
||||||
|
Statement::Conditional(conditional)
|
||||||
|
}
|
||||||
|
Statement::Conversion(mut conv) => {
|
||||||
|
conv.dst = f(conv.dst);
|
||||||
|
conv.src = f(conv.src);
|
||||||
|
Statement::Conversion(conv)
|
||||||
|
}
|
||||||
|
Statement::Constant(mut constant) => {
|
||||||
|
constant.dst = f(constant.dst);
|
||||||
|
Statement::Constant(constant)
|
||||||
|
}
|
||||||
|
Statement::RetValue(data, id) => {
|
||||||
|
let id = f(id);
|
||||||
|
Statement::RetValue(data, id)
|
||||||
|
}
|
||||||
|
Statement::Undef(typ, id) => {
|
||||||
|
let id = f(id);
|
||||||
|
Statement::Undef(typ, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ResolvedCall<P: ast::ArgParams> {
|
struct ResolvedCall<P: ast::ArgParams> {
|
||||||
pub uniform: bool,
|
pub uniform: bool,
|
||||||
pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
|
pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
|
||||||
|
@ -3106,7 +3386,7 @@ enum Directive<'input> {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Function<'input> {
|
struct Function<'input> {
|
||||||
pub func_directive: ast::MethodDecl<'input, spirv::Word>,
|
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
|
||||||
pub globals: Vec<ExpandedStatement>,
|
pub globals: Vec<ExpandedStatement>,
|
||||||
pub body: Option<Vec<ExpandedStatement>>,
|
pub body: Option<Vec<ExpandedStatement>>,
|
||||||
}
|
}
|
||||||
|
@ -3546,18 +3826,28 @@ impl ast::Type {
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: Vec::new(),
|
components: Vec::new(),
|
||||||
|
state_space: ast::PointerStateSpace::Global,
|
||||||
},
|
},
|
||||||
ast::Type::Vector(scalar, components) => TypeParts {
|
ast::Type::Vector(scalar, components) => TypeParts {
|
||||||
kind: TypeKind::Vector,
|
kind: TypeKind::Vector,
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: vec![*components as u32],
|
components: vec![*components as u32],
|
||||||
|
state_space: ast::PointerStateSpace::Global,
|
||||||
},
|
},
|
||||||
ast::Type::Array(scalar, components) => TypeParts {
|
ast::Type::Array(scalar, components) => TypeParts {
|
||||||
kind: TypeKind::Array,
|
kind: TypeKind::Array,
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: components.clone(),
|
components: components.clone(),
|
||||||
|
state_space: ast::PointerStateSpace::Global,
|
||||||
|
},
|
||||||
|
ast::Type::Pointer(scalar, state_space) => TypeParts {
|
||||||
|
kind: TypeKind::Pointer,
|
||||||
|
scalar_kind: scalar.kind(),
|
||||||
|
width: scalar.size_of(),
|
||||||
|
components: Vec::new(),
|
||||||
|
state_space: *state_space,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3575,6 +3865,10 @@ impl ast::Type {
|
||||||
ast::ScalarType::from_parts(t.width, t.scalar_kind),
|
ast::ScalarType::from_parts(t.width, t.scalar_kind),
|
||||||
t.components,
|
t.components,
|
||||||
),
|
),
|
||||||
|
TypeKind::Pointer => ast::Type::Pointer(
|
||||||
|
ast::ScalarType::from_parts(t.width, t.scalar_kind),
|
||||||
|
t.state_space,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3585,6 +3879,7 @@ struct TypeParts {
|
||||||
scalar_kind: ScalarKind,
|
scalar_kind: ScalarKind,
|
||||||
width: u8,
|
width: u8,
|
||||||
components: Vec<u32>,
|
components: Vec<u32>,
|
||||||
|
state_space: ast::PointerStateSpace,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Eq, PartialEq, Copy, Clone)]
|
#[derive(Eq, PartialEq, Copy, Clone)]
|
||||||
|
@ -3592,6 +3887,7 @@ enum TypeKind {
|
||||||
Scalar,
|
Scalar,
|
||||||
Vector,
|
Vector,
|
||||||
Array,
|
Array,
|
||||||
|
Pointer,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ast::Instruction<ExpandedArgParams> {
|
impl ast::Instruction<ExpandedArgParams> {
|
||||||
|
@ -3762,6 +4058,36 @@ impl ast::VariableParamType {
|
||||||
(ast::ScalarType::from(*t).size_of() as usize)
|
(ast::ScalarType::from(*t).size_of() as usize)
|
||||||
* (len.iter().fold(1, |x, y| x * (*y)) as usize)
|
* (len.iter().fold(1, |x, y| x * (*y)) as usize)
|
||||||
}
|
}
|
||||||
|
ast::VariableParamType::Pointer(_, _) => mem::size_of::<usize>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ast::KernelArgumentType {
|
||||||
|
fn width(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
ast::KernelArgumentType::Normal(t) => t.width(),
|
||||||
|
ast::KernelArgumentType::Shared => mem::size_of::<usize>(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ast::KernelArgumentType> for ast::Type {
|
||||||
|
fn from(this: ast::KernelArgumentType) -> Self {
|
||||||
|
match this {
|
||||||
|
ast::KernelArgumentType::Normal(typ) => typ.into(),
|
||||||
|
ast::KernelArgumentType::Shared => ast::Type::Scalar(ast::ScalarType::B64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ast::KernelArgumentType {
|
||||||
|
fn to_param(self) -> ast::VariableParamType {
|
||||||
|
match self {
|
||||||
|
ast::KernelArgumentType::Normal(p) => p,
|
||||||
|
ast::KernelArgumentType::Shared => {
|
||||||
|
ast::VariableParamType::Scalar(ast::ParamScalarType::B64)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4598,6 +4924,7 @@ impl From<ast::FnArgumentType> for ast::VariableType {
|
||||||
match t {
|
match t {
|
||||||
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
|
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
|
||||||
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
|
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
|
||||||
|
ast::FnArgumentType::Shared => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4648,6 +4975,17 @@ fn bitcast_physical_pointer(
|
||||||
ss: Option<ast::LdStateSpace>,
|
ss: Option<ast::LdStateSpace>,
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
match operand_type {
|
match operand_type {
|
||||||
|
// array decays to a pointer
|
||||||
|
ast::Type::Array(_, vec) => {
|
||||||
|
if vec.len() != 0 {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
if let Some(space) = ss {
|
||||||
|
Ok(Some(ConversionKind::BitToPtr(space)))
|
||||||
|
} else {
|
||||||
|
Err(TranslateError::Unreachable)
|
||||||
|
}
|
||||||
|
}
|
||||||
ast::Type::Scalar(ast::ScalarType::B64)
|
ast::Type::Scalar(ast::ScalarType::B64)
|
||||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||||
| ast::Type::Scalar(ast::ScalarType::S64) => {
|
| ast::Type::Scalar(ast::ScalarType::S64) => {
|
||||||
|
@ -4882,7 +5220,10 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> {
|
||||||
f(&ast::FnArgument {
|
f(&ast::FnArgument {
|
||||||
align: arg.align,
|
align: arg.align,
|
||||||
name: arg.name,
|
name: arg.name,
|
||||||
v_type: ast::FnArgumentType::Param(arg.v_type.clone()),
|
v_type: match arg.v_type.clone() {
|
||||||
|
ast::KernelArgumentType::Normal(typ) => ast::FnArgumentType::Param(typ),
|
||||||
|
ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
|
||||||
|
},
|
||||||
array_init: arg.array_init.clone(),
|
array_init: arg.array_init.clone(),
|
||||||
})
|
})
|
||||||
}),
|
}),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue