mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Fix most of the remaining bugs
This commit is contained in:
parent
87fe601494
commit
04fbfea80a
3 changed files with 99 additions and 24 deletions
|
@ -405,7 +405,19 @@ impl ControlFlowGraph {
|
||||||
node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock);
|
node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fixup_function_calls(&mut self) {
|
fn fixup_function_calls(&mut self) -> Result<(), TranslateError> {
|
||||||
|
for (fn_, follow_on_labels) in self.call_returns.iter() {
|
||||||
|
let connecting_bb = match self.functions_rets.get(fn_) {
|
||||||
|
Some(return_bb) => *return_bb,
|
||||||
|
// function is just a declaration
|
||||||
|
None => *self.basic_blocks.get(fn_).ok_or_else(error_unreachable)?,
|
||||||
|
};
|
||||||
|
for follow_on_label in follow_on_labels {
|
||||||
|
self.graph.add_edge(connecting_bb, *follow_on_label, ());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
/*
|
||||||
for (function, source) in self.functions_rets.iter() {
|
for (function, source) in self.functions_rets.iter() {
|
||||||
for target in self
|
for target in self
|
||||||
.call_returns
|
.call_returns
|
||||||
|
@ -418,6 +430,7 @@ impl ControlFlowGraph {
|
||||||
self.graph.add_edge(*source, target, ());
|
self.graph.add_edge(*source, target, ());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -481,6 +494,7 @@ impl ResolvedControlFlowGraph {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// This should happen only for orphaned basic blocks
|
||||||
mode.map(Resolved::Value).ok_or_else(error_unreachable)
|
mode.map(Resolved::Value).ok_or_else(error_unreachable)
|
||||||
}
|
}
|
||||||
fn resolve_mode<T: Eq + PartialEq + Copy + Default>(
|
fn resolve_mode<T: Eq + PartialEq + Copy + Default>(
|
||||||
|
@ -730,7 +744,6 @@ pub(crate) fn run<'input>(
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
};
|
};
|
||||||
bb_state.record_call(*func, after_call_label)?;
|
bb_state.record_call(*func, after_call_label)?;
|
||||||
//body_iter.next();
|
|
||||||
}
|
}
|
||||||
Statement::RetValue(..)
|
Statement::RetValue(..)
|
||||||
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||||
|
@ -761,7 +774,7 @@ pub(crate) fn run<'input>(
|
||||||
"{:?}",
|
"{:?}",
|
||||||
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
||||||
);
|
);
|
||||||
cfg.fixup_function_calls();
|
cfg.fixup_function_calls()?;
|
||||||
println!(
|
println!(
|
||||||
"{:?}",
|
"{:?}",
|
||||||
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
||||||
|
@ -784,6 +797,7 @@ pub(crate) fn run<'input>(
|
||||||
)?;
|
)?;
|
||||||
let temp = join_modes2(
|
let temp = join_modes2(
|
||||||
flat_resolver,
|
flat_resolver,
|
||||||
|
&directives,
|
||||||
cfg,
|
cfg,
|
||||||
denormal_f32,
|
denormal_f32,
|
||||||
denormal_f16f64,
|
denormal_f16f64,
|
||||||
|
@ -816,6 +830,7 @@ pub(crate) fn run<'input>(
|
||||||
|
|
||||||
fn join_modes2(
|
fn join_modes2(
|
||||||
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||||
|
directives: &Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||||
cfg: ResolvedControlFlowGraph,
|
cfg: ResolvedControlFlowGraph,
|
||||||
mandatory_denormal_f32: MandatoryModeInsertions<DenormalMode>,
|
mandatory_denormal_f32: MandatoryModeInsertions<DenormalMode>,
|
||||||
mandatory_denormal_f16f64: MandatoryModeInsertions<DenormalMode>,
|
mandatory_denormal_f16f64: MandatoryModeInsertions<DenormalMode>,
|
||||||
|
@ -877,6 +892,47 @@ fn join_modes2(
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
||||||
|
let temp = directives
|
||||||
|
.iter()
|
||||||
|
.filter_map(|directive| match directive {
|
||||||
|
Directive2::Method(Function2 {
|
||||||
|
name,
|
||||||
|
body: None,
|
||||||
|
is_kernel: false,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
let fn_bb = match cfg.basic_blocks.get(name) {
|
||||||
|
Some(bb) => bb,
|
||||||
|
None => return None,
|
||||||
|
};
|
||||||
|
let weights = cfg.graph.node_weight(*fn_bb).unwrap();
|
||||||
|
let modes = ResolvedInstructionModes {
|
||||||
|
denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz),
|
||||||
|
denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz),
|
||||||
|
rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast),
|
||||||
|
rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast),
|
||||||
|
};
|
||||||
|
Some(Ok((*name, modes)))
|
||||||
|
}
|
||||||
|
Directive2::Method(Function2 {
|
||||||
|
name,
|
||||||
|
body: Some(_),
|
||||||
|
is_kernel: false,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
let ret_bb = cfg.functions_rets.get(name).unwrap();
|
||||||
|
let weights = cfg.graph.node_weight(*ret_bb).unwrap();
|
||||||
|
let modes = ResolvedInstructionModes {
|
||||||
|
denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz),
|
||||||
|
denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz),
|
||||||
|
rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast),
|
||||||
|
rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast),
|
||||||
|
};
|
||||||
|
Some(Ok((*name, modes)))
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
||||||
let functions_exit_modes = cfg
|
let functions_exit_modes = cfg
|
||||||
.functions_rets
|
.functions_rets
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -893,7 +949,7 @@ fn join_modes2(
|
||||||
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
.collect::<Result<FxHashMap<_, _>, _>>()?;
|
||||||
Ok(FullModeInsertion2 {
|
Ok(FullModeInsertion2 {
|
||||||
basic_blocks,
|
basic_blocks,
|
||||||
functions_exit_modes,
|
functions_exit_modes: temp,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1149,13 +1205,13 @@ fn emit_mode_prelude(
|
||||||
.denormal
|
.denormal
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f32
|
.f32
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ftz(),
|
.to_ftz(),
|
||||||
f16f64: fn_mode_state
|
f16f64: fn_mode_state
|
||||||
.denormal
|
.denormal
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ftz(),
|
.to_ftz(),
|
||||||
},
|
},
|
||||||
ModeRegister::Rounding {
|
ModeRegister::Rounding {
|
||||||
|
@ -1163,13 +1219,13 @@ fn emit_mode_prelude(
|
||||||
.rounding
|
.rounding
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f32
|
.f32
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ast(),
|
.to_ast(),
|
||||||
f16f64: fn_mode_state
|
f16f64: fn_mode_state
|
||||||
.rounding
|
.rounding
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ast(),
|
.to_ast(),
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
@ -1186,13 +1242,13 @@ fn emit_mode_prelude(
|
||||||
.denormal
|
.denormal
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f32
|
.f32
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ftz(),
|
.to_ftz(),
|
||||||
f16f64: fn_mode_state
|
f16f64: fn_mode_state
|
||||||
.denormal
|
.denormal
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ftz(),
|
.to_ftz(),
|
||||||
}]
|
}]
|
||||||
.into_iter(),
|
.into_iter(),
|
||||||
|
@ -1208,13 +1264,13 @@ fn emit_mode_prelude(
|
||||||
.rounding
|
.rounding
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f32
|
.f32
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ast(),
|
.to_ast(),
|
||||||
f16f64: fn_mode_state
|
f16f64: fn_mode_state
|
||||||
.rounding
|
.rounding
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ast(),
|
.to_ast(),
|
||||||
}]
|
}]
|
||||||
.into_iter(),
|
.into_iter(),
|
||||||
|
@ -1409,21 +1465,21 @@ impl<'a> BasicBlockControlState<'a> {
|
||||||
if let Some(prologue) = bb_state.dual_prologue {
|
if let Some(prologue) = bb_state.dual_prologue {
|
||||||
statements.push(Statement::Label(prologue));
|
statements.push(Statement::Label(prologue));
|
||||||
statements.push(Statement::SetMode(ModeRegister::Denormal {
|
statements.push(Statement::SetMode(ModeRegister::Denormal {
|
||||||
f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(),
|
f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(),
|
||||||
f16f64: bb_state
|
f16f64: bb_state
|
||||||
.denormal
|
.denormal
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ftz(),
|
.to_ftz(),
|
||||||
}));
|
}));
|
||||||
statements.push(Statement::SetMode(ModeRegister::Rounding {
|
statements.push(Statement::SetMode(ModeRegister::Rounding {
|
||||||
f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(),
|
f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(),
|
||||||
f16f64: bb_state
|
f16f64: bb_state
|
||||||
.rounding
|
.rounding
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ast(),
|
.to_ast(),
|
||||||
}));
|
}));
|
||||||
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
||||||
|
@ -1433,12 +1489,12 @@ impl<'a> BasicBlockControlState<'a> {
|
||||||
if let Some(prologue) = bb_state.denormal.prologue {
|
if let Some(prologue) = bb_state.denormal.prologue {
|
||||||
statements.push(Statement::Label(prologue));
|
statements.push(Statement::Label(prologue));
|
||||||
statements.push(Statement::SetMode(ModeRegister::Denormal {
|
statements.push(Statement::SetMode(ModeRegister::Denormal {
|
||||||
f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(),
|
f32: bb_state.denormal.twin_mode.f32.unwrap_or_default().to_ftz(),
|
||||||
f16f64: bb_state
|
f16f64: bb_state
|
||||||
.denormal
|
.denormal
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ftz(),
|
.to_ftz(),
|
||||||
}));
|
}));
|
||||||
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
||||||
|
@ -1448,12 +1504,12 @@ impl<'a> BasicBlockControlState<'a> {
|
||||||
if let Some(prologue) = bb_state.rounding.prologue {
|
if let Some(prologue) = bb_state.rounding.prologue {
|
||||||
statements.push(Statement::Label(prologue));
|
statements.push(Statement::Label(prologue));
|
||||||
statements.push(Statement::SetMode(ModeRegister::Rounding {
|
statements.push(Statement::SetMode(ModeRegister::Rounding {
|
||||||
f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(),
|
f32: bb_state.rounding.twin_mode.f32.unwrap_or_default().to_ast(),
|
||||||
f16f64: bb_state
|
f16f64: bb_state
|
||||||
.rounding
|
.rounding
|
||||||
.twin_mode
|
.twin_mode
|
||||||
.f16f64
|
.f16f64
|
||||||
.unwrap_of_default()
|
.unwrap_or_default()
|
||||||
.to_ast(),
|
.to_ast(),
|
||||||
}));
|
}));
|
||||||
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
statements.push(Statement::Instruction(ast::Instruction::Bra {
|
||||||
|
@ -1516,8 +1572,15 @@ impl<'a> BasicBlockControlState<'a> {
|
||||||
new_mode,
|
new_mode,
|
||||||
View::TwinView::get_register(self)
|
View::TwinView::get_register(self)
|
||||||
.current_value
|
.current_value
|
||||||
.ok_or_else(error_unreachable)?,
|
.unwrap_or(View::ComputeValue::default().into()),
|
||||||
)));
|
)));
|
||||||
|
View::set_register(
|
||||||
|
self,
|
||||||
|
RegisterState {
|
||||||
|
current_value: Resolved::Value(new_mode),
|
||||||
|
last_foldable: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
set_fold_index::<View::TwinView>(self, Some(result.len() - 1));
|
set_fold_index::<View::TwinView>(self, Some(result.len() - 1));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1579,7 +1642,7 @@ enum Resolved<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Default> Resolved<T> {
|
impl<T: Default> Resolved<T> {
|
||||||
fn unwrap_of_default(self) -> T {
|
fn unwrap_or_default(self) -> T {
|
||||||
match self {
|
match self {
|
||||||
Resolved::Conflict => T::default(),
|
Resolved::Conflict => T::default(),
|
||||||
Resolved::Value(t) => t,
|
Resolved::Value(t) => t,
|
||||||
|
@ -1599,6 +1662,13 @@ impl<T: Eq + PartialEq> Resolved<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Resolved<T> {
|
impl<T> Resolved<T> {
|
||||||
|
fn unwrap_or(self, if_fail: T) -> T {
|
||||||
|
match self {
|
||||||
|
Resolved::Conflict => if_fail,
|
||||||
|
Resolved::Value(t) => t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn map<U, F>(self, f: F) -> Resolved<U>
|
fn map<U, F>(self, f: F) -> Resolved<U>
|
||||||
where
|
where
|
||||||
F: FnOnce(T) -> U,
|
F: FnOnce(T) -> U,
|
||||||
|
@ -1621,6 +1691,7 @@ impl<T> Resolved<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
trait ModeView {
|
trait ModeView {
|
||||||
|
type ComputeValue: Default + Into<Self::Value>;
|
||||||
type Value: PartialEq + Eq + Copy + Clone;
|
type Value: PartialEq + Eq + Copy + Clone;
|
||||||
type TwinView: ModeView<Value = Self::Value>;
|
type TwinView: ModeView<Value = Self::Value>;
|
||||||
|
|
||||||
|
@ -1633,6 +1704,7 @@ trait ModeView {
|
||||||
struct DenormalF32View;
|
struct DenormalF32View;
|
||||||
|
|
||||||
impl ModeView for DenormalF32View {
|
impl ModeView for DenormalF32View {
|
||||||
|
type ComputeValue = DenormalMode;
|
||||||
type Value = bool;
|
type Value = bool;
|
||||||
type TwinView = DenormalF16F64View;
|
type TwinView = DenormalF16F64View;
|
||||||
|
|
||||||
|
@ -1660,6 +1732,7 @@ impl ModeView for DenormalF32View {
|
||||||
struct DenormalF16F64View;
|
struct DenormalF16F64View;
|
||||||
|
|
||||||
impl ModeView for DenormalF16F64View {
|
impl ModeView for DenormalF16F64View {
|
||||||
|
type ComputeValue = DenormalMode;
|
||||||
type Value = bool;
|
type Value = bool;
|
||||||
type TwinView = DenormalF32View;
|
type TwinView = DenormalF32View;
|
||||||
|
|
||||||
|
@ -1687,6 +1760,7 @@ impl ModeView for DenormalF16F64View {
|
||||||
struct RoundingF32View;
|
struct RoundingF32View;
|
||||||
|
|
||||||
impl ModeView for RoundingF32View {
|
impl ModeView for RoundingF32View {
|
||||||
|
type ComputeValue = RoundingMode;
|
||||||
type Value = ast::RoundingMode;
|
type Value = ast::RoundingMode;
|
||||||
type TwinView = RoundingF16F64View;
|
type TwinView = RoundingF16F64View;
|
||||||
|
|
||||||
|
@ -1714,6 +1788,7 @@ impl ModeView for RoundingF32View {
|
||||||
struct RoundingF16F64View;
|
struct RoundingF16F64View;
|
||||||
|
|
||||||
impl ModeView for RoundingF16F64View {
|
impl ModeView for RoundingF16F64View {
|
||||||
|
type ComputeValue = RoundingMode;
|
||||||
type Value = ast::RoundingMode;
|
type Value = ast::RoundingMode;
|
||||||
type TwinView = RoundingF32View;
|
type TwinView = RoundingF32View;
|
||||||
|
|
||||||
|
|
|
@ -52,9 +52,9 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
||||||
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
||||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||||
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
|
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
|
||||||
|
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
||||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||||
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
|
||||||
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
||||||
let directives = hoist_globals::run(directives)?;
|
let directives = hoist_globals::run(directives)?;
|
||||||
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
|
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
|
||||||
|
|
|
@ -54,7 +54,7 @@ pub(crate) fn run(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TerminatorKind::Fake => match statement {
|
TerminatorKind::Fake => match statement {
|
||||||
// if it happens that there is a label after a call just reuse it
|
// If there's a label after a call just reuse it
|
||||||
Statement::Label(label) => {
|
Statement::Label(label) => {
|
||||||
result.push(Statement::Instruction(ast::Instruction::Bra {
|
result.push(Statement::Instruction(ast::Instruction::Bra {
|
||||||
arguments: ast::BraArgs { src: label },
|
arguments: ast::BraArgs { src: label },
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue