Bugfixing

This commit is contained in:
Andrzej Janik 2025-03-10 22:27:08 +00:00
parent 7bd26aa480
commit c86473b396
5 changed files with 192 additions and 44 deletions

View file

@ -10,12 +10,20 @@
add.rz.ftz.f32 temp, temp, temp;
call use_modes;
add.rp.ftz.f32 temp, temp, temp;
ret;
}
.func use_modes()
{
.reg .f32 temp;
.reg .pred pred;
@pred bra SET_RM;
@!pred bra SET_RZ;
SET_RM:
add.rm.f32 temp, temp, temp;
ret;
SET_RZ:
add.rz.f32 temp, temp, temp;
ret;
}

View file

@ -20,7 +20,6 @@ use smallvec::SmallVec;
use std::hash::Hash;
use std::iter;
use std::mem;
use std::u32;
use strum::EnumCount;
use strum_macros::{EnumCount, VariantArray};
use unwrap_or::unwrap_some_or;
@ -250,7 +249,7 @@ struct ControlFlowGraph {
// map function -> return label
call_returns: FxHashMap<SpirvWord, Vec<NodeIndex>>,
// map function -> return basic blocks
function_rets: FxHashMap<SpirvWord, Vec<NodeIndex>>,
functions_rets: FxHashMap<SpirvWord, NodeIndex>,
graph: Graph<Node, ()>,
}
@ -260,7 +259,7 @@ impl ControlFlowGraph {
entry_points: FxHashMap::default(),
basic_blocks: FxHashMap::default(),
call_returns: FxHashMap::default(),
function_rets: FxHashMap::default(),
functions_rets: FxHashMap::default(),
graph: Graph::new(),
}
}
@ -298,7 +297,7 @@ impl ControlFlowGraph {
}
fn fixup_function_calls(&mut self) {
for (function, sources) in self.function_rets.iter() {
for (function, source) in self.functions_rets.iter() {
for target in self
.call_returns
.get(function)
@ -307,15 +306,15 @@ impl ControlFlowGraph {
.flatten()
.copied()
{
for source in sources {
self.graph.add_edge(*source, target, ());
}
self.graph.add_edge(*source, target, ());
}
}
}
}
#[derive(Clone, Copy)]
//#[cfg_attr(test, derive(Debug))]
#[derive(Debug)]
struct Mode<T: Eq + PartialEq> {
entry: Option<ExtendedMode<T>>,
exit: Option<ExtendedMode<T>>,
@ -337,6 +336,8 @@ impl<T: Eq + PartialEq> Mode<T> {
}
}
//#[cfg_attr(test, derive(Debug))]
#[derive(Debug)]
struct Node {
label: SpirvWord,
denormal_f32: Mode<DenormalMode>,
@ -376,7 +377,7 @@ trait EnumTuple {
pub(crate) fn run<'input>(
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
mut directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut cfg = ControlFlowGraph::new();
for directive in directives.iter() {
@ -398,13 +399,17 @@ pub(crate) fn run<'input>(
arguments: ast::CallArgs { func, .. },
..
}) => {
let after_call_label = match body_iter.peek() {
Some(Statement::Label(l)) => *l,
let after_call_label = match body_iter.next() {
Some(Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src },
})) => *src,
_ => return Err(error_unreachable()),
};
bb_state.record_call(*func, after_call_label)?;
//body_iter.next();
}
Statement::Instruction(ast::Instruction::Ret { .. }) => {
Statement::RetValue(..)
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
bb_state.record_ret(*name)?;
}
Statement::Label(label) => {
@ -426,7 +431,15 @@ pub(crate) fn run<'input>(
_ => {}
}
}
println!(
"{:?}",
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
);
cfg.fixup_function_calls();
println!(
"{:?}",
petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel])
);
let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32);
let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64);
let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32);
@ -434,7 +447,8 @@ pub(crate) fn run<'input>(
let denormal_f32 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f32);
let denormal_f16f64 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f16f64);
let rounding_f32 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f32);
let rounding_f16f64 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f16f64);
let rounding_f16f64: MandatoryModeInsertions<RoundingMode> =
optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f16f64);
let denormal = join_modes(
flat_resolver,
&cfg,
@ -483,7 +497,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>(
mut f16f64_exit_view: impl FnMut(&Node) -> Option<ExtendedMode<T>>,
) -> Result<TwinModeInsertions<T>, TranslateError> {
// Returns None if there are multiple conflicting modes
fn get_incoming_mode<T: Eq + PartialEq + Copy>(
fn get_incoming_mode<T: Eq + PartialEq + Copy + Default>(
cfg: &ControlFlowGraph,
kernels: &FxHashMap<SpirvWord, T>,
node: NodeIndex,
@ -500,11 +514,11 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>(
if !visited.insert(node) {
continue;
}
let x = &cfg.graph[node];
match (mode, exit_getter(x)) {
let node_data = &cfg.graph[node];
match (mode, exit_getter(node_data)) {
(_, None) => {
for next in cfg.graph.neighbors_directed(node, Direction::Incoming) {
if !visited.insert(next) {
if !visited.contains(&next) {
to_visit.push(next);
}
}
@ -513,7 +527,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>(
let new_mode = match new_mode {
ExtendedMode::BasicBlock(new_mode) => new_mode,
ExtendedMode::Entry(kernel) => {
*kernels.get(&kernel).ok_or_else(error_unreachable)?
kernels.get(&kernel).copied().unwrap_or_default()
}
};
if let Some(existing_mode) = existing_mode {
@ -546,7 +560,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>(
.kernels
.get(&kernel)
.copied()
.ok_or_else(error_unreachable)?,
.unwrap_or_default(),
),
// None means that no instruction in the basic block sets mode, but
// another basic block might rely on this instruction transitively
@ -560,7 +574,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>(
.kernels
.get(&kernel)
.copied()
.ok_or_else(error_unreachable)?,
.unwrap_or_default(),
),
None => None,
};
@ -713,7 +727,9 @@ fn insert_mode_control<'input>(
let old_body = mem::replace(body_ptr, Vec::new());
let mut result = Vec::with_capacity(old_body.len());
let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode);
for mut statement in old_body.into_iter() {
let mut old_body = old_body.into_iter();
while let Some(mut statement) = old_body.next() {
let mut call_target = None;
match &mut statement {
Statement::Label(label) => {
bb_state.start(*label, &mut result)?;
@ -723,6 +739,7 @@ fn insert_mode_control<'input>(
..
}) => {
bb_state.redirect_jump(func)?;
call_target = Some(*func);
}
Statement::Conditional(BrachCondition {
if_true, if_false, ..
@ -742,6 +759,16 @@ fn insert_mode_control<'input>(
_ => {}
}
result.push(statement);
if let Some(call_target) = call_target {
if let Some(Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src: post_call_label },
})) = old_body.next()
{
// get return block for the function, if there is a mode
// change between caller and callee then apply it here
todo!()
}
}
}
*body_ptr = result;
new_directives.push(directive);
@ -1165,8 +1192,8 @@ impl<'a> BasicBlockState<'a> {
fn_call: SpirvWord,
after_call_label: SpirvWord,
) -> Result<(), TranslateError> {
let node_index = self.node_index.ok_or_else(error_unreachable)?;
let after_call_label = self.cfg.add_jump(node_index, after_call_label);
self.end(&[fn_call]).ok_or_else(error_unreachable)?;
let after_call_label = self.cfg.get_or_add_basic_block(after_call_label);
let call_returns = self
.cfg
.call_returns
@ -1178,8 +1205,11 @@ impl<'a> BasicBlockState<'a> {
fn record_ret(&mut self, fn_name: SpirvWord) -> Result<(), TranslateError> {
let node_index = self.node_index.ok_or_else(error_unreachable)?;
let function_rets = self.cfg.function_rets.entry(fn_name).or_insert(Vec::new());
function_rets.push(node_index);
let previous_function_ret = self.cfg.functions_rets.insert(fn_name, node_index);
// This pass relies on there being only a single `ret;` in a function
if previous_function_ret.is_some() {
return Err(error_unreachable());
}
Ok(())
}
@ -1263,7 +1293,11 @@ struct PartialModeInsertion<T> {
bb_maybe_insert_mode: FxHashMap<SpirvWord, (T, FxHashSet<SpirvWord>)>,
}
fn optimize<T: Copy + Into<usize> + strum::VariantArray + std::fmt::Debug, const N: usize>(
// Only returns kernel mode insertions if a kernel is relevant to the optimization problem
fn optimize<
T: Copy + Into<usize> + strum::VariantArray + std::fmt::Debug + Default,
const N: usize,
>(
partial: PartialModeInsertion<T>,
) -> MandatoryModeInsertions<T> {
let mut problem = Problem::new(OptimizationDirection::Maximize);
@ -1341,6 +1375,8 @@ struct MandatoryModeInsertions<T> {
}
#[derive(Eq, PartialEq, Clone, Copy)]
//#[cfg_attr(test, derive(Debug))]
#[derive(Debug)]
enum ExtendedMode<T: Eq + PartialEq> {
BasicBlock(T),
Entry(SpirvWord),
@ -1549,4 +1585,4 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
}
#[cfg(test)]
mod test;
mod test;

View file

@ -198,7 +198,7 @@ fn compile_methods(ptx: &str) -> Vec<Function2<ast::Instruction<SpirvWord>, Spir
let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap();
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives);
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives).unwrap();
let directives = super::run(&mut flat_resolver, directives).unwrap();
directives
.into_iter()
@ -220,10 +220,37 @@ fn call_with_mode() {
&**methods[1].body.as_ref().unwrap(),
[
Statement::Label(..),
Statement::Variable(..),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Call { .. }),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
// Dual prelude
Statement::SetMode(ModeRegister::Denormal {
f32: false,
f16f64: false
f32: true,
f16f64: true
}),
Statement::SetMode(ModeRegister::Rounding {
f32: ast::RoundingMode::PositiveInf,
f16f64: ast::RoundingMode::NearestEven
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
// Denormal prelude
Statement::Label(..),
Statement::SetMode(ModeRegister::Denormal {
f32: true,
f16f64: true
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
// Rounding prelude
Statement::Label(..),
Statement::SetMode(ModeRegister::Rounding {
f32: ast::RoundingMode::PositiveInf,
f16f64: ast::RoundingMode::NearestEven
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Ret { .. }),
]
));

View file

@ -51,9 +51,9 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives);
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;

View file

@ -1,16 +1,23 @@
use super::*;
// This pass normalized ptx modules in two ways that makes mode computation pass
// This pass normalizes ptx modules in two ways that makes mode computation pass
// and code emissions passes much simpler:
// * Inserts label at the start of every function
// This makes control flow graph simpler in mode computation block: we can
// represent kernels as separate nodes with its own separate entry/exit mode
// * Inserts label at the start of every basic block
// * Insert explicit jumps before labels
// * Functions get a single `ret;` exit point - this is because mode computation
// logic requires it. Control flow graph constructed by mode computation
// models function calls as jumps into and then from another function.
// If this cfg allowed multiple return basic blocks then there would be cases
// where we want to insert mode setting instruction along the edge between
// `ret;` and bb in the caller. This is only possible if there's a single
// edge between from function `ret;` and caller
pub(crate) fn run(
flat_resolver: &mut GlobalStringIdentResolver2<'_>,
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>> {
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
for directive in directives.iter_mut() {
let body_ref = match directive {
Directive2::Method(Function2 {
@ -20,8 +27,9 @@ pub(crate) fn run(
};
let body = std::mem::replace(body_ref, Vec::new());
let mut result = Vec::with_capacity(body.len());
let mut needs_label = false;
let mut previous_instruction_was_terminator = TerminatorKind::Not;
let mut body_iterator = body.into_iter();
let mut return_statements = Vec::new();
match body_iterator.next() {
Some(Statement::Label(_)) => {}
Some(statement) => {
@ -31,25 +39,94 @@ pub(crate) fn run(
None => {}
}
for statement in body_iterator {
if needs_label && !matches!(statement, Statement::Label(..)) {
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
match previous_instruction_was_terminator {
TerminatorKind::Not => match statement {
Statement::Label(label) => {
result.push(Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src: label },
}))
}
_ => {}
},
TerminatorKind::Real => {
if !matches!(statement, Statement::Label(..)) {
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
}
}
TerminatorKind::Fake => match statement {
// if it happens that there is a label after a call just reuse it
Statement::Label(label) => {
result.push(Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src: label },
}))
}
_ => {
let label = flat_resolver.register_unnamed(None);
result.push(Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src: label },
}));
result.push(Statement::Label(label));
}
},
}
needs_label = is_block_terminator(&statement);
match statement {
Statement::RetValue(..) => {
return Err(error_unreachable());
}
Statement::Instruction(ast::Instruction::Ret { .. }) => {
return_statements.push(result.len())
}
_ => {}
}
previous_instruction_was_terminator = is_block_terminator(&statement);
result.push(statement);
}
convert_from_multiple_returns_to_single_return(
flat_resolver,
&mut result,
return_statements,
)?;
*body_ref = result;
}
directives
Ok(directives)
}
fn is_block_terminator(instruction: &Statement<ast::Instruction<SpirvWord>, SpirvWord>) -> bool {
match instruction {
enum TerminatorKind {
Not,
Real,
Fake,
}
fn convert_from_multiple_returns_to_single_return(
flat_resolver: &mut GlobalStringIdentResolver2<'_>,
result: &mut Vec<Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
return_statements: Vec<usize>,
) -> Result<(), TranslateError> {
Ok(if return_statements.len() > 1 {
let ret_bb = flat_resolver.register_unnamed(None);
result.push(Statement::Label(ret_bb));
result.push(Statement::Instruction(ast::Instruction::Ret {
data: ast::RetData { uniform: false },
}));
for ret_index in return_statements {
let statement = result.get_mut(ret_index).ok_or_else(error_unreachable)?;
*statement = Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src: ret_bb },
});
}
})
}
fn is_block_terminator(
statement: &Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> TerminatorKind {
match statement {
Statement::Conditional(..)
| Statement::Instruction(ast::Instruction::Bra { .. })
// Normally call is not a terminator, but we treat it as such because it
// makes the instruction modes to global modes pass possible
| Statement::Instruction(ast::Instruction::Call { .. })
| Statement::Instruction(ast::Instruction::Ret { .. }) => true,
_ => false,
// makes the "instruction modes to global modes" pass possible
| Statement::Instruction(ast::Instruction::Ret { .. }) => TerminatorKind::Real,
Statement::Instruction(ast::Instruction::Call { .. }) => TerminatorKind::Fake,
_ => TerminatorKind::Not,
}
}