Continue working on ftz modes

This commit is contained in:
Andrzej Janik 2025-02-18 02:42:17 +00:00
parent 17529f951d
commit 5121bba285
15 changed files with 559 additions and 226 deletions

View file

@ -2,8 +2,8 @@ use super::*;
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
@ -12,8 +12,8 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'input>( fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2, resolver: &mut GlobalStringIdentResolver2,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
@ -22,13 +22,13 @@ fn run_directive<'input>(
fn run_method<'input>( fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2, resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let is_declaration = method.body.is_none(); let is_declaration = method.body.is_none();
let mut body = Vec::new(); let mut body = Vec::new();
let mut remap_returns = Vec::new(); let mut remap_returns = Vec::new();
if !method.func_decl.name.is_kernel() { if !method.is_kernel {
for arg in method.func_decl.return_arguments.iter_mut() { for arg in method.return_arguments.iter_mut() {
match arg.state_space { match arg.state_space {
ptx_parser::StateSpace::Param => { ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg; arg.state_space = ptx_parser::StateSpace::Reg;
@ -51,7 +51,7 @@ fn run_method<'input>(
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
} }
} }
for arg in method.func_decl.input_arguments.iter_mut() { for arg in method.input_arguments.iter_mut() {
match arg.state_space { match arg.state_space {
ptx_parser::StateSpace::Param => { ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg; arg.state_space = ptx_parser::StateSpace::Reg;
@ -96,12 +96,14 @@ fn run_method<'input>(
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
func_decl: method.func_decl, return_arguments: method.return_arguments,
globals: method.globals, name: method.name,
input_arguments: method.input_arguments,
body, body,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel,
}) })
} }

View file

@ -168,7 +168,7 @@ impl Deref for MemoryBuffer {
pub(super) fn run<'input>( pub(super) fn run<'input>(
id_defs: GlobalStringIdentResolver2<'input>, id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> { ) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new(); let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED); let module = Module::new(&context, LLVM_UNNAMED);
@ -218,24 +218,20 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn emit_method( fn emit_method(
&mut self, &mut self,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let func_decl = method.func_decl;
let name = method let name = method
.import_as .import_as
.as_deref() .as_deref()
.or_else(|| match func_decl.name { .or_else(|| self.id_defs.ident_map[&method.name].name.as_deref())
ast::MethodName::Kernel(name) => Some(name),
ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
})
.ok_or_else(|| error_unreachable())?; .ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?; let name = CString::new(name).map_err(|_| error_unreachable())?;
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
if fn_ == ptr::null_mut() { if fn_ == ptr::null_mut() {
let fn_type = get_function_type( let fn_type = get_function_type(
self.context, self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type), method.return_arguments.iter().map(|v| &v.v_type),
func_decl method
.input_arguments .input_arguments
.iter() .iter()
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
@ -245,15 +241,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
self.emit_fn_attribute(fn_, "uniform-work-group-size", "true"); self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
self.emit_fn_attribute(fn_, "no-trapping-math", "true"); self.emit_fn_attribute(fn_, "no-trapping-math", "true");
} }
if let ast::MethodName::Func(name) = func_decl.name { if !method.is_kernel {
self.resolver.register(name, fn_); self.resolver.register(method.name, fn_);
} }
for (i, param) in func_decl.input_arguments.iter().enumerate() { for (i, param) in method.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) }; let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name); let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value); self.resolver.register(param.name, value);
if func_decl.name.is_kernel() { if method.is_kernel {
let attr_kind = unsafe { let attr_kind = unsafe {
LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len()) LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
}; };
@ -267,7 +263,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
} }
} }
let call_conv = if func_decl.name.is_kernel() { let call_conv = if method.is_kernel {
Self::kernel_call_convention() Self::kernel_call_convention()
} else { } else {
Self::func_call_convention() Self::func_call_convention()
@ -282,7 +278,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
for var in func_decl.return_arguments { for var in method.return_arguments {
method_emitter.emit_variable(var)?; method_emitter.emit_variable(var)?;
} }
for statement in statements.iter() { for statement in statements.iter() {
@ -1558,7 +1554,7 @@ impl<'a> MethodEmitContext<'a> {
return self.emit_cvt_float_to_int( return self.emit_cvt_float_to_int(
data.from, data.from,
data.to, data.to,
integer_rounding.unwrap_or(ast::RoundingMode::NearestEven), integer_rounding,
arguments, arguments,
Some(LLVMBuildFPToSI), Some(LLVMBuildFPToSI),
) )

View file

@ -2,8 +2,8 @@ use super::*;
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<UnconditionalDirective<'input>>, directives: Vec<UnconditionalDirective>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
@ -13,11 +13,10 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'input>( fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2< directive: Directive2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>,
>, >,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var), Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
@ -27,11 +26,10 @@ fn run_directive<'input>(
fn run_method<'input>( fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
method: Function2< method: Function2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>,
>, >,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
@ -43,12 +41,14 @@ fn run_method<'input>(
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
func_decl: method.func_decl, return_arguments: method.return_arguments,
globals: method.globals, name: method.name,
input_arguments: method.input_arguments,
body, body,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel,
}) })
} }

View file

@ -1,30 +1,29 @@
use super::*; use super::*;
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2, special_registers: &'a SpecialRegistersMap2,
directives: Vec<UnconditionalDirective<'input>>, directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> { ) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let declarations = SpecialRegistersMap2::generate_declarations(resolver); let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
let mut result = Vec::with_capacity(declarations.len() + directives.len());
let mut sreg_to_function = let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default()); FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
for (sreg, declaration) in declarations { SpecialRegistersMap2::foreach_declaration(
let name = if let ast::MethodName::Func(name) = declaration.name { resolver,
name |sreg, (return_arguments, name, input_arguments)| {
} else { result.push(UnconditionalDirective::Method(UnconditionalFunction {
return Err(error_unreachable()); return_arguments,
}; name,
result.push(UnconditionalDirective::Method(UnconditionalFunction { input_arguments,
func_decl: declaration, body: None,
globals: Vec::new(), import_as: None,
body: None, tuning: Vec::new(),
import_as: None, linkage: ast::LinkingDirective::EXTERN,
tuning: Vec::new(), is_kernel: false,
linkage: ast::LinkingDirective::EXTERN, }));
})); sreg_to_function.insert(sreg, name);
sreg_to_function.insert(sreg, name); },
} );
let mut visitor = SpecialRegisterResolver { let mut visitor = SpecialRegisterResolver {
resolver, resolver,
special_registers, special_registers,
@ -39,8 +38,8 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'a, 'input>( fn run_directive<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>, visitor: &mut SpecialRegisterResolver<'a, 'input>,
directive: UnconditionalDirective<'input>, directive: UnconditionalDirective,
) -> Result<UnconditionalDirective<'input>, TranslateError> { ) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?), Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
@ -49,8 +48,8 @@ fn run_directive<'a, 'input>(
fn run_method<'a, 'input>( fn run_method<'a, 'input>(
visitor: &mut SpecialRegisterResolver<'a, 'input>, visitor: &mut SpecialRegisterResolver<'a, 'input>,
method: UnconditionalFunction<'input>, method: UnconditionalFunction,
) -> Result<UnconditionalFunction<'input>, TranslateError> { ) -> Result<UnconditionalFunction, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
@ -62,12 +61,14 @@ fn run_method<'a, 'input>(
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
func_decl: method.func_decl, return_arguments: method.return_arguments,
globals: method.globals, name: method.name,
input_arguments: method.input_arguments,
body, body,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel,
}) })
} }

View file

@ -1,8 +1,8 @@
use super::*; use super::*;
pub(super) fn run<'input>( pub(super) fn run<'input>(
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len()); let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() { for mut directive in directives.into_iter() {
run_directive(&mut result, &mut directive)?; run_directive(&mut result, &mut directive)?;
@ -12,8 +12,8 @@ pub(super) fn run<'input>(
} }
fn run_directive<'input>( fn run_directive<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>, result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>, directive: &mut Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match directive { match directive {
Directive2::Variable(..) => {} Directive2::Variable(..) => {}
@ -23,8 +23,8 @@ fn run_directive<'input>(
} }
fn run_function<'input>( fn run_function<'input>(
result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>, result: &mut Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>, function: &mut Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) { ) {
function.body = function.body.take().map(|statements| { function.body = function.body.take().map(|statements| {
statements statements

View file

@ -11,8 +11,8 @@ use super::*;
// pass, so we do nothing there // pass, so we do nothing there
pub(super) fn run<'a, 'input>( pub(super) fn run<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
@ -21,8 +21,8 @@ pub(super) fn run<'a, 'input>(
fn run_directive<'a, 'input>( fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => { Directive2::Method(method) => {
@ -34,12 +34,11 @@ fn run_directive<'a, 'input>(
fn run_method<'a, 'input>( fn run_method<'a, 'input>(
mut visitor: InsertMemSSAVisitor<'a, 'input>, mut visitor: InsertMemSSAVisitor<'a, 'input>,
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, mut method: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Function2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl; let is_kernel = method.is_kernel;
let is_kernel = func_decl.name.is_kernel();
if is_kernel { if is_kernel {
for arg in func_decl.input_arguments.iter_mut() { for arg in method.input_arguments.iter_mut() {
let old_name = arg.name; let old_name = arg.name;
let old_space = arg.state_space; let old_space = arg.state_space;
let new_space = ast::StateSpace::ParamEntry; let new_space = ast::StateSpace::ParamEntry;
@ -51,10 +50,10 @@ fn run_method<'a, 'input>(
arg.state_space = new_space; arg.state_space = new_space;
} }
}; };
for arg in func_decl.return_arguments.iter_mut() { for arg in method.return_arguments.iter_mut() {
visitor.visit_variable(arg)?; visitor.visit_variable(arg)?;
} }
let return_arguments = &func_decl.return_arguments[..]; let return_arguments = &method.return_arguments[..];
let body = method let body = method
.body .body
.map(move |statements| { .map(move |statements| {
@ -66,12 +65,14 @@ fn run_method<'a, 'input>(
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
func_decl: func_decl, return_arguments: method.return_arguments,
globals: method.globals, name: method.name,
input_arguments: method.input_arguments,
body, body,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel,
}) })
} }

View file

@ -1,3 +1,5 @@
use crate::pass::error_unreachable;
use super::BrachCondition; use super::BrachCondition;
use super::Directive2; use super::Directive2;
use super::Function2; use super::Function2;
@ -17,6 +19,178 @@ use rustc_hash::FxHashSet;
use std::hash::Hash; use std::hash::Hash;
use std::iter; use std::iter;
#[derive(Default)]
enum DenormalMode {
#[default]
FlushToZero,
Preserve,
}
impl DenormalMode {
fn from_ftz(ftz: bool) -> Self {
if ftz {
DenormalMode::FlushToZero
} else {
DenormalMode::Preserve
}
}
}
#[derive(Default)]
enum RoundingMode {
#[default]
NearestEven,
Zero,
NegativeInf,
PositiveInf,
}
impl RoundingMode {
fn to_ast(self) -> ast::RoundingMode {
match self {
RoundingMode::NearestEven => ast::RoundingMode::NearestEven,
RoundingMode::Zero => ast::RoundingMode::Zero,
RoundingMode::NegativeInf => ast::RoundingMode::NegativeInf,
RoundingMode::PositiveInf => ast::RoundingMode::PositiveInf,
}
}
fn from_ast(rnd: ast::RoundingMode) -> Self {
match rnd {
ast::RoundingMode::NearestEven => RoundingMode::NearestEven,
ast::RoundingMode::Zero => RoundingMode::Zero,
ast::RoundingMode::NegativeInf => RoundingMode::NegativeInf,
ast::RoundingMode::PositiveInf => RoundingMode::PositiveInf,
}
}
}
struct InstructionModes {
denormal_f32: Option<DenormalMode>,
denormal_f16_f64: Option<DenormalMode>,
rounding_f32: Option<RoundingMode>,
rounding_f16_f64: Option<RoundingMode>,
}
impl InstructionModes {
fn none() -> Self {
Self {
denormal_f32: None,
denormal_f16_f64: None,
rounding_f32: None,
rounding_f16_f64: None,
}
}
fn new(
type_: ast::ScalarType,
denormal: Option<DenormalMode>,
rounding: Option<RoundingMode>,
) -> Self {
if type_ != ast::ScalarType::F32 {
Self {
denormal_f16_f64: denormal,
rounding_f16_f64: rounding,
..Self::none()
}
} else {
Self {
denormal_f32: denormal,
rounding_f32: rounding,
..Self::none()
}
}
}
fn mixed_ftz_f32(
type_: ast::ScalarType,
denormal: Option<DenormalMode>,
rounding: Option<RoundingMode>,
) -> Self {
if type_ != ast::ScalarType::F32 {
Self {
denormal_f16_f64: denormal,
rounding_f32: rounding,
..Self::none()
}
} else {
Self {
denormal_f32: denormal,
rounding_f32: rounding,
..Self::none()
}
}
}
fn from_arith_float(arith: &ast::ArithFloat) -> InstructionModes {
let denormal = arith.flush_to_zero.map(DenormalMode::from_ftz);
let rounding = Some(RoundingMode::from_ast(arith.rounding));
InstructionModes::new(arith.type_, denormal, rounding)
}
fn from_ftz(type_: ast::ScalarType, ftz: Option<bool>) -> Self {
Self::new(type_, ftz.map(DenormalMode::from_ftz), None)
}
fn from_ftz_f32(ftz: bool) -> Self {
Self::new(
ast::ScalarType::F32,
Some(DenormalMode::from_ftz(ftz)),
None,
)
}
fn from_rcp(data: ast::RcpData) -> InstructionModes {
let rounding = match data.kind {
ast::RcpKind::Approx => None,
ast::RcpKind::Compliant(rnd) => Some(RoundingMode::from_ast(rnd)),
};
let denormal = data.flush_to_zero.map(DenormalMode::from_ftz);
InstructionModes::new(data.type_, denormal, rounding)
}
fn from_cvt(cvt: &ast::CvtDetails) -> InstructionModes {
match cvt.mode {
ast::CvtMode::ZeroExtend
| ast::CvtMode::SignExtend
| ast::CvtMode::Truncate
| ast::CvtMode::Bitcast
| ast::CvtMode::SaturateUnsignedToSigned
| ast::CvtMode::SaturateSignedToUnsigned => Self::none(),
ast::CvtMode::FPExtend { flush_to_zero } => {
Self::from_ftz(ast::ScalarType::F32, flush_to_zero)
}
ast::CvtMode::FPTruncate {
rounding,
flush_to_zero,
}
| ast::CvtMode::FPRound {
integer_rounding: rounding,
flush_to_zero,
} => Self::mixed_ftz_f32(
cvt.to,
flush_to_zero.map(DenormalMode::from_ftz),
Some(RoundingMode::from_ast(rounding)),
),
ast::CvtMode::SignedFromFP {
flush_to_zero,
rounding,
}
| ast::CvtMode::UnsignedFromFP {
flush_to_zero,
rounding,
} => Self::new(
cvt.from,
flush_to_zero.map(DenormalMode::from_ftz),
Some(RoundingMode::from_ast(rounding)),
),
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
}
}
}
}
struct ControlFlowGraph<T: Eq + PartialEq> { struct ControlFlowGraph<T: Eq + PartialEq> {
entry_points: FxHashMap<SpirvWord, NodeIndex>, entry_points: FxHashMap<SpirvWord, NodeIndex>,
basic_blocks: FxHashMap<SpirvWord, NodeIndex>, basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
@ -74,19 +248,40 @@ struct Node<T> {
pub(crate) fn run<'input>( pub(crate) fn run<'input>(
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
directives: Vec<super::Directive2<'input, ast::Instruction<SpirvWord>, super::SpirvWord>>, directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut cfg = ControlFlowGraph::<bool>::new(); let mut cfg = ControlFlowGraph::<bool>::new();
let mut node_idx_to_name = FxHashMap::<NodeIndex<u32>, SpirvWord>::default();
for directive in directives.iter() { for directive in directives.iter() {
match directive { match directive {
super::Directive2::Method(Function2 { super::Directive2::Method(Function2 {
func_decl: ast::MethodDeclaration { name, .. }, name,
body, body: Some(body),
.. ..
}) => { }) => {
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
for statement in body.iter() { for statement in body.iter() {
todo!() match statement {
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
let bb_index = basic_block.ok_or_else(error_unreachable)?;
cfg.add_jump(bb_index, arguments.src);
basic_block = None;
}
Statement::Label(label) => {
basic_block = Some(cfg.get_or_add_basic_block(*label));
}
Statement::Conditional(BrachCondition {
if_true, if_false, ..
}) => {
let bb_index = basic_block.ok_or_else(error_unreachable)?;
cfg.add_jump(bb_index, *if_true);
cfg.add_jump(bb_index, *if_false);
basic_block = None;
}
Statement::Instruction(instruction) => {
let modes = get_modes(instruction);
}
_ => continue,
}
} }
} }
_ => continue, _ => continue,
@ -280,6 +475,169 @@ impl<T: Copy + Eq + Hash> UniqueVec<T> {
} }
} }
fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
match inst {
// TODO: review it when implementing virtual calls
ast::Instruction::Call { .. }
| ast::Instruction::Mov { .. }
| ast::Instruction::Ld { .. }
| ast::Instruction::St { .. }
| ast::Instruction::PrmtSlow { .. }
| ast::Instruction::Prmt { .. }
| ast::Instruction::Activemask { .. }
| ast::Instruction::Membar { .. }
| ast::Instruction::Trap {}
| ast::Instruction::Not { .. }
| ast::Instruction::Or { .. }
| ast::Instruction::And { .. }
| ast::Instruction::Bra { .. }
| ast::Instruction::Clz { .. }
| ast::Instruction::Brev { .. }
| ast::Instruction::Popc { .. }
| ast::Instruction::Xor { .. }
| ast::Instruction::Rem { .. }
| ast::Instruction::Bfe { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Shr { .. }
| ast::Instruction::Shl { .. }
| ast::Instruction::Selp { .. }
| ast::Instruction::Ret { .. }
| ast::Instruction::Bar { .. }
| ast::Instruction::Cvta { .. }
| ast::Instruction::Atom { .. }
| ast::Instruction::AtomCas { .. } => InstructionModes::none(),
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),
..
}
| ast::Instruction::Sub {
data: ast::ArithDetails::Integer(..),
..
}
| ast::Instruction::Mul {
data: ast::MulDetails::Integer { .. },
..
}
| ast::Instruction::Mad {
data: ast::MadDetails::Integer { .. },
..
}
| ast::Instruction::Min {
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
..
}
| ast::Instruction::Max {
data: ast::MinMaxDetails::Signed(..) | ast::MinMaxDetails::Unsigned(..),
..
}
| ast::Instruction::Div {
data: ast::DivDetails::Signed(..) | ast::DivDetails::Unsigned(..),
..
} => InstructionModes::none(),
ast::Instruction::Fma { data, .. }
| ast::Instruction::Sub {
data: ast::ArithDetails::Float(data),
..
}
| ast::Instruction::Mul {
data: ast::MulDetails::Float(data),
..
}
| ast::Instruction::Mad {
data: ast::MadDetails::Float(data),
..
}
| ast::Instruction::Add {
data: ast::ArithDetails::Float(data),
..
} => InstructionModes::from_arith_float(data),
ast::Instruction::Setp {
data:
ast::SetpData {
type_,
flush_to_zero,
..
},
..
}
| ast::Instruction::SetpBool {
data:
ast::SetpBoolData {
base:
ast::SetpData {
type_,
flush_to_zero,
..
},
..
},
..
}
| ast::Instruction::Neg {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Ex2 {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Rsqrt {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Abs {
data: ast::TypeFtz {
type_,
flush_to_zero,
},
..
}
| ast::Instruction::Min {
data:
ast::MinMaxDetails::Float(ast::MinMaxFloat {
type_,
flush_to_zero,
..
}),
..
}
| ast::Instruction::Max {
data:
ast::MinMaxDetails::Float(ast::MinMaxFloat {
type_,
flush_to_zero,
..
}),
..
}
| ast::Instruction::Div {
data:
ast::DivDetails::Float(ast::DivFloatDetails {
type_,
flush_to_zero,
..
}),
..
} => InstructionModes::from_ftz(*type_, *flush_to_zero),
ast::Instruction::Sin { data, .. }
| ast::Instruction::Cos { data, .. }
| ast::Instruction::Lg2 { data, .. } => InstructionModes::from_ftz_f32(data.flush_to_zero),
ast::Instruction::Rcp { data, .. } | ast::Instruction::Sqrt { data, .. } => {
InstructionModes::from_rcp(*data)
}
ast::Instruction::Cvt { data, .. } => InstructionModes::from_cvt(data),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -19,8 +19,8 @@ use ptx_parser as ast;
*/ */
pub(super) fn run<'input>( pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
@ -29,8 +29,8 @@ pub(super) fn run<'input>(
fn run_directive<'a, 'input>( fn run_directive<'a, 'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => { Directive2::Method(mut method) => {

View file

@ -44,7 +44,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
let directives = replace_known_functions::run(&flat_resolver, directives); let directives = replace_known_functions::run(&mut flat_resolver, directives);
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?; let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
@ -559,22 +559,23 @@ type NormalizedStatement = Statement<
ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>,
>; >;
enum Directive2<'input, Instruction, Operand: ast::Operand> { enum Directive2<Instruction, Operand: ast::Operand> {
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>), Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
Method(Function2<'input, Instruction, Operand>), Method(Function2<Instruction, Operand>),
} }
struct Function2<'input, Instruction, Operand: ast::Operand> { struct Function2<Instruction, Operand: ast::Operand> {
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>, pub return_arguments: Vec<ast::Variable<Operand::Ident>>,
pub globals: Vec<ast::Variable<SpirvWord>>, pub name: Operand::Ident,
pub input_arguments: Vec<ast::Variable<Operand::Ident>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>, pub body: Option<Vec<Statement<Instruction, Operand>>>,
is_kernel: bool,
import_as: Option<String>, import_as: Option<String>,
tuning: Vec<ast::TuningDirective>, tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective, linkage: ast::LinkingDirective,
} }
type NormalizedDirective2<'input> = Directive2< type NormalizedDirective2 = Directive2<
'input,
( (
Option<ast::PredAt<SpirvWord>>, Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::Instruction<ast::ParsedOperand<SpirvWord>>,
@ -582,8 +583,7 @@ type NormalizedDirective2<'input> = Directive2<
ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>,
>; >;
type NormalizedFunction2<'input> = Function2< type NormalizedFunction2 = Function2<
'input,
( (
Option<ast::PredAt<SpirvWord>>, Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::Instruction<ast::ParsedOperand<SpirvWord>>,
@ -591,17 +591,11 @@ type NormalizedFunction2<'input> = Function2<
ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>,
>; >;
type UnconditionalDirective<'input> = Directive2< type UnconditionalDirective =
'input, Directive2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalFunction<'input> = Function2< type UnconditionalFunction =
'input, Function2<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
struct GlobalStringIdentResolver2<'input> { struct GlobalStringIdentResolver2<'input> {
pub(crate) current_id: SpirvWord, pub(crate) current_id: SpirvWord,
@ -807,47 +801,45 @@ impl SpecialRegistersMap2 {
self.id_to_reg.get(&id).copied() self.id_to_reg.get(&id).copied()
} }
fn generate_declarations<'a, 'input>( fn len() -> usize {
PtxSpecialRegister::iter().len()
}
fn foreach_declaration<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>, resolver: &'a mut GlobalStringIdentResolver2<'input>,
) -> impl ExactSizeIterator< mut fn_: impl FnMut(
Item = (
PtxSpecialRegister, PtxSpecialRegister,
ast::MethodDeclaration<'input, SpirvWord>, (
Vec<ast::Variable<SpirvWord>>,
SpirvWord,
Vec<ast::Variable<SpirvWord>>,
),
), ),
> + 'a { ) {
PtxSpecialRegister::iter().map(|sreg| { for sreg in PtxSpecialRegister::iter() {
let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let name = let name = resolver.register_named(Cow::Owned(external_fn_name), None);
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
let return_type = sreg.get_function_return_type(); let return_type = sreg.get_function_return_type();
let input_type = sreg.get_function_input_type(); let input_type = sreg.get_function_input_type();
( let return_arguments = vec![ast::Variable {
sreg, align: None,
ast::MethodDeclaration { v_type: return_type.into(),
return_arguments: vec![ast::Variable { state_space: ast::StateSpace::Reg,
align: None, name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
v_type: return_type.into(), array_init: Vec::new(),
state_space: ast::StateSpace::Reg, }];
name: resolver let input_arguments = input_type
.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), .into_iter()
array_init: Vec::new(), .map(|type_| ast::Variable {
}], align: None,
name: name, v_type: type_.into(),
input_arguments: input_type state_space: ast::StateSpace::Reg,
.into_iter() name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
.map(|type_| ast::Variable { array_init: Vec::new(),
align: None, })
v_type: type_.into(), .collect::<Vec<_>>();
state_space: ast::StateSpace::Reg, fn_(sreg, (return_arguments, name, input_arguments));
name: resolver }
.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
array_init: Vec::new(),
})
.collect::<Vec<_>>(),
shared_mem: None,
},
)
})
} }
} }

View file

@ -4,7 +4,7 @@ use ptx_parser as ast;
pub(crate) fn run<'input, 'b>( pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>, directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> { ) -> Result<Vec<NormalizedDirective2>, TranslateError> {
resolver.start_scope(); resolver.start_scope();
let result = directives let result = directives
.into_iter() .into_iter()
@ -17,7 +17,7 @@ pub(crate) fn run<'input, 'b>(
fn run_directive<'input, 'b>( fn run_directive<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> { ) -> Result<NormalizedDirective2, TranslateError> {
Ok(match directive { Ok(match directive {
ast::Directive::Variable(linking, var) => { ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?) NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
@ -32,15 +32,11 @@ fn run_method<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective, linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>, method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> { ) -> Result<NormalizedFunction2, TranslateError> {
let name = match method.func_directive.name { let is_kernel = method.func_directive.name.is_kernel();
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), let name = resolver.add_or_get_in_current_scope_untyped(method.func_directive.name.text())?;
ast::MethodName::Func(text) => {
ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?)
}
};
resolver.start_scope(); resolver.start_scope();
let func_decl = run_function_decl(resolver, method.func_directive, name)?; let (return_arguments, input_arguments) = run_function_decl(resolver, method.func_directive)?;
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
@ -51,20 +47,21 @@ fn run_method<'input, 'b>(
.transpose()?; .transpose()?;
resolver.end_scope(); resolver.end_scope();
Ok(Function2 { Ok(Function2 {
func_decl, return_arguments,
globals: Vec::new(), name,
input_arguments,
body, body,
import_as: None, import_as: None,
tuning: method.tuning, tuning: method.tuning,
linkage, linkage,
is_kernel,
}) })
} }
fn run_function_decl<'input, 'b>( fn run_function_decl<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>, func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>, ) -> Result<(Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
assert!(func_directive.shared_mem.is_none()); assert!(func_directive.shared_mem.is_none());
let return_arguments = func_directive let return_arguments = func_directive
.return_arguments .return_arguments
@ -76,12 +73,7 @@ fn run_function_decl<'input, 'b>(
.into_iter() .into_iter()
.map(|var| run_variable(resolver, var)) .map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(ast::MethodDeclaration { Ok((return_arguments, input_arguments))
return_arguments,
name,
input_arguments,
shared_mem: None,
})
} }
fn run_variable<'input, 'b>( fn run_variable<'input, 'b>(

View file

@ -3,8 +3,8 @@ use ptx_parser as ast;
pub(crate) fn run<'input>( pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<NormalizedDirective2<'input>>, directives: Vec<NormalizedDirective2>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> { ) -> Result<Vec<UnconditionalDirective>, TranslateError> {
directives directives
.into_iter() .into_iter()
.map(|directive| run_directive(resolver, directive)) .map(|directive| run_directive(resolver, directive))
@ -13,8 +13,8 @@ pub(crate) fn run<'input>(
fn run_directive<'input>( fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directive: NormalizedDirective2<'input>, directive: NormalizedDirective2,
) -> Result<UnconditionalDirective<'input>, TranslateError> { ) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive { Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var), Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
@ -23,8 +23,8 @@ fn run_directive<'input>(
fn run_method<'input>( fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
method: NormalizedFunction2<'input>, method: NormalizedFunction2,
) -> Result<UnconditionalFunction<'input>, TranslateError> { ) -> Result<UnconditionalFunction, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
@ -36,12 +36,14 @@ fn run_method<'input>(
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
func_decl: method.func_decl, return_arguments: method.return_arguments,
globals: method.globals, name: method.name,
input_arguments: method.input_arguments,
body, body,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel,
}) })
} }

View file

@ -2,8 +2,8 @@ use super::*;
pub(super) fn run<'input>( pub(super) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut fn_declarations = FxHashMap::default(); let mut fn_declarations = FxHashMap::default();
let remapped_directives = directives let remapped_directives = directives
.into_iter() .into_iter()
@ -13,17 +13,14 @@ pub(super) fn run<'input>(
.into_iter() .into_iter()
.map(|(_, (return_arguments, name, input_arguments))| { .map(|(_, (return_arguments, name, input_arguments))| {
Directive2::Method(Function2 { Directive2::Method(Function2 {
func_decl: ast::MethodDeclaration { return_arguments,
return_arguments, name: name,
name: ast::MethodName::Func(name), input_arguments,
input_arguments,
shared_mem: None,
},
globals: Vec::new(),
body: None, body: None,
import_as: None, import_as: None,
tuning: Vec::new(), tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN, linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
}) })
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -41,8 +38,8 @@ fn run_directive<'input>(
Vec<ast::Variable<SpirvWord>>, Vec<ast::Variable<SpirvWord>>,
), ),
>, >,
directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> { ) -> Result<Directive2<ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(mut method) => { Directive2::Method(mut method) => {

View file

@ -1,14 +1,15 @@
use std::borrow::Cow;
use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord}; use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
pub(crate) fn run<'input>( pub(crate) fn run<'input>(
resolver: &GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
mut directives: Vec<NormalizedDirective2<'input>>, mut directives: Vec<NormalizedDirective2>,
) -> Vec<NormalizedDirective2<'input>> { ) -> Vec<NormalizedDirective2> {
for directive in directives.iter_mut() { for directive in directives.iter_mut() {
match directive { match directive {
NormalizedDirective2::Method(func) => { NormalizedDirective2::Method(func) => {
func.import_as = replace_with_ptx_impl(resolver, func.name);
replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take());
} }
_ => {} _ => {}
} }
@ -17,22 +18,16 @@ pub(crate) fn run<'input>(
} }
fn replace_with_ptx_impl<'input>( fn replace_with_ptx_impl<'input>(
resolver: &GlobalStringIdentResolver2<'input>, resolver: &mut GlobalStringIdentResolver2<'input>,
fn_name: &ptx_parser::MethodName<'input, SpirvWord>, fn_name: SpirvWord,
name: Option<String>, ) {
) -> Option<String> {
let known_names = ["__assertfail"]; let known_names = ["__assertfail"];
match name { if let Some(super::IdentEntry {
Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)), name: Some(name), ..
Some(name) => Some(name), }) = resolver.ident_map.get_mut(&fn_name)
None => match fn_name { {
ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) { if known_names.contains(&&**name) {
Some(super::IdentEntry { *name = Cow::Owned(format!("__zluda_ptx_impl_{}", name));
name: Some(name), .. }
}) => Some(format!("__zluda_ptx_impl_{}", name)),
_ => None,
},
ptx_parser::MethodName::Kernel(..) => None,
},
} }
} }

View file

@ -3,8 +3,8 @@ use ptx_parser as ast;
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
pub(crate) fn run<'input>( pub(crate) fn run<'input>(
directives: Vec<UnconditionalDirective<'input>>, directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> { ) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut functions = FxHashSet::default(); let mut functions = FxHashSet::default();
directives directives
.into_iter() .into_iter()
@ -14,19 +14,13 @@ pub(crate) fn run<'input>(
fn run_directive<'input>( fn run_directive<'input>(
functions: &mut FxHashSet<SpirvWord>, functions: &mut FxHashSet<SpirvWord>,
directive: UnconditionalDirective<'input>, directive: UnconditionalDirective,
) -> Result<UnconditionalDirective<'input>, TranslateError> { ) -> Result<UnconditionalDirective, TranslateError> {
Ok(match directive { Ok(match directive {
var @ Directive2::Variable(..) => var, var @ Directive2::Variable(..) => var,
Directive2::Method(method) => { Directive2::Method(method) => {
{ if !method.is_kernel {
let func_decl = &method.func_decl; functions.insert(method.name);
match func_decl.name {
ptx_parser::MethodName::Kernel(_) => {}
ptx_parser::MethodName::Func(name) => {
functions.insert(name);
}
}
} }
Directive2::Method(run_method(functions, method)?) Directive2::Method(run_method(functions, method)?)
} }
@ -35,8 +29,8 @@ fn run_directive<'input>(
fn run_method<'input>( fn run_method<'input>(
functions: &mut FxHashSet<SpirvWord>, functions: &mut FxHashSet<SpirvWord>,
method: UnconditionalFunction<'input>, method: UnconditionalFunction,
) -> Result<UnconditionalFunction<'input>, TranslateError> { ) -> Result<UnconditionalFunction, TranslateError> {
let body = method let body = method
.body .body
.map(|statements| { .map(|statements| {
@ -47,12 +41,14 @@ fn run_method<'input>(
}) })
.transpose()?; .transpose()?;
Ok(Function2 { Ok(Function2 {
func_decl: method.func_decl, return_arguments: method.return_arguments,
globals: method.globals, name: method.name,
input_arguments: method.input_arguments,
body, body,
import_as: method.import_as, import_as: method.import_as,
tuning: method.tuning, tuning: method.tuning,
linkage: method.linkage, linkage: method.linkage,
is_kernel: method.is_kernel,
}) })
} }

View file

@ -1028,7 +1028,7 @@ pub struct ArithFloat {
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular, // round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add // mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
// instructions on the target device. // instructions on the target device.
pub is_fusable: bool pub is_fusable: bool,
} }
#[derive(Copy, Clone, PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq)]
@ -1447,6 +1447,7 @@ pub struct CvtDetails {
pub mode: CvtMode, pub mode: CvtMode,
} }
#[derive(Clone, Copy)]
pub enum CvtMode { pub enum CvtMode {
// int from int // int from int
ZeroExtend, ZeroExtend,
@ -1465,7 +1466,7 @@ pub enum CvtMode {
flush_to_zero: Option<bool>, flush_to_zero: Option<bool>,
}, },
FPRound { FPRound {
integer_rounding: Option<RoundingMode>, integer_rounding: RoundingMode,
flush_to_zero: Option<bool>, flush_to_zero: Option<bool>,
}, },
// int from float // int from float
@ -1519,7 +1520,7 @@ impl CvtDetails {
flush_to_zero, flush_to_zero,
}, },
Ordering::Equal => CvtMode::FPRound { Ordering::Equal => CvtMode::FPRound {
integer_rounding: rounding, integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven),
flush_to_zero, flush_to_zero,
}, },
Ordering::Greater => { Ordering::Greater => {