diff --git a/ptx/src/pass/insert_ftz_control/mod.rs b/ptx/src/pass/insert_ftz_control/mod.rs index 00c2f86..24120a4 100644 --- a/ptx/src/pass/insert_ftz_control/mod.rs +++ b/ptx/src/pass/insert_ftz_control/mod.rs @@ -405,7 +405,19 @@ impl ControlFlowGraph { node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock); } - fn fixup_function_calls(&mut self) { + fn fixup_function_calls(&mut self) -> Result<(), TranslateError> { + for (fn_, follow_on_labels) in self.call_returns.iter() { + let connecting_bb = match self.functions_rets.get(fn_) { + Some(return_bb) => *return_bb, + // function is just a declaration + None => *self.basic_blocks.get(fn_).ok_or_else(error_unreachable)?, + }; + for follow_on_label in follow_on_labels { + self.graph.add_edge(connecting_bb, *follow_on_label, ()); + } + } + Ok(()) + /* for (function, source) in self.functions_rets.iter() { for target in self .call_returns @@ -418,6 +430,7 @@ impl ControlFlowGraph { self.graph.add_edge(*source, target, ()); } } + */ } } @@ -481,6 +494,7 @@ impl ResolvedControlFlowGraph { } } } + // This should happen only for orphaned basic blocks mode.map(Resolved::Value).ok_or_else(error_unreachable) } fn resolve_mode( @@ -730,7 +744,6 @@ pub(crate) fn run<'input>( _ => return Err(error_unreachable()), }; bb_state.record_call(*func, after_call_label)?; - //body_iter.next(); } Statement::RetValue(..) | Statement::Instruction(ast::Instruction::Ret { .. }) => { @@ -761,7 +774,7 @@ pub(crate) fn run<'input>( "{:?}", petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) ); - cfg.fixup_function_calls(); + cfg.fixup_function_calls()?; println!( "{:?}", petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) @@ -784,6 +797,7 @@ pub(crate) fn run<'input>( )?; let temp = join_modes2( flat_resolver, + &directives, cfg, denormal_f32, denormal_f16f64, @@ -816,6 +830,7 @@ pub(crate) fn run<'input>( fn join_modes2( flat_resolver: &mut super::GlobalStringIdentResolver2, + directives: &Vec, super::SpirvWord>>, cfg: ResolvedControlFlowGraph, mandatory_denormal_f32: MandatoryModeInsertions, mandatory_denormal_f16f64: MandatoryModeInsertions, @@ -877,6 +892,47 @@ fn join_modes2( )) }) .collect::, _>>()?; + let temp = directives + .iter() + .filter_map(|directive| match directive { + Directive2::Method(Function2 { + name, + body: None, + is_kernel: false, + .. + }) => { + let fn_bb = match cfg.basic_blocks.get(name) { + Some(bb) => bb, + None => return None, + }; + let weights = cfg.graph.node_weight(*fn_bb).unwrap(); + let modes = ResolvedInstructionModes { + denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz), + denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz), + rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast), + rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast), + }; + Some(Ok((*name, modes))) + } + Directive2::Method(Function2 { + name, + body: Some(_), + is_kernel: false, + .. + }) => { + let ret_bb = cfg.functions_rets.get(name).unwrap(); + let weights = cfg.graph.node_weight(*ret_bb).unwrap(); + let modes = ResolvedInstructionModes { + denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz), + denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz), + rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast), + rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast), + }; + Some(Ok((*name, modes))) + } + _ => None, + }) + .collect::, _>>()?; let functions_exit_modes = cfg .functions_rets .into_iter() @@ -893,7 +949,7 @@ fn join_modes2( .collect::, _>>()?; Ok(FullModeInsertion2 { basic_blocks, - functions_exit_modes, + functions_exit_modes: temp, }) } @@ -1149,13 +1205,13 @@ fn emit_mode_prelude( .denormal .twin_mode .f32 - .unwrap_of_default() + .unwrap_or_default() .to_ftz(), f16f64: fn_mode_state .denormal .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ftz(), }, ModeRegister::Rounding { @@ -1163,13 +1219,13 @@ fn emit_mode_prelude( .rounding .twin_mode .f32 - .unwrap_of_default() + .unwrap_or_default() .to_ast(), f16f64: fn_mode_state .rounding .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ast(), }, ] @@ -1186,13 +1242,13 @@ fn emit_mode_prelude( .denormal .twin_mode .f32 - .unwrap_of_default() + .unwrap_or_default() .to_ftz(), f16f64: fn_mode_state .denormal .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ftz(), }] .into_iter(), @@ -1208,13 +1264,13 @@ fn emit_mode_prelude( .rounding .twin_mode .f32 - .unwrap_of_default() + .unwrap_or_default() .to_ast(), f16f64: fn_mode_state .rounding .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ast(), }] .into_iter(), @@ -1409,21 +1465,21 @@ impl<'a> BasicBlockControlState<'a> { if let Some(prologue) = bb_state.dual_prologue { statements.push(Statement::Label(prologue)); statements.push(Statement::SetMode(ModeRegister::Denormal { - f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(), + f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(), f16f64: bb_state .denormal .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ftz(), })); statements.push(Statement::SetMode(ModeRegister::Rounding { - f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(), + f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(), f16f64: bb_state .rounding .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ast(), })); statements.push(Statement::Instruction(ast::Instruction::Bra { @@ -1433,12 +1489,12 @@ impl<'a> BasicBlockControlState<'a> { if let Some(prologue) = bb_state.denormal.prologue { statements.push(Statement::Label(prologue)); statements.push(Statement::SetMode(ModeRegister::Denormal { - f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(), + f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(), f16f64: bb_state .denormal .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ftz(), })); statements.push(Statement::Instruction(ast::Instruction::Bra { @@ -1448,12 +1504,12 @@ impl<'a> BasicBlockControlState<'a> { if let Some(prologue) = bb_state.rounding.prologue { statements.push(Statement::Label(prologue)); statements.push(Statement::SetMode(ModeRegister::Rounding { - f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(), + f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(), f16f64: bb_state .rounding .twin_mode .f16f64 - .unwrap_of_default() + .unwrap_or_default() .to_ast(), })); statements.push(Statement::Instruction(ast::Instruction::Bra { @@ -1516,8 +1572,15 @@ impl<'a> BasicBlockControlState<'a> { new_mode, View::TwinView::get_register(self) .current_value - .ok_or_else(error_unreachable)?, + .unwrap_or(View::ComputeValue::default().into()), ))); + View::set_register( + self, + RegisterState { + current_value: Resolved::Value(new_mode), + last_foldable: None, + }, + ); set_fold_index::(self, Some(result.len() - 1)); } }, @@ -1579,7 +1642,7 @@ enum Resolved { } impl Resolved { - fn unwrap_of_default(self) -> T { + fn unwrap_or_default(self) -> T { match self { Resolved::Conflict => T::default(), Resolved::Value(t) => t, @@ -1599,6 +1662,13 @@ impl Resolved { } impl Resolved { + fn unwrap_or(self, if_fail: T) -> T { + match self { + Resolved::Conflict => if_fail, + Resolved::Value(t) => t, + } + } + fn map(self, f: F) -> Resolved where F: FnOnce(T) -> U, @@ -1621,6 +1691,7 @@ impl Resolved { } trait ModeView { + type ComputeValue: Default + Into; type Value: PartialEq + Eq + Copy + Clone; type TwinView: ModeView; @@ -1633,6 +1704,7 @@ trait ModeView { struct DenormalF32View; impl ModeView for DenormalF32View { + type ComputeValue = DenormalMode; type Value = bool; type TwinView = DenormalF16F64View; @@ -1660,6 +1732,7 @@ impl ModeView for DenormalF32View { struct DenormalF16F64View; impl ModeView for DenormalF16F64View { + type ComputeValue = DenormalMode; type Value = bool; type TwinView = DenormalF32View; @@ -1687,6 +1760,7 @@ impl ModeView for DenormalF16F64View { struct RoundingF32View; impl ModeView for RoundingF32View { + type ComputeValue = RoundingMode; type Value = ast::RoundingMode; type TwinView = RoundingF16F64View; @@ -1714,6 +1788,7 @@ impl ModeView for RoundingF32View { struct RoundingF16F64View; impl ModeView for RoundingF16F64View { + type ComputeValue = RoundingMode; type Value = ast::RoundingMode; type TwinView = RoundingF32View; diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 10741c0..20a716e 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -52,9 +52,9 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result match statement { - // if it happens that there is a label after a call just reuse it + // If there's a label after a call just reuse it Statement::Label(label) => { result.push(Statement::Instruction(ast::Instruction::Bra { arguments: ast::BraArgs { src: label },