Add mode-setting wrappers to functions

This commit is contained in:
Andrzej Janik 2025-03-13 16:49:06 +00:00
parent a0d4b7eeb2
commit 87fe601494
3 changed files with 323 additions and 94 deletions

View file

@ -734,7 +734,9 @@ pub(crate) fn run<'input>(
} }
Statement::RetValue(..) Statement::RetValue(..)
| Statement::Instruction(ast::Instruction::Ret { .. }) => { | Statement::Instruction(ast::Instruction::Ret { .. }) => {
bb_state.record_ret(*name)?; if !is_kernel {
bb_state.record_ret(*name)?;
}
} }
Statement::Label(label) => { Statement::Label(label) => {
bb_state.start(*label); bb_state.start(*label);
@ -808,7 +810,7 @@ pub(crate) fn run<'input>(
)?; )?;
let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?; let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?;
*/ */
let directives = insert_mode_control(directives, temp)?; let directives = insert_mode_control(flat_resolver, directives, temp)?;
Ok(directives) Ok(directives)
} }
@ -1018,47 +1020,48 @@ struct TwinMode<T> {
} }
fn insert_mode_control( fn insert_mode_control(
flat_resolver: &mut super::GlobalStringIdentResolver2,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
global_modes: FullModeInsertion2, global_modes: FullModeInsertion2,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let directives_len = directives.len(); let directives_len = directives.len();
directives directives
.into_iter() .into_iter()
.map(|mut directive| { .map(|directive| {
let mut new_directives = SmallVec::<[_; 4]>::new(); let mut new_directives = SmallVec::<[_; 4]>::new();
let (fn_name, initial_mode, body_ptr) = match directive { let (mut method, initial_mode) = match directive {
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => { Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => {
new_directives.push(directive); new_directives.push(directive);
return Ok(new_directives); return Ok(new_directives);
} }
Directive2::Method(Function2 { Directive2::Method(
name, mut method @ Function2 {
body: Some(ref mut body), name,
ref mut flush_to_zero_f32, body: Some(_),
ref mut flush_to_zero_f16f64, ..
ref mut rounding_mode_f32, },
ref mut rounding_mode_f16f64, ) => {
..
}) => {
let initial_mode = global_modes let initial_mode = global_modes
.basic_blocks .basic_blocks
.get(&name) .get(&name)
.ok_or_else(error_unreachable)?; .ok_or_else(error_unreachable)?;
let denormal_mode = initial_mode.denormal.twin_mode; let denormal_mode = initial_mode.denormal.twin_mode;
let rounding_mode = initial_mode.rounding.twin_mode; let rounding_mode = initial_mode.rounding.twin_mode;
*flush_to_zero_f32 = denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz(); method.flush_to_zero_f32 =
*flush_to_zero_f16f64 = denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz();
method.flush_to_zero_f16f64 =
denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz(); denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz();
*rounding_mode_f32 = rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast(); method.rounding_mode_f32 =
*rounding_mode_f16f64 = rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast();
method.rounding_mode_f16f64 =
rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast(); rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast();
(name, initial_mode, body) (method, initial_mode)
} }
}; };
emit_mode_prelude(fn_name, &mut new_directives); emit_mode_prelude(flat_resolver, &method, &global_modes, &mut new_directives)?;
let old_body = mem::replace(body_ptr, Vec::new()); let old_body = method.body.take().unwrap();
let mut result = Vec::with_capacity(old_body.len()); let mut result = Vec::with_capacity(old_body.len());
let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode); let mut bb_state = BasicBlockControlState::new(&global_modes, initial_mode);
let mut old_body = old_body.into_iter(); let mut old_body = old_body.into_iter();
while let Some(mut statement) = old_body.next() { while let Some(mut statement) = old_body.next() {
let mut call_target = None; let mut call_target = None;
@ -1115,8 +1118,8 @@ fn insert_mode_control(
} }
} }
} }
*body_ptr = result; method.body = Some(result);
new_directives.push(directive); new_directives.push(Directive2::Method(method));
Ok(new_directives) Ok(new_directives)
}) })
.try_fold(Vec::with_capacity(directives_len), |mut acc, d| { .try_fold(Vec::with_capacity(directives_len), |mut acc, d| {
@ -1126,35 +1129,219 @@ fn insert_mode_control(
} }
fn emit_mode_prelude( fn emit_mode_prelude(
fn_name: SpirvWord, flat_resolver: &mut super::GlobalStringIdentResolver2,
global_modes: FullModeInsertion2, method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
global_modes: &FullModeInsertion2,
new_directives: &mut SmallVec<[Directive2<ptx_parser::Instruction<SpirvWord>, SpirvWord>; 4]>, new_directives: &mut SmallVec<[Directive2<ptx_parser::Instruction<SpirvWord>, SpirvWord>; 4]>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let fn_mode_state = global_modes.basic_blocks.get(&fn_name).ok_or_else(error_unreachable)?; let fn_mode_state = global_modes
.basic_blocks
.get(&method.name)
.ok_or_else(error_unreachable)?;
if let Some(dual_prologue) = fn_mode_state.dual_prologue { if let Some(dual_prologue) = fn_mode_state.dual_prologue {
new_directives.push(Directive2::Method( new_directives.push(create_fn_wrapper(
Function2 { flat_resolver,
return_arguments: todo!(), method,
name: dual_prologue, dual_prologue,
input_arguments: todo!(), [
body: todo!(), ModeRegister::Denormal {
is_kernel: false, f32: fn_mode_state
import_as: None, .denormal
tuning: Vec::new(), .twin_mode
linkage: ast::LinkingDirective::NONE, .f32
flush_to_zero_f32: false, .unwrap_of_default()
flush_to_zero_f16f64: false, .to_ftz(),
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, f16f64: fn_mode_state
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, .denormal
} .twin_mode
.f16f64
.unwrap_of_default()
.to_ftz(),
},
ModeRegister::Rounding {
f32: fn_mode_state
.rounding
.twin_mode
.f32
.unwrap_of_default()
.to_ast(),
f16f64: fn_mode_state
.rounding
.twin_mode
.f16f64
.unwrap_of_default()
.to_ast(),
},
]
.into_iter(),
)); ));
} }
if let Some(prologue) = fn_mode_state.denormal.prologue { if let Some(prologue) = fn_mode_state.denormal.prologue {
todo!() new_directives.push(create_fn_wrapper(
flat_resolver,
method,
prologue,
[ModeRegister::Denormal {
f32: fn_mode_state
.denormal
.twin_mode
.f32
.unwrap_of_default()
.to_ftz(),
f16f64: fn_mode_state
.denormal
.twin_mode
.f16f64
.unwrap_of_default()
.to_ftz(),
}]
.into_iter(),
));
} }
if let Some(prologue) = fn_mode_state.rounding.prologue { if let Some(prologue) = fn_mode_state.rounding.prologue {
todo!() new_directives.push(create_fn_wrapper(
flat_resolver,
method,
prologue,
[ModeRegister::Rounding {
f32: fn_mode_state
.rounding
.twin_mode
.f32
.unwrap_of_default()
.to_ast(),
f16f64: fn_mode_state
.rounding
.twin_mode
.f16f64
.unwrap_of_default()
.to_ast(),
}]
.into_iter(),
));
} }
Ok(())
}
fn create_fn_wrapper(
flat_resolver: &mut super::GlobalStringIdentResolver2,
method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
name: SpirvWord,
modes: impl ExactSizeIterator<Item = ModeRegister>,
) -> Directive2<ast::Instruction<SpirvWord>, SpirvWord> {
// * Label
// * return argument registers
// * input argument registers
// * Load input arguments
// * set modes
// * call
// * return with value
let return_arguments = rename_variables(flat_resolver, &method.return_arguments);
let input_arguments = rename_variables(flat_resolver, &method.input_arguments);
let mut body = Vec::with_capacity(
1 + (input_arguments.len() * 2) + return_arguments.len() + modes.len() + 2,
);
body.push(Statement::Label(flat_resolver.register_unnamed(None)));
let return_variables = append_variables(flat_resolver, &mut body, &return_arguments);
let input_variables = append_variables(flat_resolver, &mut body, &input_arguments);
for (index, input_reg) in input_variables.iter().enumerate() {
body.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: input_arguments[index].state_space,
caching: ast::LdCacheOperator::Cached,
typ: input_arguments[index].v_type.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
src: input_arguments[index].name,
dst: *input_reg,
},
}));
}
body.extend(modes.map(|mode_set| Statement::SetMode(mode_set)));
// Out of order because we want to use return_variables before they are moved
let ret_statement = if return_arguments.is_empty() {
Statement::Instruction(ast::Instruction::Ret {
data: ast::RetData { uniform: false },
})
} else {
Statement::RetValue(
ast::RetData { uniform: false },
return_variables
.iter()
.enumerate()
.map(|(index, var)| (*var, method.return_arguments[index].v_type.clone()))
.collect(),
)
};
body.push(Statement::Instruction(ast::Instruction::Call {
data: ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|arg| (arg.v_type.clone(), arg.state_space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|arg| (arg.v_type.clone(), arg.state_space))
.collect(),
},
arguments: ast::CallArgs {
return_arguments: return_variables,
func: method.name,
input_arguments: input_variables,
},
}));
body.push(ret_statement);
Directive2::Method(Function2 {
return_arguments,
name,
input_arguments,
body: Some(body),
is_kernel: false,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::NONE,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
}
fn rename_variables(
flat_resolver: &mut super::GlobalStringIdentResolver2,
variables: &Vec<ast::Variable<SpirvWord>>,
) -> Vec<ast::Variable<SpirvWord>> {
variables
.iter()
.cloned()
.map(|arg| ast::Variable {
name: flat_resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))),
..arg
})
.collect()
}
fn append_variables<'a, 'input: 'a>(
flat_resolver: &'a mut super::GlobalStringIdentResolver2<'input>,
body: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
arguments: &'a Vec<ast::Variable<SpirvWord>>,
) -> Vec<SpirvWord> {
let mut result = Vec::with_capacity(arguments.len());
for arg in arguments {
let name = flat_resolver.register_unnamed(Some((arg.v_type.clone(), ast::StateSpace::Reg)));
body.push(Statement::Variable(ast::Variable {
align: None,
v_type: arg.v_type.clone(),
state_space: ast::StateSpace::Reg,
name,
array_init: Vec::new(),
}));
result.push(name);
}
result
} }
struct BasicBlockControlState<'a> { struct BasicBlockControlState<'a> {
@ -1163,7 +1350,6 @@ struct BasicBlockControlState<'a> {
denormal_f16f64: RegisterState<bool>, denormal_f16f64: RegisterState<bool>,
rounding_f32: RegisterState<ast::RoundingMode>, rounding_f32: RegisterState<ast::RoundingMode>,
rounding_f16f64: RegisterState<ast::RoundingMode>, rounding_f16f64: RegisterState<ast::RoundingMode>,
current_bb: SpirvWord,
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
@ -1175,20 +1361,6 @@ struct RegisterState<T> {
} }
impl<T> RegisterState<T> { impl<T> RegisterState<T> {
fn single(t: T) -> Self {
RegisterState {
last_foldable: None,
current_value: Resolved::Value(t),
}
}
fn conflict() -> Self {
RegisterState {
last_foldable: None,
current_value: Resolved::Conflict,
}
}
fn new<U>(value: Resolved<U>) -> RegisterState<T> fn new<U>(value: Resolved<U>) -> RegisterState<T>
where where
U: Into<T>, U: Into<T>,
@ -1201,11 +1373,7 @@ impl<T> RegisterState<T> {
} }
impl<'a> BasicBlockControlState<'a> { impl<'a> BasicBlockControlState<'a> {
fn new( fn new(global_modes: &'a FullModeInsertion2, initial_mode: &FullBasicBlockEntryState) -> Self {
global_modes: &'a FullModeInsertion2,
current_bb: SpirvWord,
initial_mode: &FullBasicBlockEntryState,
) -> Self {
let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32); let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32);
let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64); let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64);
let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32); let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32);
@ -1216,7 +1384,6 @@ impl<'a> BasicBlockControlState<'a> {
denormal_f16f64, denormal_f16f64,
rounding_f32, rounding_f32,
rounding_f16f64, rounding_f16f64,
current_bb,
} }
} }
@ -1405,12 +1572,6 @@ fn redirect_jump_impl(
Ok(()) Ok(())
} }
struct ModeJumpTargets {
dual_prologue: Option<SpirvWord>,
denormal: Option<SpirvWord>,
rounding: Option<SpirvWord>,
}
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
enum Resolved<T> { enum Resolved<T> {
Conflict, Conflict,
@ -1457,13 +1618,6 @@ impl<T> Resolved<T> {
Resolved::Conflict => Err(err()), Resolved::Conflict => Err(err()),
} }
} }
fn has_value(&self) -> bool {
match self {
Resolved::Value(_) => true,
Resolved::Conflict => false,
}
}
} }
trait ModeView { trait ModeView {

View file

@ -258,30 +258,83 @@ fn call_with_mode() {
)); ));
let [to_fn0] = calls(method_1); let [to_fn0] = calls(method_1);
let [_, dual_prelude, _, _, add] = labels(method_1); let [_, dual_prelude, _, _, add] = labels(method_1);
let [post_call, post_prelude_0, post_prelude_1, post_prelude_2] = branches(method_1); let [post_call, post_prelude_dual, post_prelude_denormal, post_prelude_rounding] =
branches(method_1);
assert_eq!(methods[0].name, to_fn0); assert_eq!(methods[0].name, to_fn0);
assert_eq!(post_call, dual_prelude); assert_eq!(post_call, dual_prelude);
assert_eq!(post_prelude_0, add); assert_eq!(post_prelude_dual, add);
assert_eq!(post_prelude_1, add); assert_eq!(post_prelude_denormal, add);
assert_eq!(post_prelude_2, add); assert_eq!(post_prelude_rounding, add);
let method_2 = methods[2].body.as_ref().unwrap(); let method_2 = methods[2].body.as_ref().unwrap();
assert!(matches!( assert!(matches!(
&**method_2, &**method_2,
[ [
Statement::Label(..), Statement::Label(..),
Statement::Variable(..),
Statement::Variable(..),
Statement::Conditional(..),
Statement::Label(..),
Statement::Conditional(..),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
// Dual prelude
Statement::SetMode(ModeRegister::Denormal { Statement::SetMode(ModeRegister::Denormal {
f32: true, f32: false,
f16f64: true f16f64: true
}), }),
Statement::SetMode(ModeRegister::Rounding { Statement::SetMode(ModeRegister::Rounding {
f32: ast::RoundingMode::PositiveInf, f32: ast::RoundingMode::NegativeInf,
f16f64: ast::RoundingMode::NearestEven f16f64: ast::RoundingMode::NearestEven
}), }),
Statement::Instruction(ast::Instruction::Call { .. }), Statement::Instruction(ast::Instruction::Bra { .. }),
// Denormal prelude
Statement::Label(..),
Statement::SetMode(ModeRegister::Denormal {
f32: false,
f16f64: true
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
// Rounding prelude
Statement::Label(..),
Statement::SetMode(ModeRegister::Rounding {
f32: ast::RoundingMode::NegativeInf,
f16f64: ast::RoundingMode::NearestEven
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::SetMode(ModeRegister::Denormal {
f32: false,
f16f64: true
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Ret { .. }), Statement::Instruction(ast::Instruction::Ret { .. }),
] ]
)); ));
let [(if_rm_true, if_rm_false), (if_rz_true, if_rz_false)] = conditionals(method_2);
let [_, conditional2, post_conditional2, prelude_dual, _, _, add1, add2_set_denormal, add2, ret] =
labels(method_2);
let [post_conditional2_jump, post_prelude_dual, post_prelude_denormal, post_prelude_rounding, post_add1, post_add2_set_denormal, post_add2] =
branches(method_2);
assert_eq!(if_rm_true, prelude_dual);
assert_eq!(if_rm_false, conditional2);
assert_eq!(if_rz_true, post_conditional2);
assert_eq!(if_rz_false, add2_set_denormal);
assert_eq!(post_conditional2_jump, prelude_dual);
assert_eq!(post_prelude_dual, add1);
assert_eq!(post_prelude_denormal, add1);
assert_eq!(post_prelude_rounding, add1);
assert_eq!(post_add1, ret);
assert_eq!(post_add2_set_denormal, add2);
assert_eq!(post_add2, ret);
} }
fn branches<const N: usize>( fn branches<const N: usize>(
@ -303,10 +356,12 @@ fn labels<const N: usize>(
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> [SpirvWord; N] { ) -> [SpirvWord; N] {
fn_.iter() fn_.iter()
.filter_map(|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s { .filter_map(
Statement::Label(label) => Some(*label), |s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
_ => None, Statement::Label(label) => Some(*label),
}) _ => None,
},
)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()
.unwrap() .unwrap()
@ -317,7 +372,25 @@ fn calls<const N: usize>(
) -> [SpirvWord; N] { ) -> [SpirvWord; N] {
fn_.iter() fn_.iter()
.filter_map(|s| match s { .filter_map(|s| match s {
Statement::Instruction(ast::Instruction::Call { arguments: ast::CallArgs { func,.. }, .. }) => Some(*func), Statement::Instruction(ast::Instruction::Call {
arguments: ast::CallArgs { func, .. },
..
}) => Some(*func),
_ => None,
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
fn conditionals<const N: usize>(
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> [(SpirvWord, SpirvWord); N] {
fn_.iter()
.filter_map(|s| match s {
Statement::Conditional(BrachCondition {
if_true, if_false, ..
}) => Some((*if_true, *if_false)),
_ => None, _ => None,
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()

View file

@ -7,7 +7,7 @@ use super::*;
// represent kernels as separate nodes with its own separate entry/exit mode // represent kernels as separate nodes with its own separate entry/exit mode
// * Inserts label at the start of every basic block // * Inserts label at the start of every basic block
// * Insert explicit jumps before labels // * Insert explicit jumps before labels
// * Functions get a single `ret;` exit point - this is because mode computation // * Non-.entry methods get a single `ret;` exit point - this is because mode computation
// logic requires it. Control flow graph constructed by mode computation // logic requires it. Control flow graph constructed by mode computation
// models function calls as jumps into and then from another function. // models function calls as jumps into and then from another function.
// If this cfg allowed multiple return basic blocks then there would be cases // If this cfg allowed multiple return basic blocks then there would be cases
@ -19,10 +19,10 @@ pub(crate) fn run(
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> { ) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
for directive in directives.iter_mut() { for directive in directives.iter_mut() {
let body_ref = match directive { let (body_ref, is_kernel) = match directive {
Directive2::Method(Function2 { Directive2::Method(Function2 {
body: Some(body), .. body: Some(body), is_kernel, ..
}) => body, }) => (body, *is_kernel),
_ => continue, _ => continue,
}; };
let body = std::mem::replace(body_ref, Vec::new()); let body = std::mem::replace(body_ref, Vec::new());
@ -74,7 +74,9 @@ pub(crate) fn run(
return Err(error_unreachable()); return Err(error_unreachable());
} }
Statement::Instruction(ast::Instruction::Ret { .. }) => { Statement::Instruction(ast::Instruction::Ret { .. }) => {
return_statements.push(result.len()) if !is_kernel {
return_statements.push(result.len());
}
} }
_ => {} _ => {}
} }