Fix most of the remaining bugs

This commit is contained in:
Andrzej Janik 2025-03-13 19:22:11 +00:00
parent 87fe601494
commit 04fbfea80a
3 changed files with 99 additions and 24 deletions

View file

@ -405,7 +405,19 @@ impl ControlFlowGraph {
node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock); node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock);
} }
fn fixup_function_calls(&mut self) { fn fixup_function_calls(&mut self) -> Result<(), TranslateError> {
for (fn_, follow_on_labels) in self.call_returns.iter() {
let connecting_bb = match self.functions_rets.get(fn_) {
Some(return_bb) => *return_bb,
// function is just a declaration
None => *self.basic_blocks.get(fn_).ok_or_else(error_unreachable)?,
};
for follow_on_label in follow_on_labels {
self.graph.add_edge(connecting_bb, *follow_on_label, ());
}
}
Ok(())
/*
for (function, source) in self.functions_rets.iter() { for (function, source) in self.functions_rets.iter() {
for target in self for target in self
.call_returns .call_returns
@ -418,6 +430,7 @@ impl ControlFlowGraph {
self.graph.add_edge(*source, target, ()); self.graph.add_edge(*source, target, ());
} }
} }
*/
} }
} }
@ -481,6 +494,7 @@ impl ResolvedControlFlowGraph {
} }
} }
} }
// This should happen only for orphaned basic blocks
mode.map(Resolved::Value).ok_or_else(error_unreachable) mode.map(Resolved::Value).ok_or_else(error_unreachable)
} }
fn resolve_mode<T: Eq + PartialEq + Copy + Default>( fn resolve_mode<T: Eq + PartialEq + Copy + Default>(
@ -730,7 +744,6 @@ pub(crate) fn run<'input>(
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
}; };
bb_state.record_call(*func, after_call_label)?; bb_state.record_call(*func, after_call_label)?;
//body_iter.next();
} }
Statement::RetValue(..) Statement::RetValue(..)
| Statement::Instruction(ast::Instruction::Ret { .. }) => { | Statement::Instruction(ast::Instruction::Ret { .. }) => {
@ -761,7 +774,7 @@ pub(crate) fn run<'input>(
"{:?}", "{:?}",
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
); );
cfg.fixup_function_calls(); cfg.fixup_function_calls()?;
println!( println!(
"{:?}", "{:?}",
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
@ -784,6 +797,7 @@ pub(crate) fn run<'input>(
)?; )?;
let temp = join_modes2( let temp = join_modes2(
flat_resolver, flat_resolver,
&directives,
cfg, cfg,
denormal_f32, denormal_f32,
denormal_f16f64, denormal_f16f64,
@ -816,6 +830,7 @@ pub(crate) fn run<'input>(
fn join_modes2( fn join_modes2(
flat_resolver: &mut super::GlobalStringIdentResolver2, flat_resolver: &mut super::GlobalStringIdentResolver2,
directives: &Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
cfg: ResolvedControlFlowGraph, cfg: ResolvedControlFlowGraph,
mandatory_denormal_f32: MandatoryModeInsertions<DenormalMode>, mandatory_denormal_f32: MandatoryModeInsertions<DenormalMode>,
mandatory_denormal_f16f64: MandatoryModeInsertions<DenormalMode>, mandatory_denormal_f16f64: MandatoryModeInsertions<DenormalMode>,
@ -877,6 +892,47 @@ fn join_modes2(
)) ))
}) })
.collect::<Result<FxHashMap<_, _>, _>>()?; .collect::<Result<FxHashMap<_, _>, _>>()?;
let temp = directives
.iter()
.filter_map(|directive| match directive {
Directive2::Method(Function2 {
name,
body: None,
is_kernel: false,
..
}) => {
let fn_bb = match cfg.basic_blocks.get(name) {
Some(bb) => bb,
None => return None,
};
let weights = cfg.graph.node_weight(*fn_bb).unwrap();
let modes = ResolvedInstructionModes {
denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz),
denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz),
rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast),
rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast),
};
Some(Ok((*name, modes)))
}
Directive2::Method(Function2 {
name,
body: Some(_),
is_kernel: false,
..
}) => {
let ret_bb = cfg.functions_rets.get(name).unwrap();
let weights = cfg.graph.node_weight(*ret_bb).unwrap();
let modes = ResolvedInstructionModes {
denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz),
denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz),
rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast),
rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast),
};
Some(Ok((*name, modes)))
}
_ => None,
})
.collect::<Result<FxHashMap<_, _>, _>>()?;
let functions_exit_modes = cfg let functions_exit_modes = cfg
.functions_rets .functions_rets
.into_iter() .into_iter()
@ -893,7 +949,7 @@ fn join_modes2(
.collect::<Result<FxHashMap<_, _>, _>>()?; .collect::<Result<FxHashMap<_, _>, _>>()?;
Ok(FullModeInsertion2 { Ok(FullModeInsertion2 {
basic_blocks, basic_blocks,
functions_exit_modes, functions_exit_modes: temp,
}) })
} }
@ -1149,13 +1205,13 @@ fn emit_mode_prelude(
.denormal .denormal
.twin_mode .twin_mode
.f32 .f32
.unwrap_of_default() .unwrap_or_default()
.to_ftz(), .to_ftz(),
f16f64: fn_mode_state f16f64: fn_mode_state
.denormal .denormal
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ftz(), .to_ftz(),
}, },
ModeRegister::Rounding { ModeRegister::Rounding {
@ -1163,13 +1219,13 @@ fn emit_mode_prelude(
.rounding .rounding
.twin_mode .twin_mode
.f32 .f32
.unwrap_of_default() .unwrap_or_default()
.to_ast(), .to_ast(),
f16f64: fn_mode_state f16f64: fn_mode_state
.rounding .rounding
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ast(), .to_ast(),
}, },
] ]
@ -1186,13 +1242,13 @@ fn emit_mode_prelude(
.denormal .denormal
.twin_mode .twin_mode
.f32 .f32
.unwrap_of_default() .unwrap_or_default()
.to_ftz(), .to_ftz(),
f16f64: fn_mode_state f16f64: fn_mode_state
.denormal .denormal
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ftz(), .to_ftz(),
}] }]
.into_iter(), .into_iter(),
@ -1208,13 +1264,13 @@ fn emit_mode_prelude(
.rounding .rounding
.twin_mode .twin_mode
.f32 .f32
.unwrap_of_default() .unwrap_or_default()
.to_ast(), .to_ast(),
f16f64: fn_mode_state f16f64: fn_mode_state
.rounding .rounding
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ast(), .to_ast(),
}] }]
.into_iter(), .into_iter(),
@ -1409,21 +1465,21 @@ impl<'a> BasicBlockControlState<'a> {
if let Some(prologue) = bb_state.dual_prologue { if let Some(prologue) = bb_state.dual_prologue {
statements.push(Statement::Label(prologue)); statements.push(Statement::Label(prologue));
statements.push(Statement::SetMode(ModeRegister::Denormal { statements.push(Statement::SetMode(ModeRegister::Denormal {
f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(), f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(),
f16f64: bb_state f16f64: bb_state
.denormal .denormal
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ftz(), .to_ftz(),
})); }));
statements.push(Statement::SetMode(ModeRegister::Rounding { statements.push(Statement::SetMode(ModeRegister::Rounding {
f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(), f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(),
f16f64: bb_state f16f64: bb_state
.rounding .rounding
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ast(), .to_ast(),
})); }));
statements.push(Statement::Instruction(ast::Instruction::Bra { statements.push(Statement::Instruction(ast::Instruction::Bra {
@ -1433,12 +1489,12 @@ impl<'a> BasicBlockControlState<'a> {
if let Some(prologue) = bb_state.denormal.prologue { if let Some(prologue) = bb_state.denormal.prologue {
statements.push(Statement::Label(prologue)); statements.push(Statement::Label(prologue));
statements.push(Statement::SetMode(ModeRegister::Denormal { statements.push(Statement::SetMode(ModeRegister::Denormal {
f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(), f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(),
f16f64: bb_state f16f64: bb_state
.denormal .denormal
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ftz(), .to_ftz(),
})); }));
statements.push(Statement::Instruction(ast::Instruction::Bra { statements.push(Statement::Instruction(ast::Instruction::Bra {
@ -1448,12 +1504,12 @@ impl<'a> BasicBlockControlState<'a> {
if let Some(prologue) = bb_state.rounding.prologue { if let Some(prologue) = bb_state.rounding.prologue {
statements.push(Statement::Label(prologue)); statements.push(Statement::Label(prologue));
statements.push(Statement::SetMode(ModeRegister::Rounding { statements.push(Statement::SetMode(ModeRegister::Rounding {
f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(), f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(),
f16f64: bb_state f16f64: bb_state
.rounding .rounding
.twin_mode .twin_mode
.f16f64 .f16f64
.unwrap_of_default() .unwrap_or_default()
.to_ast(), .to_ast(),
})); }));
statements.push(Statement::Instruction(ast::Instruction::Bra { statements.push(Statement::Instruction(ast::Instruction::Bra {
@ -1516,8 +1572,15 @@ impl<'a> BasicBlockControlState<'a> {
new_mode, new_mode,
View::TwinView::get_register(self) View::TwinView::get_register(self)
.current_value .current_value
.ok_or_else(error_unreachable)?, .unwrap_or(View::ComputeValue::default().into()),
))); )));
View::set_register(
self,
RegisterState {
current_value: Resolved::Value(new_mode),
last_foldable: None,
},
);
set_fold_index::<View::TwinView>(self, Some(result.len() - 1)); set_fold_index::<View::TwinView>(self, Some(result.len() - 1));
} }
}, },
@ -1579,7 +1642,7 @@ enum Resolved<T> {
} }
impl<T: Default> Resolved<T> { impl<T: Default> Resolved<T> {
fn unwrap_of_default(self) -> T { fn unwrap_or_default(self) -> T {
match self { match self {
Resolved::Conflict => T::default(), Resolved::Conflict => T::default(),
Resolved::Value(t) => t, Resolved::Value(t) => t,
@ -1599,6 +1662,13 @@ impl<T: Eq + PartialEq> Resolved<T> {
} }
impl<T> Resolved<T> { impl<T> Resolved<T> {
fn unwrap_or(self, if_fail: T) -> T {
match self {
Resolved::Conflict => if_fail,
Resolved::Value(t) => t,
}
}
fn map<U, F>(self, f: F) -> Resolved<U> fn map<U, F>(self, f: F) -> Resolved<U>
where where
F: FnOnce(T) -> U, F: FnOnce(T) -> U,
@ -1621,6 +1691,7 @@ impl<T> Resolved<T> {
} }
trait ModeView { trait ModeView {
type ComputeValue: Default + Into<Self::Value>;
type Value: PartialEq + Eq + Copy + Clone; type Value: PartialEq + Eq + Copy + Clone;
type TwinView: ModeView<Value = Self::Value>; type TwinView: ModeView<Value = Self::Value>;
@ -1633,6 +1704,7 @@ trait ModeView {
struct DenormalF32View; struct DenormalF32View;
impl ModeView for DenormalF32View { impl ModeView for DenormalF32View {
type ComputeValue = DenormalMode;
type Value = bool; type Value = bool;
type TwinView = DenormalF16F64View; type TwinView = DenormalF16F64View;
@ -1660,6 +1732,7 @@ impl ModeView for DenormalF32View {
struct DenormalF16F64View; struct DenormalF16F64View;
impl ModeView for DenormalF16F64View { impl ModeView for DenormalF16F64View {
type ComputeValue = DenormalMode;
type Value = bool; type Value = bool;
type TwinView = DenormalF32View; type TwinView = DenormalF32View;
@ -1687,6 +1760,7 @@ impl ModeView for DenormalF16F64View {
struct RoundingF32View; struct RoundingF32View;
impl ModeView for RoundingF32View { impl ModeView for RoundingF32View {
type ComputeValue = RoundingMode;
type Value = ast::RoundingMode; type Value = ast::RoundingMode;
type TwinView = RoundingF16F64View; type TwinView = RoundingF16F64View;
@ -1714,6 +1788,7 @@ impl ModeView for RoundingF32View {
struct RoundingF16F64View; struct RoundingF16F64View;
impl ModeView for RoundingF16F64View { impl ModeView for RoundingF16F64View {
type ComputeValue = RoundingMode;
type Value = ast::RoundingMode; type Value = ast::RoundingMode;
type TwinView = RoundingF32View; type TwinView = RoundingF32View;

View file

@ -52,9 +52,9 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?; let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?; let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?; let llvm_ir = emit_llvm::run(flat_resolver, directives)?;

View file

@ -54,7 +54,7 @@ pub(crate) fn run(
} }
} }
TerminatorKind::Fake => match statement { TerminatorKind::Fake => match statement {
// if it happens that there is a label after a call just reuse it // If there's a label after a call just reuse it
Statement::Label(label) => { Statement::Label(label) => {
result.push(Statement::Instruction(ast::Instruction::Bra { result.push(Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src: label }, arguments: ast::BraArgs { src: label },