mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Fix most of the remaining bugs
This commit is contained in:
parent
87fe601494
commit
04fbfea80a
3 changed files with 99 additions and 24 deletions
|
@ -405,7 +405,19 @@ impl ControlFlowGraph {
|
|||
node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock);
|
||||
}
|
||||
|
||||
fn fixup_function_calls(&mut self) {
|
||||
fn fixup_function_calls(&mut self) -> Result<(), TranslateError> {
|
||||
for (fn_, follow_on_labels) in self.call_returns.iter() {
|
||||
let connecting_bb = match self.functions_rets.get(fn_) {
|
||||
Some(return_bb) => *return_bb,
|
||||
// function is just a declaration
|
||||
None => *self.basic_blocks.get(fn_).ok_or_else(error_unreachable)?,
|
||||
};
|
||||
for follow_on_label in follow_on_labels {
|
||||
self.graph.add_edge(connecting_bb, *follow_on_label, ());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
/*
|
||||
for (function, source) in self.functions_rets.iter() {
|
||||
for target in self
|
||||
.call_returns
|
||||
|
@ -418,6 +430,7 @@ impl ControlFlowGraph {
|
|||
self.graph.add_edge(*source, target, ());
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -481,6 +494,7 @@ impl ResolvedControlFlowGraph {
|
|||
}
|
||||
}
|
||||
}
|
||||
// This should happen only for orphaned basic blocks
|
||||
mode.map(Resolved::Value).ok_or_else(error_unreachable)
|
||||
}
|
||||
fn resolve_mode<T: Eq + PartialEq + Copy + Default>(
|
||||
|
@ -730,7 +744,6 @@ pub(crate) fn run<'input>(
|
|||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
bb_state.record_call(*func, after_call_label)?;
|
||||
//body_iter.next();
|
||||
}
|
||||
Statement::RetValue(..)
|
||||
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||
|
@ -761,7 +774,7 @@ pub(crate) fn run<'input>(
|
|||
"{:?}",
|
||||
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
||||
);
|
||||
cfg.fixup_function_calls();
|
||||
cfg.fixup_function_calls()?;
|
||||
println!(
|
||||
"{:?}",
|
||||
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
||||
|
@ -784,6 +797,7 @@ pub(crate) fn run<'input>(
|
|||
)?;
|
||||
let temp = join_modes2(
|
||||
flat_resolver,
|
||||
&directives,
|
||||
cfg,
|
||||
denormal_f32,
|
||||
denormal_f16f64,
|
||||
|
@ -816,6 +830,7 @@ pub(crate) fn run<'input>(
|
|||
|
||||
fn join_modes2(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||
directives: &Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||
cfg: ResolvedControlFlowGraph,
|
||||
mandatory_denormal_f32: MandatoryModeInsertions<DenormalMode>,
|
||||
mandatory_denormal_f16f64: MandatoryModeInsertions<DenormalMode>,
|
||||
|
@ -877,6 +892,47 @@ fn join_modes2(
|
|||
))
|
||||
})
|
||||
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
||||
let temp = directives
|
||||
.iter()
|
||||
.filter_map(|directive| match directive {
|
||||
Directive2::Method(Function2 {
|
||||
name,
|
||||
body: None,
|
||||
is_kernel: false,
|
||||
..
|
||||
}) => {
|
||||
let fn_bb = match cfg.basic_blocks.get(name) {
|
||||
Some(bb) => bb,
|
||||
None => return None,
|
||||
};
|
||||
let weights = cfg.graph.node_weight(*fn_bb).unwrap();
|
||||
let modes = ResolvedInstructionModes {
|
||||
denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz),
|
||||
denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz),
|
||||
rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast),
|
||||
rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast),
|
||||
};
|
||||
Some(Ok((*name, modes)))
|
||||
}
|
||||
Directive2::Method(Function2 {
|
||||
name,
|
||||
body: Some(_),
|
||||
is_kernel: false,
|
||||
..
|
||||
}) => {
|
||||
let ret_bb = cfg.functions_rets.get(name).unwrap();
|
||||
let weights = cfg.graph.node_weight(*ret_bb).unwrap();
|
||||
let modes = ResolvedInstructionModes {
|
||||
denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz),
|
||||
denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz),
|
||||
rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast),
|
||||
rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast),
|
||||
};
|
||||
Some(Ok((*name, modes)))
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
||||
let functions_exit_modes = cfg
|
||||
.functions_rets
|
||||
.into_iter()
|
||||
|
@ -893,7 +949,7 @@ fn join_modes2(
|
|||
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
||||
Ok(FullModeInsertion2 {
|
||||
basic_blocks,
|
||||
functions_exit_modes,
|
||||
functions_exit_modes: temp,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1149,13 +1205,13 @@ fn emit_mode_prelude(
|
|||
.denormal
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ftz(),
|
||||
f16f64: fn_mode_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ftz(),
|
||||
},
|
||||
ModeRegister::Rounding {
|
||||
|
@ -1163,13 +1219,13 @@ fn emit_mode_prelude(
|
|||
.rounding
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ast(),
|
||||
f16f64: fn_mode_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ast(),
|
||||
},
|
||||
]
|
||||
|
@ -1186,13 +1242,13 @@ fn emit_mode_prelude(
|
|||
.denormal
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ftz(),
|
||||
f16f64: fn_mode_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ftz(),
|
||||
}]
|
||||
.into_iter(),
|
||||
|
@ -1208,13 +1264,13 @@ fn emit_mode_prelude(
|
|||
.rounding
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ast(),
|
||||
f16f64: fn_mode_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ast(),
|
||||
}]
|
||||
.into_iter(),
|
||||
|
@ -1409,21 +1465,21 @@ impl<'a> BasicBlockControlState<'a> {
|
|||
if let Some(prologue) = bb_state.dual_prologue {
|
||||
statements.push(Statement::Label(prologue));
|
||||
statements.push(Statement::SetMode(ModeRegister::Denormal {
|
||||
f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(),
|
||||
f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(),
|
||||
f16f64: bb_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ftz(),
|
||||
}));
|
||||
statements.push(Statement::SetMode(ModeRegister::Rounding {
|
||||
f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(),
|
||||
f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(),
|
||||
f16f64: bb_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ast(),
|
||||
}));
|
||||
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
||||
|
@ -1433,12 +1489,12 @@ impl<'a> BasicBlockControlState<'a> {
|
|||
if let Some(prologue) = bb_state.denormal.prologue {
|
||||
statements.push(Statement::Label(prologue));
|
||||
statements.push(Statement::SetMode(ModeRegister::Denormal {
|
||||
f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(),
|
||||
f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(),
|
||||
f16f64: bb_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ftz(),
|
||||
}));
|
||||
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
||||
|
@ -1448,12 +1504,12 @@ impl<'a> BasicBlockControlState<'a> {
|
|||
if let Some(prologue) = bb_state.rounding.prologue {
|
||||
statements.push(Statement::Label(prologue));
|
||||
statements.push(Statement::SetMode(ModeRegister::Rounding {
|
||||
f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(),
|
||||
f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(),
|
||||
f16f64: bb_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.unwrap_or_default()
|
||||
.to_ast(),
|
||||
}));
|
||||
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
||||
|
@ -1516,8 +1572,15 @@ impl<'a> BasicBlockControlState<'a> {
|
|||
new_mode,
|
||||
View::TwinView::get_register(self)
|
||||
.current_value
|
||||
.ok_or_else(error_unreachable)?,
|
||||
.unwrap_or(View::ComputeValue::default().into()),
|
||||
)));
|
||||
View::set_register(
|
||||
self,
|
||||
RegisterState {
|
||||
current_value: Resolved::Value(new_mode),
|
||||
last_foldable: None,
|
||||
},
|
||||
);
|
||||
set_fold_index::<View::TwinView>(self, Some(result.len() - 1));
|
||||
}
|
||||
},
|
||||
|
@ -1579,7 +1642,7 @@ enum Resolved<T> {
|
|||
}
|
||||
|
||||
impl<T: Default> Resolved<T> {
|
||||
fn unwrap_of_default(self) -> T {
|
||||
fn unwrap_or_default(self) -> T {
|
||||
match self {
|
||||
Resolved::Conflict => T::default(),
|
||||
Resolved::Value(t) => t,
|
||||
|
@ -1599,6 +1662,13 @@ impl<T: Eq + PartialEq> Resolved<T> {
|
|||
}
|
||||
|
||||
impl<T> Resolved<T> {
|
||||
fn unwrap_or(self, if_fail: T) -> T {
|
||||
match self {
|
||||
Resolved::Conflict => if_fail,
|
||||
Resolved::Value(t) => t,
|
||||
}
|
||||
}
|
||||
|
||||
fn map<U, F>(self, f: F) -> Resolved<U>
|
||||
where
|
||||
F: FnOnce(T) -> U,
|
||||
|
@ -1621,6 +1691,7 @@ impl<T> Resolved<T> {
|
|||
}
|
||||
|
||||
trait ModeView {
|
||||
type ComputeValue: Default + Into<Self::Value>;
|
||||
type Value: PartialEq + Eq + Copy + Clone;
|
||||
type TwinView: ModeView<Value = Self::Value>;
|
||||
|
||||
|
@ -1633,6 +1704,7 @@ trait ModeView {
|
|||
struct DenormalF32View;
|
||||
|
||||
impl ModeView for DenormalF32View {
|
||||
type ComputeValue = DenormalMode;
|
||||
type Value = bool;
|
||||
type TwinView = DenormalF16F64View;
|
||||
|
||||
|
@ -1660,6 +1732,7 @@ impl ModeView for DenormalF32View {
|
|||
struct DenormalF16F64View;
|
||||
|
||||
impl ModeView for DenormalF16F64View {
|
||||
type ComputeValue = DenormalMode;
|
||||
type Value = bool;
|
||||
type TwinView = DenormalF32View;
|
||||
|
||||
|
@ -1687,6 +1760,7 @@ impl ModeView for DenormalF16F64View {
|
|||
struct RoundingF32View;
|
||||
|
||||
impl ModeView for RoundingF32View {
|
||||
type ComputeValue = RoundingMode;
|
||||
type Value = ast::RoundingMode;
|
||||
type TwinView = RoundingF16F64View;
|
||||
|
||||
|
@ -1714,6 +1788,7 @@ impl ModeView for RoundingF32View {
|
|||
struct RoundingF16F64View;
|
||||
|
||||
impl ModeView for RoundingF16F64View {
|
||||
type ComputeValue = RoundingMode;
|
||||
type Value = ast::RoundingMode;
|
||||
type TwinView = RoundingF32View;
|
||||
|
||||
|
|
|
@ -52,9 +52,9 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
|||
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
||||
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
||||
let directives = hoist_globals::run(directives)?;
|
||||
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
|
||||
|
|
|
@ -54,7 +54,7 @@ pub(crate) fn run(
|
|||
}
|
||||
}
|
||||
TerminatorKind::Fake => match statement {
|
||||
// if it happens that there is a label after a call just reuse it
|
||||
// If there's a label after a call just reuse it
|
||||
Statement::Label(label) => {
|
||||
result.push(Statement::Instruction(ast::Instruction::Bra {
|
||||
arguments: ast::BraArgs { src: label },
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue