mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Continue working on ftz modes
This commit is contained in:
parent
17529f951d
commit
5121bba285
15 changed files with 559 additions and 226 deletions
|
@ -2,8 +2,8 @@ use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
@ -12,8 +12,8 @@ pub(super) fn run<'a, 'input>(
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2,
|
resolver: &mut GlobalStringIdentResolver2,
|
||||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
|
@ -22,13 +22,13 @@ fn run_directive<'input>(
|
||||||
|
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2,
|
resolver: &mut GlobalStringIdentResolver2,
|
||||||
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
let is_declaration = method.body.is_none();
|
let is_declaration = method.body.is_none();
|
||||||
let mut body = Vec::new();
|
let mut body = Vec::new();
|
||||||
let mut remap_returns = Vec::new();
|
let mut remap_returns = Vec::new();
|
||||||
if !method.func_decl.name.is_kernel() {
|
if !method.is_kernel {
|
||||||
for arg in method.func_decl.return_arguments.iter_mut() {
|
for arg in method.return_arguments.iter_mut() {
|
||||||
match arg.state_space {
|
match arg.state_space {
|
||||||
ptx_parser::StateSpace::Param => {
|
ptx_parser::StateSpace::Param => {
|
||||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||||
|
@ -51,7 +51,7 @@ fn run_method<'input>(
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for arg in method.func_decl.input_arguments.iter_mut() {
|
for arg in method.input_arguments.iter_mut() {
|
||||||
match arg.state_space {
|
match arg.state_space {
|
||||||
ptx_parser::StateSpace::Param => {
|
ptx_parser::StateSpace::Param => {
|
||||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||||
|
@ -96,12 +96,14 @@ fn run_method<'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
func_decl: method.func_decl,
|
return_arguments: method.return_arguments,
|
||||||
globals: method.globals,
|
name: method.name,
|
||||||
|
input_arguments: method.input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
|
is_kernel: method.is_kernel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -168,7 +168,7 @@ impl Deref for MemoryBuffer {
|
||||||
|
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
id_defs: GlobalStringIdentResolver2<'input>,
|
id_defs: GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<MemoryBuffer, TranslateError> {
|
) -> Result<MemoryBuffer, TranslateError> {
|
||||||
let context = Context::new();
|
let context = Context::new();
|
||||||
let module = Module::new(&context, LLVM_UNNAMED);
|
let module = Module::new(&context, LLVM_UNNAMED);
|
||||||
|
@ -218,24 +218,20 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
|
|
||||||
fn emit_method(
|
fn emit_method(
|
||||||
&mut self,
|
&mut self,
|
||||||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let func_decl = method.func_decl;
|
|
||||||
let name = method
|
let name = method
|
||||||
.import_as
|
.import_as
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.or_else(|| match func_decl.name {
|
.or_else(|| self.id_defs.ident_map[&method.name].name.as_deref())
|
||||||
ast::MethodName::Kernel(name) => Some(name),
|
|
||||||
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
|
|
||||||
})
|
|
||||||
.ok_or_else(|| error_unreachable())?;
|
.ok_or_else(|| error_unreachable())?;
|
||||||
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
||||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
||||||
if fn_ == ptr::null_mut() {
|
if fn_ == ptr::null_mut() {
|
||||||
let fn_type = get_function_type(
|
let fn_type = get_function_type(
|
||||||
self.context,
|
self.context,
|
||||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
method.return_arguments.iter().map(|v| &v.v_type),
|
||||||
func_decl
|
method
|
||||||
.input_arguments
|
.input_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
|
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
|
||||||
|
@ -245,15 +241,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
|
self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
|
||||||
self.emit_fn_attribute(fn_, "no-trapping-math", "true");
|
self.emit_fn_attribute(fn_, "no-trapping-math", "true");
|
||||||
}
|
}
|
||||||
if let ast::MethodName::Func(name) = func_decl.name {
|
if !method.is_kernel {
|
||||||
self.resolver.register(name, fn_);
|
self.resolver.register(method.name, fn_);
|
||||||
}
|
}
|
||||||
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
for (i, param) in method.input_arguments.iter().enumerate() {
|
||||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||||
let name = self.resolver.get_or_add(param.name);
|
let name = self.resolver.get_or_add(param.name);
|
||||||
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
|
||||||
self.resolver.register(param.name, value);
|
self.resolver.register(param.name, value);
|
||||||
if func_decl.name.is_kernel() {
|
if method.is_kernel {
|
||||||
let attr_kind = unsafe {
|
let attr_kind = unsafe {
|
||||||
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
|
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
|
||||||
};
|
};
|
||||||
|
@ -267,7 +263,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
|
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let call_conv = if func_decl.name.is_kernel() {
|
let call_conv = if method.is_kernel {
|
||||||
Self::kernel_call_convention()
|
Self::kernel_call_convention()
|
||||||
} else {
|
} else {
|
||||||
Self::func_call_convention()
|
Self::func_call_convention()
|
||||||
|
@ -282,7 +278,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
||||||
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
|
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
|
||||||
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
|
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
|
||||||
for var in func_decl.return_arguments {
|
for var in method.return_arguments {
|
||||||
method_emitter.emit_variable(var)?;
|
method_emitter.emit_variable(var)?;
|
||||||
}
|
}
|
||||||
for statement in statements.iter() {
|
for statement in statements.iter() {
|
||||||
|
@ -1558,7 +1554,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
return self.emit_cvt_float_to_int(
|
return self.emit_cvt_float_to_int(
|
||||||
data.from,
|
data.from,
|
||||||
data.to,
|
data.to,
|
||||||
integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
|
integer_rounding,
|
||||||
arguments,
|
arguments,
|
||||||
Some(LLVMBuildFPToSI),
|
Some(LLVMBuildFPToSI),
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,8 +2,8 @@ use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<UnconditionalDirective<'input>>,
|
directives: Vec<UnconditionalDirective>,
|
||||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
@ -13,11 +13,10 @@ pub(super) fn run<'a, 'input>(
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: Directive2<
|
directive: Directive2<
|
||||||
'input,
|
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>,
|
>,
|
||||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
|
@ -27,11 +26,10 @@ fn run_directive<'input>(
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
method: Function2<
|
method: Function2<
|
||||||
'input,
|
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>,
|
>,
|
||||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
|
@ -43,12 +41,14 @@ fn run_method<'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
func_decl: method.func_decl,
|
return_arguments: method.return_arguments,
|
||||||
globals: method.globals,
|
name: method.name,
|
||||||
|
input_arguments: method.input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
|
is_kernel: method.is_kernel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,30 +1,29 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
special_registers: &'a SpecialRegistersMap2,
|
special_registers: &'a SpecialRegistersMap2,
|
||||||
directives: Vec<UnconditionalDirective<'input>>,
|
directives: Vec<UnconditionalDirective>,
|
||||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||||
let declarations = SpecialRegistersMap2::generate_declarations(resolver);
|
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
|
||||||
let mut result = Vec::with_capacity(declarations.len() + directives.len());
|
|
||||||
let mut sreg_to_function =
|
let mut sreg_to_function =
|
||||||
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
|
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
|
||||||
for (sreg, declaration) in declarations {
|
SpecialRegistersMap2::foreach_declaration(
|
||||||
let name = if let ast::MethodName::Func(name) = declaration.name {
|
resolver,
|
||||||
name
|
|sreg, (return_arguments, name, input_arguments)| {
|
||||||
} else {
|
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
||||||
return Err(error_unreachable());
|
return_arguments,
|
||||||
};
|
name,
|
||||||
result.push(UnconditionalDirective::Method(UnconditionalFunction {
|
input_arguments,
|
||||||
func_decl: declaration,
|
body: None,
|
||||||
globals: Vec::new(),
|
import_as: None,
|
||||||
body: None,
|
tuning: Vec::new(),
|
||||||
import_as: None,
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
tuning: Vec::new(),
|
is_kernel: false,
|
||||||
linkage: ast::LinkingDirective::EXTERN,
|
}));
|
||||||
}));
|
sreg_to_function.insert(sreg, name);
|
||||||
sreg_to_function.insert(sreg, name);
|
},
|
||||||
}
|
);
|
||||||
let mut visitor = SpecialRegisterResolver {
|
let mut visitor = SpecialRegisterResolver {
|
||||||
resolver,
|
resolver,
|
||||||
special_registers,
|
special_registers,
|
||||||
|
@ -39,8 +38,8 @@ pub(super) fn run<'a, 'input>(
|
||||||
|
|
||||||
fn run_directive<'a, 'input>(
|
fn run_directive<'a, 'input>(
|
||||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
directive: UnconditionalDirective<'input>,
|
directive: UnconditionalDirective,
|
||||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
) -> Result<UnconditionalDirective, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
|
||||||
|
@ -49,8 +48,8 @@ fn run_directive<'a, 'input>(
|
||||||
|
|
||||||
fn run_method<'a, 'input>(
|
fn run_method<'a, 'input>(
|
||||||
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
visitor: &mut SpecialRegisterResolver<'a, 'input>,
|
||||||
method: UnconditionalFunction<'input>,
|
method: UnconditionalFunction,
|
||||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
) -> Result<UnconditionalFunction, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
|
@ -62,12 +61,14 @@ fn run_method<'a, 'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
func_decl: method.func_decl,
|
return_arguments: method.return_arguments,
|
||||||
globals: method.globals,
|
name: method.name,
|
||||||
|
input_arguments: method.input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
|
is_kernel: method.is_kernel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let mut result = Vec::with_capacity(directives.len());
|
let mut result = Vec::with_capacity(directives.len());
|
||||||
for mut directive in directives.into_iter() {
|
for mut directive in directives.into_iter() {
|
||||||
run_directive(&mut result, &mut directive)?;
|
run_directive(&mut result, &mut directive)?;
|
||||||
|
@ -12,8 +12,8 @@ pub(super) fn run<'input>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
match directive {
|
match directive {
|
||||||
Directive2::Variable(..) => {}
|
Directive2::Variable(..) => {}
|
||||||
|
@ -23,8 +23,8 @@ fn run_directive<'input>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_function<'input>(
|
fn run_function<'input>(
|
||||||
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
|
result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
|
function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) {
|
) {
|
||||||
function.body = function.body.take().map(|statements| {
|
function.body = function.body.take().map(|statements| {
|
||||||
statements
|
statements
|
||||||
|
|
|
@ -11,8 +11,8 @@ use super::*;
|
||||||
// pass, so we do nothing there
|
// pass, so we do nothing there
|
||||||
pub(super) fn run<'a, 'input>(
|
pub(super) fn run<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
@ -21,8 +21,8 @@ pub(super) fn run<'a, 'input>(
|
||||||
|
|
||||||
fn run_directive<'a, 'input>(
|
fn run_directive<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => {
|
Directive2::Method(method) => {
|
||||||
|
@ -34,12 +34,11 @@ fn run_directive<'a, 'input>(
|
||||||
|
|
||||||
fn run_method<'a, 'input>(
|
fn run_method<'a, 'input>(
|
||||||
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
mut visitor: InsertMemSSAVisitor<'a, 'input>,
|
||||||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
let mut func_decl = method.func_decl;
|
let is_kernel = method.is_kernel;
|
||||||
let is_kernel = func_decl.name.is_kernel();
|
|
||||||
if is_kernel {
|
if is_kernel {
|
||||||
for arg in func_decl.input_arguments.iter_mut() {
|
for arg in method.input_arguments.iter_mut() {
|
||||||
let old_name = arg.name;
|
let old_name = arg.name;
|
||||||
let old_space = arg.state_space;
|
let old_space = arg.state_space;
|
||||||
let new_space = ast::StateSpace::ParamEntry;
|
let new_space = ast::StateSpace::ParamEntry;
|
||||||
|
@ -51,10 +50,10 @@ fn run_method<'a, 'input>(
|
||||||
arg.state_space = new_space;
|
arg.state_space = new_space;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
for arg in func_decl.return_arguments.iter_mut() {
|
for arg in method.return_arguments.iter_mut() {
|
||||||
visitor.visit_variable(arg)?;
|
visitor.visit_variable(arg)?;
|
||||||
}
|
}
|
||||||
let return_arguments = &func_decl.return_arguments[..];
|
let return_arguments = &method.return_arguments[..];
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(move |statements| {
|
.map(move |statements| {
|
||||||
|
@ -66,12 +65,14 @@ fn run_method<'a, 'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
func_decl: func_decl,
|
return_arguments: method.return_arguments,
|
||||||
globals: method.globals,
|
name: method.name,
|
||||||
|
input_arguments: method.input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
|
is_kernel: method.is_kernel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use crate::pass::error_unreachable;
|
||||||
|
|
||||||
use super::BrachCondition;
|
use super::BrachCondition;
|
||||||
use super::Directive2;
|
use super::Directive2;
|
||||||
use super::Function2;
|
use super::Function2;
|
||||||
|
@ -17,6 +19,178 @@ use rustc_hash::FxHashSet;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
enum DenormalMode {
|
||||||
|
#[default]
|
||||||
|
FlushToZero,
|
||||||
|
Preserve,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DenormalMode {
|
||||||
|
fn from_ftz(ftz: bool) -> Self {
|
||||||
|
if ftz {
|
||||||
|
DenormalMode::FlushToZero
|
||||||
|
} else {
|
||||||
|
DenormalMode::Preserve
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
enum RoundingMode {
|
||||||
|
#[default]
|
||||||
|
NearestEven,
|
||||||
|
Zero,
|
||||||
|
NegativeInf,
|
||||||
|
PositiveInf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RoundingMode {
|
||||||
|
fn to_ast(self) -> ast::RoundingMode {
|
||||||
|
match self {
|
||||||
|
RoundingMode::NearestEven => ast::RoundingMode::NearestEven,
|
||||||
|
RoundingMode::Zero => ast::RoundingMode::Zero,
|
||||||
|
RoundingMode::NegativeInf => ast::RoundingMode::NegativeInf,
|
||||||
|
RoundingMode::PositiveInf => ast::RoundingMode::PositiveInf,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_ast(rnd: ast::RoundingMode) -> Self {
|
||||||
|
match rnd {
|
||||||
|
ast::RoundingMode::NearestEven => RoundingMode::NearestEven,
|
||||||
|
ast::RoundingMode::Zero => RoundingMode::Zero,
|
||||||
|
ast::RoundingMode::NegativeInf => RoundingMode::NegativeInf,
|
||||||
|
ast::RoundingMode::PositiveInf => RoundingMode::PositiveInf,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InstructionModes {
|
||||||
|
denormal_f32: Option<DenormalMode>,
|
||||||
|
denormal_f16_f64: Option<DenormalMode>,
|
||||||
|
rounding_f32: Option<RoundingMode>,
|
||||||
|
rounding_f16_f64: Option<RoundingMode>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InstructionModes {
|
||||||
|
fn none() -> Self {
|
||||||
|
Self {
|
||||||
|
denormal_f32: None,
|
||||||
|
denormal_f16_f64: None,
|
||||||
|
rounding_f32: None,
|
||||||
|
rounding_f16_f64: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(
|
||||||
|
type_: ast::ScalarType,
|
||||||
|
denormal: Option<DenormalMode>,
|
||||||
|
rounding: Option<RoundingMode>,
|
||||||
|
) -> Self {
|
||||||
|
if type_ != ast::ScalarType::F32 {
|
||||||
|
Self {
|
||||||
|
denormal_f16_f64: denormal,
|
||||||
|
rounding_f16_f64: rounding,
|
||||||
|
..Self::none()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Self {
|
||||||
|
denormal_f32: denormal,
|
||||||
|
rounding_f32: rounding,
|
||||||
|
..Self::none()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mixed_ftz_f32(
|
||||||
|
type_: ast::ScalarType,
|
||||||
|
denormal: Option<DenormalMode>,
|
||||||
|
rounding: Option<RoundingMode>,
|
||||||
|
) -> Self {
|
||||||
|
if type_ != ast::ScalarType::F32 {
|
||||||
|
Self {
|
||||||
|
denormal_f16_f64: denormal,
|
||||||
|
rounding_f32: rounding,
|
||||||
|
..Self::none()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Self {
|
||||||
|
denormal_f32: denormal,
|
||||||
|
rounding_f32: rounding,
|
||||||
|
..Self::none()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes {
|
||||||
|
let denormal = arith.flush_to_zero.map(DenormalMode::from_ftz);
|
||||||
|
let rounding = Some(RoundingMode::from_ast(arith.rounding));
|
||||||
|
InstructionModes::new(arith.type_, denormal, rounding)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_ftz(type_: ast::ScalarType, ftz: Option<bool>) -> Self {
|
||||||
|
Self::new(type_, ftz.map(DenormalMode::from_ftz), None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_ftz_f32(ftz: bool) -> Self {
|
||||||
|
Self::new(
|
||||||
|
ast::ScalarType::F32,
|
||||||
|
Some(DenormalMode::from_ftz(ftz)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_rcp(data: ast::RcpData) -> InstructionModes {
|
||||||
|
let rounding = match data.kind {
|
||||||
|
ast::RcpKind::Approx => None,
|
||||||
|
ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)),
|
||||||
|
};
|
||||||
|
let denormal = data.flush_to_zero.map(DenormalMode::from_ftz);
|
||||||
|
InstructionModes::new(data.type_, denormal, rounding)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_cvt(cvt: &ast::CvtDetails) -> InstructionModes {
|
||||||
|
match cvt.mode {
|
||||||
|
ast::CvtMode::ZeroExtend
|
||||||
|
| ast::CvtMode::SignExtend
|
||||||
|
| ast::CvtMode::Truncate
|
||||||
|
| ast::CvtMode::Bitcast
|
||||||
|
| ast::CvtMode::SaturateUnsignedToSigned
|
||||||
|
| ast::CvtMode::SaturateSignedToUnsigned => Self::none(),
|
||||||
|
ast::CvtMode::FPExtend { flush_to_zero } => {
|
||||||
|
Self::from_ftz(ast::ScalarType::F32, flush_to_zero)
|
||||||
|
}
|
||||||
|
ast::CvtMode::FPTruncate {
|
||||||
|
rounding,
|
||||||
|
flush_to_zero,
|
||||||
|
}
|
||||||
|
| ast::CvtMode::FPRound {
|
||||||
|
integer_rounding: rounding,
|
||||||
|
flush_to_zero,
|
||||||
|
} => Self::mixed_ftz_f32(
|
||||||
|
cvt.to,
|
||||||
|
flush_to_zero.map(DenormalMode::from_ftz),
|
||||||
|
Some(RoundingMode::from_ast(rounding)),
|
||||||
|
),
|
||||||
|
ast::CvtMode::SignedFromFP {
|
||||||
|
flush_to_zero,
|
||||||
|
rounding,
|
||||||
|
}
|
||||||
|
| ast::CvtMode::UnsignedFromFP {
|
||||||
|
flush_to_zero,
|
||||||
|
rounding,
|
||||||
|
} => Self::new(
|
||||||
|
cvt.from,
|
||||||
|
flush_to_zero.map(DenormalMode::from_ftz),
|
||||||
|
Some(RoundingMode::from_ast(rounding)),
|
||||||
|
),
|
||||||
|
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
|
||||||
|
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ControlFlowGraph<T: Eq + PartialEq> {
|
struct ControlFlowGraph<T: Eq + PartialEq> {
|
||||||
entry_points: FxHashMap<SpirvWord, NodeIndex>,
|
entry_points: FxHashMap<SpirvWord, NodeIndex>,
|
||||||
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
|
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
|
||||||
|
@ -74,19 +248,40 @@ struct Node<T> {
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<super::Directive2<'input, ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let mut cfg = ControlFlowGraph::<bool>::new();
|
let mut cfg = ControlFlowGraph::<bool>::new();
|
||||||
let mut node_idx_to_name = FxHashMap::<NodeIndex<u32>, SpirvWord>::default();
|
|
||||||
for directive in directives.iter() {
|
for directive in directives.iter() {
|
||||||
match directive {
|
match directive {
|
||||||
super::Directive2::Method(Function2 {
|
super::Directive2::Method(Function2 {
|
||||||
func_decl: ast::MethodDeclaration { name, .. },
|
name,
|
||||||
body,
|
body: Some(body),
|
||||||
..
|
..
|
||||||
}) => {
|
}) => {
|
||||||
|
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
|
||||||
for statement in body.iter() {
|
for statement in body.iter() {
|
||||||
todo!()
|
match statement {
|
||||||
|
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
|
||||||
|
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||||
|
cfg.add_jump(bb_index, arguments.src);
|
||||||
|
basic_block = None;
|
||||||
|
}
|
||||||
|
Statement::Label(label) => {
|
||||||
|
basic_block = Some(cfg.get_or_add_basic_block(*label));
|
||||||
|
}
|
||||||
|
Statement::Conditional(BrachCondition {
|
||||||
|
if_true, if_false, ..
|
||||||
|
}) => {
|
||||||
|
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||||
|
cfg.add_jump(bb_index, *if_true);
|
||||||
|
cfg.add_jump(bb_index, *if_false);
|
||||||
|
basic_block = None;
|
||||||
|
}
|
||||||
|
Statement::Instruction(instruction) => {
|
||||||
|
let modes = get_modes(instruction);
|
||||||
|
}
|
||||||
|
_ => continue,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => continue,
|
_ => continue,
|
||||||
|
@ -280,6 +475,169 @@ impl<T: Copy + Eq + Hash> UniqueVec<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||||
|
match inst {
|
||||||
|
// TODO: review it when implementing virtual calls
|
||||||
|
ast::Instruction::Call { .. }
|
||||||
|
| ast::Instruction::Mov { .. }
|
||||||
|
| ast::Instruction::Ld { .. }
|
||||||
|
| ast::Instruction::St { .. }
|
||||||
|
| ast::Instruction::PrmtSlow { .. }
|
||||||
|
| ast::Instruction::Prmt { .. }
|
||||||
|
| ast::Instruction::Activemask { .. }
|
||||||
|
| ast::Instruction::Membar { .. }
|
||||||
|
| ast::Instruction::Trap {}
|
||||||
|
| ast::Instruction::Not { .. }
|
||||||
|
| ast::Instruction::Or { .. }
|
||||||
|
| ast::Instruction::And { .. }
|
||||||
|
| ast::Instruction::Bra { .. }
|
||||||
|
| ast::Instruction::Clz { .. }
|
||||||
|
| ast::Instruction::Brev { .. }
|
||||||
|
| ast::Instruction::Popc { .. }
|
||||||
|
| ast::Instruction::Xor { .. }
|
||||||
|
| ast::Instruction::Rem { .. }
|
||||||
|
| ast::Instruction::Bfe { .. }
|
||||||
|
| ast::Instruction::Bfi { .. }
|
||||||
|
| ast::Instruction::Shr { .. }
|
||||||
|
| ast::Instruction::Shl { .. }
|
||||||
|
| ast::Instruction::Selp { .. }
|
||||||
|
| ast::Instruction::Ret { .. }
|
||||||
|
| ast::Instruction::Bar { .. }
|
||||||
|
| ast::Instruction::Cvta { .. }
|
||||||
|
| ast::Instruction::Atom { .. }
|
||||||
|
| ast::Instruction::AtomCas { .. } => InstructionModes::none(),
|
||||||
|
ast::Instruction::Add {
|
||||||
|
data: ast::ArithDetails::Integer(_),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Sub {
|
||||||
|
data: ast::ArithDetails::Integer(..),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Integer { .. },
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Mad {
|
||||||
|
data: ast::MadDetails::Integer { .. },
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Min {
|
||||||
|
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Max {
|
||||||
|
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Div {
|
||||||
|
data: ast::DivDetails::Signed(..) | ast::DivDetails::Unsigned(..),
|
||||||
|
..
|
||||||
|
} => InstructionModes::none(),
|
||||||
|
ast::Instruction::Fma { data, .. }
|
||||||
|
| ast::Instruction::Sub {
|
||||||
|
data: ast::ArithDetails::Float(data),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Mul {
|
||||||
|
data: ast::MulDetails::Float(data),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Mad {
|
||||||
|
data: ast::MadDetails::Float(data),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Add {
|
||||||
|
data: ast::ArithDetails::Float(data),
|
||||||
|
..
|
||||||
|
} => InstructionModes::from_arith_float(data),
|
||||||
|
ast::Instruction::Setp {
|
||||||
|
data:
|
||||||
|
ast::SetpData {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::SetpBool {
|
||||||
|
data:
|
||||||
|
ast::SetpBoolData {
|
||||||
|
base:
|
||||||
|
ast::SetpData {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
..
|
||||||
|
},
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Neg {
|
||||||
|
data: ast::TypeFtz {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
},
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Ex2 {
|
||||||
|
data: ast::TypeFtz {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
},
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Rsqrt {
|
||||||
|
data: ast::TypeFtz {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
},
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Abs {
|
||||||
|
data: ast::TypeFtz {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
},
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Min {
|
||||||
|
data:
|
||||||
|
ast::MinMaxDetails::Float(ast::MinMaxFloat {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
..
|
||||||
|
}),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Max {
|
||||||
|
data:
|
||||||
|
ast::MinMaxDetails::Float(ast::MinMaxFloat {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
..
|
||||||
|
}),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| ast::Instruction::Div {
|
||||||
|
data:
|
||||||
|
ast::DivDetails::Float(ast::DivFloatDetails {
|
||||||
|
type_,
|
||||||
|
flush_to_zero,
|
||||||
|
..
|
||||||
|
}),
|
||||||
|
..
|
||||||
|
} => InstructionModes::from_ftz(*type_, *flush_to_zero),
|
||||||
|
ast::Instruction::Sin { data, .. }
|
||||||
|
| ast::Instruction::Cos { data, .. }
|
||||||
|
| ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
|
||||||
|
ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => {
|
||||||
|
InstructionModes::from_rcp(*data)
|
||||||
|
}
|
||||||
|
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -19,8 +19,8 @@ use ptx_parser as ast;
|
||||||
*/
|
*/
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
@ -29,8 +29,8 @@ pub(super) fn run<'input>(
|
||||||
|
|
||||||
fn run_directive<'a, 'input>(
|
fn run_directive<'a, 'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(mut method) => {
|
Directive2::Method(mut method) => {
|
||||||
|
|
|
@ -44,7 +44,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
||||||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||||
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
|
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
|
||||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
|
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
|
||||||
let directives = replace_known_functions::run(&flat_resolver, directives);
|
let directives = replace_known_functions::run(&mut flat_resolver, directives);
|
||||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
||||||
let directives = resolve_function_pointers::run(directives)?;
|
let directives = resolve_function_pointers::run(directives)?;
|
||||||
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||||
|
@ -559,22 +559,23 @@ type NormalizedStatement = Statement<
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
enum Directive2<'input, Instruction, Operand: ast::Operand> {
|
enum Directive2<Instruction, Operand: ast::Operand> {
|
||||||
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
|
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
|
||||||
Method(Function2<'input, Instruction, Operand>),
|
Method(Function2<Instruction, Operand>),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Function2<'input, Instruction, Operand: ast::Operand> {
|
struct Function2<Instruction, Operand: ast::Operand> {
|
||||||
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
|
pub return_arguments: Vec<ast::Variable<Operand::Ident>>,
|
||||||
pub globals: Vec<ast::Variable<SpirvWord>>,
|
pub name: Operand::Ident,
|
||||||
|
pub input_arguments: Vec<ast::Variable<Operand::Ident>>,
|
||||||
pub body: Option<Vec<Statement<Instruction, Operand>>>,
|
pub body: Option<Vec<Statement<Instruction, Operand>>>,
|
||||||
|
is_kernel: bool,
|
||||||
import_as: Option<String>,
|
import_as: Option<String>,
|
||||||
tuning: Vec<ast::TuningDirective>,
|
tuning: Vec<ast::TuningDirective>,
|
||||||
linkage: ast::LinkingDirective,
|
linkage: ast::LinkingDirective,
|
||||||
}
|
}
|
||||||
|
|
||||||
type NormalizedDirective2<'input> = Directive2<
|
type NormalizedDirective2 = Directive2<
|
||||||
'input,
|
|
||||||
(
|
(
|
||||||
Option<ast::PredAt<SpirvWord>>,
|
Option<ast::PredAt<SpirvWord>>,
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
@ -582,8 +583,7 @@ type NormalizedDirective2<'input> = Directive2<
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
type NormalizedFunction2<'input> = Function2<
|
type NormalizedFunction2 = Function2<
|
||||||
'input,
|
|
||||||
(
|
(
|
||||||
Option<ast::PredAt<SpirvWord>>,
|
Option<ast::PredAt<SpirvWord>>,
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
@ -591,17 +591,11 @@ type NormalizedFunction2<'input> = Function2<
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
type UnconditionalDirective<'input> = Directive2<
|
type UnconditionalDirective =
|
||||||
'input,
|
Directive2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
|
||||||
ast::ParsedOperand<SpirvWord>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
type UnconditionalFunction<'input> = Function2<
|
type UnconditionalFunction =
|
||||||
'input,
|
Function2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
|
||||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
|
||||||
ast::ParsedOperand<SpirvWord>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
struct GlobalStringIdentResolver2<'input> {
|
struct GlobalStringIdentResolver2<'input> {
|
||||||
pub(crate) current_id: SpirvWord,
|
pub(crate) current_id: SpirvWord,
|
||||||
|
@ -807,47 +801,45 @@ impl SpecialRegistersMap2 {
|
||||||
self.id_to_reg.get(&id).copied()
|
self.id_to_reg.get(&id).copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_declarations<'a, 'input>(
|
fn len() -> usize {
|
||||||
|
PtxSpecialRegister::iter().len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn foreach_declaration<'a, 'input>(
|
||||||
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
resolver: &'a mut GlobalStringIdentResolver2<'input>,
|
||||||
) -> impl ExactSizeIterator<
|
mut fn_: impl FnMut(
|
||||||
Item = (
|
|
||||||
PtxSpecialRegister,
|
PtxSpecialRegister,
|
||||||
ast::MethodDeclaration<'input, SpirvWord>,
|
(
|
||||||
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
|
SpirvWord,
|
||||||
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
> + 'a {
|
) {
|
||||||
PtxSpecialRegister::iter().map(|sreg| {
|
for sreg in PtxSpecialRegister::iter() {
|
||||||
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||||
let name =
|
let name = resolver.register_named(Cow::Owned(external_fn_name), None);
|
||||||
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
|
|
||||||
let return_type = sreg.get_function_return_type();
|
let return_type = sreg.get_function_return_type();
|
||||||
let input_type = sreg.get_function_input_type();
|
let input_type = sreg.get_function_input_type();
|
||||||
(
|
let return_arguments = vec![ast::Variable {
|
||||||
sreg,
|
align: None,
|
||||||
ast::MethodDeclaration {
|
v_type: return_type.into(),
|
||||||
return_arguments: vec![ast::Variable {
|
state_space: ast::StateSpace::Reg,
|
||||||
align: None,
|
name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
|
||||||
v_type: return_type.into(),
|
array_init: Vec::new(),
|
||||||
state_space: ast::StateSpace::Reg,
|
}];
|
||||||
name: resolver
|
let input_arguments = input_type
|
||||||
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
|
.into_iter()
|
||||||
array_init: Vec::new(),
|
.map(|type_| ast::Variable {
|
||||||
}],
|
align: None,
|
||||||
name: name,
|
v_type: type_.into(),
|
||||||
input_arguments: input_type
|
state_space: ast::StateSpace::Reg,
|
||||||
.into_iter()
|
name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
|
||||||
.map(|type_| ast::Variable {
|
array_init: Vec::new(),
|
||||||
align: None,
|
})
|
||||||
v_type: type_.into(),
|
.collect::<Vec<_>>();
|
||||||
state_space: ast::StateSpace::Reg,
|
fn_(sreg, (return_arguments, name, input_arguments));
|
||||||
name: resolver
|
}
|
||||||
.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
|
|
||||||
array_init: Vec::new(),
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
shared_mem: None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ use ptx_parser as ast;
|
||||||
pub(crate) fn run<'input, 'b>(
|
pub(crate) fn run<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||||
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
) -> Result<Vec<NormalizedDirective2>, TranslateError> {
|
||||||
resolver.start_scope();
|
resolver.start_scope();
|
||||||
let result = directives
|
let result = directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -17,7 +17,7 @@ pub(crate) fn run<'input, 'b>(
|
||||||
fn run_directive<'input, 'b>(
|
fn run_directive<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||||
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
) -> Result<NormalizedDirective2, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
ast::Directive::Variable(linking, var) => {
|
ast::Directive::Variable(linking, var) => {
|
||||||
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||||
|
@ -32,15 +32,11 @@ fn run_method<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
linkage: ast::LinkingDirective,
|
linkage: ast::LinkingDirective,
|
||||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
) -> Result<NormalizedFunction2, TranslateError> {
|
||||||
let name = match method.func_directive.name {
|
let is_kernel = method.func_directive.name.is_kernel();
|
||||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
|
||||||
ast::MethodName::Func(text) => {
|
|
||||||
ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
resolver.start_scope();
|
resolver.start_scope();
|
||||||
let func_decl = run_function_decl(resolver, method.func_directive, name)?;
|
let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
|
@ -51,20 +47,21 @@ fn run_method<'input, 'b>(
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
resolver.end_scope();
|
resolver.end_scope();
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
func_decl,
|
return_arguments,
|
||||||
globals: Vec::new(),
|
name,
|
||||||
|
input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: None,
|
import_as: None,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage,
|
linkage,
|
||||||
|
is_kernel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_function_decl<'input, 'b>(
|
fn run_function_decl<'input, 'b>(
|
||||||
resolver: &mut ScopedResolver<'input, 'b>,
|
resolver: &mut ScopedResolver<'input, 'b>,
|
||||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||||
name: ast::MethodName<'input, SpirvWord>,
|
) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||||
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
|
|
||||||
assert!(func_directive.shared_mem.is_none());
|
assert!(func_directive.shared_mem.is_none());
|
||||||
let return_arguments = func_directive
|
let return_arguments = func_directive
|
||||||
.return_arguments
|
.return_arguments
|
||||||
|
@ -76,12 +73,7 @@ fn run_function_decl<'input, 'b>(
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|var| run_variable(resolver, var))
|
.map(|var| run_variable(resolver, var))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
Ok(ast::MethodDeclaration {
|
Ok((return_arguments, input_arguments))
|
||||||
return_arguments,
|
|
||||||
name,
|
|
||||||
input_arguments,
|
|
||||||
shared_mem: None,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_variable<'input, 'b>(
|
fn run_variable<'input, 'b>(
|
||||||
|
|
|
@ -3,8 +3,8 @@ use ptx_parser as ast;
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<NormalizedDirective2<'input>>,
|
directives: Vec<NormalizedDirective2>,
|
||||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| run_directive(resolver, directive))
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
@ -13,8 +13,8 @@ pub(crate) fn run<'input>(
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directive: NormalizedDirective2<'input>,
|
directive: NormalizedDirective2,
|
||||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
) -> Result<UnconditionalDirective, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
|
@ -23,8 +23,8 @@ fn run_directive<'input>(
|
||||||
|
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
method: NormalizedFunction2<'input>,
|
method: NormalizedFunction2,
|
||||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
) -> Result<UnconditionalFunction, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
|
@ -36,12 +36,14 @@ fn run_method<'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
func_decl: method.func_decl,
|
return_arguments: method.return_arguments,
|
||||||
globals: method.globals,
|
name: method.name,
|
||||||
|
input_arguments: method.input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
|
is_kernel: method.is_kernel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,8 @@ use super::*;
|
||||||
|
|
||||||
pub(super) fn run<'input>(
|
pub(super) fn run<'input>(
|
||||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let mut fn_declarations = FxHashMap::default();
|
let mut fn_declarations = FxHashMap::default();
|
||||||
let remapped_directives = directives
|
let remapped_directives = directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -13,17 +13,14 @@ pub(super) fn run<'input>(
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(_, (return_arguments, name, input_arguments))| {
|
.map(|(_, (return_arguments, name, input_arguments))| {
|
||||||
Directive2::Method(Function2 {
|
Directive2::Method(Function2 {
|
||||||
func_decl: ast::MethodDeclaration {
|
return_arguments,
|
||||||
return_arguments,
|
name: name,
|
||||||
name: ast::MethodName::Func(name),
|
input_arguments,
|
||||||
input_arguments,
|
|
||||||
shared_mem: None,
|
|
||||||
},
|
|
||||||
globals: Vec::new(),
|
|
||||||
body: None,
|
body: None,
|
||||||
import_as: None,
|
import_as: None,
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
linkage: ast::LinkingDirective::EXTERN,
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
|
is_kernel: false,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
@ -41,8 +38,8 @@ fn run_directive<'input>(
|
||||||
Vec<ast::Variable<SpirvWord>>,
|
Vec<ast::Variable<SpirvWord>>,
|
||||||
),
|
),
|
||||||
>,
|
>,
|
||||||
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(mut method) => {
|
Directive2::Method(mut method) => {
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
|
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
resolver: &GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
mut directives: Vec<NormalizedDirective2<'input>>,
|
mut directives: Vec<NormalizedDirective2>,
|
||||||
) -> Vec<NormalizedDirective2<'input>> {
|
) -> Vec<NormalizedDirective2> {
|
||||||
for directive in directives.iter_mut() {
|
for directive in directives.iter_mut() {
|
||||||
match directive {
|
match directive {
|
||||||
NormalizedDirective2::Method(func) => {
|
NormalizedDirective2::Method(func) => {
|
||||||
func.import_as =
|
replace_with_ptx_impl(resolver, func.name);
|
||||||
replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take());
|
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
@ -17,22 +18,16 @@ pub(crate) fn run<'input>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn replace_with_ptx_impl<'input>(
|
fn replace_with_ptx_impl<'input>(
|
||||||
resolver: &GlobalStringIdentResolver2<'input>,
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
fn_name: &ptx_parser::MethodName<'input, SpirvWord>,
|
fn_name: SpirvWord,
|
||||||
name: Option<String>,
|
) {
|
||||||
) -> Option<String> {
|
|
||||||
let known_names = ["__assertfail"];
|
let known_names = ["__assertfail"];
|
||||||
match name {
|
if let Some(super::IdentEntry {
|
||||||
Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)),
|
name: Some(name), ..
|
||||||
Some(name) => Some(name),
|
}) = resolver.ident_map.get_mut(&fn_name)
|
||||||
None => match fn_name {
|
{
|
||||||
ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) {
|
if known_names.contains(&&**name) {
|
||||||
Some(super::IdentEntry {
|
*name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
|
||||||
name: Some(name), ..
|
}
|
||||||
}) => Some(format!("__zluda_ptx_impl_{}", name)),
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
ptx_parser::MethodName::Kernel(..) => None,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,8 @@ use ptx_parser as ast;
|
||||||
use rustc_hash::FxHashSet;
|
use rustc_hash::FxHashSet;
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
directives: Vec<UnconditionalDirective<'input>>,
|
directives: Vec<UnconditionalDirective>,
|
||||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
|
||||||
let mut functions = FxHashSet::default();
|
let mut functions = FxHashSet::default();
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -14,19 +14,13 @@ pub(crate) fn run<'input>(
|
||||||
|
|
||||||
fn run_directive<'input>(
|
fn run_directive<'input>(
|
||||||
functions: &mut FxHashSet<SpirvWord>,
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
directive: UnconditionalDirective<'input>,
|
directive: UnconditionalDirective,
|
||||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
) -> Result<UnconditionalDirective, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
var @ Directive2::Variable(..) => var,
|
var @ Directive2::Variable(..) => var,
|
||||||
Directive2::Method(method) => {
|
Directive2::Method(method) => {
|
||||||
{
|
if !method.is_kernel {
|
||||||
let func_decl = &method.func_decl;
|
functions.insert(method.name);
|
||||||
match func_decl.name {
|
|
||||||
ptx_parser::MethodName::Kernel(_) => {}
|
|
||||||
ptx_parser::MethodName::Func(name) => {
|
|
||||||
functions.insert(name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Directive2::Method(run_method(functions, method)?)
|
Directive2::Method(run_method(functions, method)?)
|
||||||
}
|
}
|
||||||
|
@ -35,8 +29,8 @@ fn run_directive<'input>(
|
||||||
|
|
||||||
fn run_method<'input>(
|
fn run_method<'input>(
|
||||||
functions: &mut FxHashSet<SpirvWord>,
|
functions: &mut FxHashSet<SpirvWord>,
|
||||||
method: UnconditionalFunction<'input>,
|
method: UnconditionalFunction,
|
||||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
) -> Result<UnconditionalFunction, TranslateError> {
|
||||||
let body = method
|
let body = method
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
|
@ -47,12 +41,14 @@ fn run_method<'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
func_decl: method.func_decl,
|
return_arguments: method.return_arguments,
|
||||||
globals: method.globals,
|
name: method.name,
|
||||||
|
input_arguments: method.input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
|
is_kernel: method.is_kernel,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1028,7 +1028,7 @@ pub struct ArithFloat {
|
||||||
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
|
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
|
||||||
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
|
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
|
||||||
// instructions on the target device.
|
// instructions on the target device.
|
||||||
pub is_fusable: bool
|
pub is_fusable: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||||
|
@ -1447,6 +1447,7 @@ pub struct CvtDetails {
|
||||||
pub mode: CvtMode,
|
pub mode: CvtMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub enum CvtMode {
|
pub enum CvtMode {
|
||||||
// int from int
|
// int from int
|
||||||
ZeroExtend,
|
ZeroExtend,
|
||||||
|
@ -1465,7 +1466,7 @@ pub enum CvtMode {
|
||||||
flush_to_zero: Option<bool>,
|
flush_to_zero: Option<bool>,
|
||||||
},
|
},
|
||||||
FPRound {
|
FPRound {
|
||||||
integer_rounding: Option<RoundingMode>,
|
integer_rounding: RoundingMode,
|
||||||
flush_to_zero: Option<bool>,
|
flush_to_zero: Option<bool>,
|
||||||
},
|
},
|
||||||
// int from float
|
// int from float
|
||||||
|
@ -1519,7 +1520,7 @@ impl CvtDetails {
|
||||||
flush_to_zero,
|
flush_to_zero,
|
||||||
},
|
},
|
||||||
Ordering::Equal => CvtMode::FPRound {
|
Ordering::Equal => CvtMode::FPRound {
|
||||||
integer_rounding: rounding,
|
integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven),
|
||||||
flush_to_zero,
|
flush_to_zero,
|
||||||
},
|
},
|
||||||
Ordering::Greater => {
|
Ordering::Greater => {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue