diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 6cf9c79..2f645e7 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -20,6 +20,8 @@ strum_macros = "0.26" petgraph = "0.7.1" microlp = "0.2.10" int-enum = "1.1" +smallvec = "1.13" +unwrap_or = "1.0.1" [dev-dependencies] hip_runtime-sys = { path = "../ext/hip_runtime-sys" } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 73d7ced..66ceb75 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -2264,17 +2264,7 @@ impl<'a> MethodEmitContext<'a> { let intrinsic = c"llvm.amdgcn.s.setreg"; let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); let (hwreg, value) = match mode_reg { - ModeRegister::DenormalF32(ftz) => { - let (reg, offset, size) = (1, 4, 2u32); - let hwreg = reg | (offset << 6) | ((size - 1) << 11); - (hwreg, if ftz { 0u32 } else { 3 }) - } - ModeRegister::DenormalF16F64(ftz) => { - let (reg, offset, size) = (1, 6, 2u32); - let hwreg = reg | (offset << 6) | ((size - 1) << 11); - (hwreg, if ftz { 0 } else { 3 }) - } - ModeRegister::DenormalBoth { f32, f16f64 } => { + ModeRegister::Denormal { f32, f16f64 } => { let (reg, offset, size) = (1, 4, 4u32); let hwreg = reg | (offset << 6) | ((size - 1) << 11); let f32 = if f32 { 0 } else { 3 }; @@ -2282,9 +2272,7 @@ impl<'a> MethodEmitContext<'a> { let value = f32 | f16f64 << 2; (hwreg, value) } - ModeRegister::RoundingF32(rounding_mode) => todo!(), - ModeRegister::RoundingF16F64(rounding_mode) => todo!(), - ModeRegister::RoundingBoth { f32, f16f64 } => todo!(), + ModeRegister::Rounding { .. } => todo!(), }; let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) }; let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) }; diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index c87dd92..a9ede33 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -51,8 +51,8 @@ fn run_method<'input>( is_kernel: method.is_kernel, flush_to_zero_f32: method.flush_to_zero_f32, flush_to_zero_f16f64: method.flush_to_zero_f16f64, - roundind_mode_f32: method.roundind_mode_f32, - roundind_mode_f16f64: method.roundind_mode_f16f64, + rounding_mode_f32: method.rounding_mode_f32, + rounding_mode_f16f64: method.rounding_mode_f16f64, }) } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index ad484fd..78e66c9 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -22,8 +22,8 @@ pub(super) fn run<'a, 'input>( is_kernel: false, flush_to_zero_f32: false, flush_to_zero_f16f64: false, - roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, - roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, })); sreg_to_function.insert(sreg, name); }, diff --git a/ptx/src/pass/insert_ftz_control/call_with_mode.ptx b/ptx/src/pass/insert_ftz_control/call_with_mode.ptx new file mode 100644 index 0000000..cfff97c --- /dev/null +++ b/ptx/src/pass/insert_ftz_control/call_with_mode.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_50 +.address_size 64 + +.func use_modes(); + +.visible .entry kernel() +{ + .reg .f32 temp; + + add.rz.ftz.f32 temp, temp, temp; + call use_modes; + ret; +} + +.func use_modes() +{ + .reg .f32 temp; + add.rm.f32 temp, temp, temp; + ret; +} diff --git a/ptx/src/pass/insert_ftz_control/fold_denormal.ptx b/ptx/src/pass/insert_ftz_control/fold_denormal.ptx new file mode 100644 index 0000000..1fa161a --- /dev/null +++ b/ptx/src/pass/insert_ftz_control/fold_denormal.ptx @@ -0,0 +1,15 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry add() +{ + .reg .f32 temp<3>; + + add.ftz.f16 temp2, temp1, temp0; + add.ftz.f32 temp2, temp1, temp0; + + add.f16 temp2, temp1, temp0; + add.f32 temp2, temp1, temp0; + ret; +} diff --git a/ptx/src/pass/insert_ftz_control.rs b/ptx/src/pass/insert_ftz_control/mod.rs similarity index 55% rename from ptx/src/pass/insert_ftz_control.rs rename to ptx/src/pass/insert_ftz_control/mod.rs index 649d332..a7097dd 100644 --- a/ptx/src/pass/insert_ftz_control.rs +++ b/ptx/src/pass/insert_ftz_control/mod.rs @@ -16,12 +16,14 @@ use petgraph::Graph; use ptx_parser as ast; use rustc_hash::FxHashMap; use rustc_hash::FxHashSet; +use smallvec::SmallVec; use std::hash::Hash; use std::iter; use std::mem; use std::u32; use strum::EnumCount; use strum_macros::{EnumCount, VariantArray}; +use unwrap_or::unwrap_some_or; #[derive(Default, PartialEq, Eq, Clone, Copy, Debug, VariantArray, EnumCount)] enum DenormalMode { @@ -47,6 +49,12 @@ impl DenormalMode { } } +impl Into for DenormalMode { + fn into(self) -> bool { + self.to_ftz() + } +} + impl Into for DenormalMode { fn into(self) -> usize { self as usize @@ -82,6 +90,12 @@ impl RoundingMode { } } +impl Into for RoundingMode { + fn into(self) -> ast::RoundingMode { + self.to_ast() + } +} + impl Into for RoundingMode { fn into(self) -> usize { self as usize @@ -233,6 +247,10 @@ impl InstructionModes { struct ControlFlowGraph { entry_points: FxHashMap, basic_blocks: FxHashMap, + // map function -> return label + call_returns: FxHashMap>, + // map function -> return basic blocks + function_rets: FxHashMap>, graph: Graph, } @@ -241,6 +259,8 @@ impl ControlFlowGraph { Self { entry_points: FxHashMap::default(), basic_blocks: FxHashMap::default(), + call_returns: FxHashMap::default(), + function_rets: FxHashMap::default(), graph: Graph::new(), } } @@ -259,9 +279,10 @@ impl ControlFlowGraph { }) } - fn add_jump(&mut self, from: NodeIndex, to: SpirvWord) { + fn add_jump(&mut self, from: NodeIndex, to: SpirvWord) -> NodeIndex { let to = self.get_or_add_basic_block(to); self.graph.add_edge(from, to, ()); + to } fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) { @@ -275,6 +296,23 @@ impl ControlFlowGraph { node.rounding_f32.exit = exit.rounding_f32.map(ExtendedMode::BasicBlock); node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock); } + + fn fixup_function_calls(&mut self) { + for (function, sources) in self.function_rets.iter() { + for target in self + .call_returns + .get(function) + .iter() + .map(|vec| vec.iter()) + .flatten() + .copied() + { + for source in sources { + self.graph.add_edge(*source, target, ()); + } + } + } + } } #[derive(Clone, Copy)] @@ -349,13 +387,26 @@ pub(crate) fn run<'input>( is_kernel, .. }) => { - let (mut bb_state, body_iter) = + let (mut bb_state, mut body_iter) = BasicBlockState::new(&mut cfg, *name, body, *is_kernel)?; - for statement in body_iter { + while let Some(statement) = body_iter.next() { match statement { Statement::Instruction(ast::Instruction::Bra { arguments }) => { bb_state.end(&[arguments.src]); } + Statement::Instruction(ast::Instruction::Call { + arguments: ast::CallArgs { func, .. }, + .. + }) => { + let after_call_label = match body_iter.peek() { + Some(Statement::Label(l)) => *l, + _ => return Err(error_unreachable()), + }; + bb_state.record_call(*func, after_call_label)?; + } + Statement::Instruction(ast::Instruction::Ret { .. }) => { + bb_state.record_ret(*name)?; + } Statement::Label(label) => { bb_state.start(*label); } @@ -375,6 +426,7 @@ pub(crate) fn run<'input>( _ => {} } } + cfg.fixup_function_calls(); let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32); let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64); let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32); @@ -383,155 +435,447 @@ pub(crate) fn run<'input>( let denormal_f16f64 = optimize::(denormal_f16f64); let rounding_f32 = optimize::(rounding_f32); let rounding_f16f64 = optimize::(rounding_f16f64); - insert_mode_control( + let denormal = join_modes( flat_resolver, - &mut directives, &cfg, denormal_f32, + |node| node.denormal_f32.entry, + |node| node.denormal_f32.exit, denormal_f16f64, - rounding_f32, - rounding_f16f64, + |node| node.denormal_f16f64.entry, + |node| node.denormal_f16f64.exit, )?; + let rounding = join_modes( + flat_resolver, + &cfg, + rounding_f32, + |node| node.rounding_f32.entry, + |node| node.rounding_f32.exit, + rounding_f16f64, + |node| node.rounding_f16f64.entry, + |node| node.rounding_f16f64.exit, + )?; + let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?; + let directives = insert_mode_control(directives, all_modes)?; Ok(directives) } -fn insert_mode_control<'input>( +// For every basic block this pass computes: +// - Name of mode prologue basic block. Mode prologue is a basic block which +// contains single instruction that sets mode to the desired value. It will +// be later inserted just before the basic block and all jumps that require +// mode change will go through this basic block +// - Entry mode: what is the mode for both f32 and f16f64 at the first instruction. +// This will be used when emiting instructions in the basic block. When we +// emit an instruction we get its modes, check if they are different and if so +// decide: do we emit new mode set statement or we fold into previous mode set. +// We don't need to compute exit mode because this will be computed naturally +// when emitting instructions in a basic block. We need exit mode to know if we +// jump directly to the next bb or jump to mode prologue +fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, - directives: &mut [Directive2, SpirvWord>], cfg: &ControlFlowGraph, - denormal_f32: ModeInsertions, - denormal_f16f64: ModeInsertions, - rounding_f32: ModeInsertions, - rounding_f16f64: ModeInsertions, -) -> Result<(), TranslateError> { - for directive in directives.iter_mut() { - let body_ptr = match directive { - Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => continue, - Directive2::Method(Function2 { - name, - body: Some(body), - flush_to_zero_f32, - flush_to_zero_f16f64, - roundind_mode_f32: rounding_mode_f32, - roundind_mode_f16f64: rounding_mode_f16f64, - .. - }) => { - *flush_to_zero_f32 = denormal_f32 - .kernels - .get(name) - .copied() - .unwrap_or(DenormalMode::default()) - .to_ftz(); - *flush_to_zero_f16f64 = denormal_f16f64 - .kernels - .get(name) - .copied() - .unwrap_or(DenormalMode::default()) - .to_ftz(); - *rounding_mode_f32 = rounding_f32 - .kernels - .get(name) - .copied() - .unwrap_or(RoundingMode::default()) - .to_ast(); - *rounding_mode_f16f64 = rounding_f16f64 - .kernels - .get(name) - .copied() - .unwrap_or(RoundingMode::default()) - .to_ast(); - body + f32_insertions: MandatoryModeInsertions, + mut f32_enter_view: impl FnMut(&Node) -> Option>, + mut f32_exit_view: impl FnMut(&Node) -> Option>, + f16f64_insertions: MandatoryModeInsertions, + mut f16f64_enter_view: impl FnMut(&Node) -> Option>, + mut f16f64_exit_view: impl FnMut(&Node) -> Option>, +) -> Result, TranslateError> { + // Returns None if there are multiple conflicting modes + fn get_incoming_mode( + cfg: &ControlFlowGraph, + kernels: &FxHashMap, + node: NodeIndex, + mut exit_getter: impl FnMut(&Node) -> Option>, + ) -> Result, TranslateError> { + let mut mode: Option = None; + let mut visited = iter::once(node).collect::>(); + let mut to_visit = cfg + .graph + .neighbors_directed(node, Direction::Incoming) + .map(|x| x) + .collect::>(); + while let Some(node) = to_visit.pop() { + if !visited.insert(node) { + continue; } - }; - let mut old_body = mem::replace(body_ptr, Vec::new()); - let mut result = Vec::with_capacity(old_body.len()); - let mut bb_state = BasicBlockControlState::new( - &denormal_f32, - &denormal_f16f64, - &rounding_f32, - &rounding_f16f64, - ); - for statement in old_body.into_iter() { - match &statement { - Statement::Label(label) => { - bb_state.start(*label); + let x = &cfg.graph[node]; + match (mode, exit_getter(x)) { + (_, None) => { + for next in cfg.graph.neighbors_directed(node, Direction::Incoming) { + if !visited.insert(next) { + to_visit.push(next); + } + } } - Statement::Instruction(instruction) => { - let modes = get_modes(&instruction); - bb_state.insert(&mut result, modes)?; + (existing_mode, Some(new_mode)) => { + let new_mode = match new_mode { + ExtendedMode::BasicBlock(new_mode) => new_mode, + ExtendedMode::Entry(kernel) => { + *kernels.get(&kernel).ok_or_else(error_unreachable)? + } + }; + if let Some(existing_mode) = existing_mode { + if existing_mode != new_mode { + return Ok(None); + } + } + mode = Some(new_mode); } - _ => {} } - result.push(statement); } - *body_ptr = result; + mode.map(Some).ok_or_else(error_unreachable) } - Ok(()) + let basic_blocks = cfg + .graph + .node_references() + .into_iter() + .map(|(node, basic_block)| { + let requires_prologue = f32_insertions.basic_blocks.contains(&basic_block.label) + || f16f64_insertions.basic_blocks.contains(&basic_block.label); + let prologue: Option = if requires_prologue { + Some(flat_resolver.register_unnamed(None)) + } else { + None + }; + let f32 = match f32_enter_view(&basic_block) { + Some(ExtendedMode::BasicBlock(mode)) => Some(mode), + Some(ExtendedMode::Entry(kernel)) => Some( + f32_insertions + .kernels + .get(&kernel) + .copied() + .ok_or_else(error_unreachable)?, + ), + // None means that no instruction in the basic block sets mode, but + // another basic block might rely on this instruction transitively + // passing a mode + None => None, + }; + let f16f64 = match f16f64_enter_view(&basic_block) { + Some(ExtendedMode::BasicBlock(mode)) => Some(mode), + Some(ExtendedMode::Entry(kernel)) => Some( + f16f64_insertions + .kernels + .get(&kernel) + .copied() + .ok_or_else(error_unreachable)?, + ), + None => None, + }; + let twin_mode = match (f32, f16f64) { + (Some(f32), Some(f16f64)) => Some(TwinMode { f32, f16f64 }), + (None, Some(f16f64)) => { + let f32 = get_incoming_mode(cfg, &f32_insertions.kernels, node, |node| { + f32_exit_view(node) + })?; + let f32 = f32.unwrap_or_default(); + Some(TwinMode { f32, f16f64 }) + } + (Some(f32), None) => { + let f16f64 = + get_incoming_mode(cfg, &f16f64_insertions.kernels, node, |node| { + f16f64_exit_view(node) + })?; + let f16f64 = f16f64.unwrap_or_default(); + Some(TwinMode { f32, f16f64 }) + } + (None, None) => None, + }; + Ok(( + basic_block.label, + BasicBlockEntryState { + prologue, + twin_mode, + }, + )) + }) + .collect::, _>>()?; + Ok(TwinModeInsertions { basic_blocks }) } -struct BasicBlockControlState<'a> { - global_denormal_f32: &'a ModeInsertions, - global_denormal_f16f64: &'a ModeInsertions, - global_rounding_f32: &'a ModeInsertions, - global_rounding_f16f64: &'a ModeInsertions, - basic_block: SpirvWord, - denormal_f32: RegisterState, - denormal_f16f64: RegisterState, - foldable_rounding_f32: Option, - foldable_rounding_f16f64: Option, +struct TwinModeInsertions { + basic_blocks: FxHashMap>, +} + +struct FullModeInsertion { + basic_blocks: FxHashMap, +} + +impl FullModeInsertion { + fn new( + flat_resolver: &mut super::GlobalStringIdentResolver2, + denormal: TwinModeInsertions, + rounding: TwinModeInsertions, + ) -> Result { + let denormal = denormal.basic_blocks; + let rounding = rounding.basic_blocks; + if denormal.len() != rounding.len() { + return Err(error_unreachable()); + } + let basic_blocks = denormal + .into_iter() + .map(|(bb, denormal)| { + let rounding = rounding.get(&bb).copied().ok_or_else(error_unreachable)?; + let dual_prologue = if denormal.prologue.is_some() && rounding.prologue.is_some() { + Some(flat_resolver.register_unnamed(None)) + } else { + None + }; + Ok(( + bb, + FullBasicBlockEntryState { + dual_prologue, + denormal, + rounding, + }, + )) + }) + .collect::, _>>()?; + Ok(Self { basic_blocks }) + } +} + +struct FullBasicBlockEntryState { + dual_prologue: Option, + denormal: BasicBlockEntryState, + rounding: BasicBlockEntryState, } #[derive(Clone, Copy)] -enum RegisterState { - Inherited, - Unknown, - Value(Option, T), +struct BasicBlockEntryState { + prologue: Option, + // It is None in case where no instructions in the basic block uses mode + twin_mode: Option>, +} + +#[derive(Clone, Copy, Default)] +struct TwinMode { + f32: T, + f16f64: T, +} + +fn insert_mode_control<'input>( + directives: Vec, SpirvWord>>, + global_modes: FullModeInsertion, +) -> Result, SpirvWord>>, TranslateError> { + let directives_len = directives.len(); + directives + .into_iter() + .map(|mut directive| { + let mut new_directives = SmallVec::<[_; 4]>::new(); + let (fn_name, initial_mode, body_ptr) = match directive { + Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => { + new_directives.push(directive); + return Ok(new_directives); + } + Directive2::Method(Function2 { + name, + body: Some(ref mut body), + ref mut flush_to_zero_f32, + ref mut flush_to_zero_f16f64, + ref mut rounding_mode_f32, + ref mut rounding_mode_f16f64, + .. + }) => { + let initial_mode = global_modes + .basic_blocks + .get(&name) + .ok_or_else(error_unreachable)?; + *flush_to_zero_f32 = initial_mode + .denormal + .twin_mode + .unwrap_or_default() + .f32 + .to_ftz(); + *flush_to_zero_f16f64 = initial_mode + .denormal + .twin_mode + .unwrap_or_default() + .f16f64 + .to_ftz(); + *rounding_mode_f32 = initial_mode + .rounding + .twin_mode + .unwrap_or_default() + .f32 + .to_ast(); + *rounding_mode_f16f64 = initial_mode + .rounding + .twin_mode + .unwrap_or_default() + .f16f64 + .to_ast(); + (name, initial_mode, body) + } + }; + let old_body = mem::replace(body_ptr, Vec::new()); + let mut result = Vec::with_capacity(old_body.len()); + let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode); + for mut statement in old_body.into_iter() { + match &mut statement { + Statement::Label(label) => { + bb_state.start(*label, &mut result)?; + } + Statement::Instruction(ast::Instruction::Call { + arguments: ast::CallArgs { func, .. }, + .. + }) => { + bb_state.redirect_jump(func)?; + } + Statement::Conditional(BrachCondition { + if_true, if_false, .. + }) => { + bb_state.redirect_jump(if_true)?; + bb_state.redirect_jump(if_false)?; + } + Statement::Instruction(ast::Instruction::Bra { + arguments: ptx_parser::BraArgs { src }, + }) => { + bb_state.redirect_jump(src)?; + } + Statement::Instruction(instruction) => { + let modes = get_modes(&instruction); + bb_state.insert(&mut result, modes)?; + } + _ => {} + } + result.push(statement); + } + *body_ptr = result; + new_directives.push(directive); + Ok(new_directives) + }) + .try_fold(Vec::with_capacity(directives_len), |mut acc, d| { + acc.extend(d?); + Ok(acc) + }) +} + +struct BasicBlockControlState<'a> { + global_modes: &'a FullModeInsertion, + denormal_f32: RegisterState, + denormal_f16f64: RegisterState, + rounding_f32: RegisterState, + rounding_f16f64: RegisterState, + current_bb: SpirvWord, +} + +#[derive(Clone, Copy)] +struct RegisterState { + current_value: Option, + // This is slightly subtle: this value is Some iff there's a SetMode in this + // basic block setting this mode, but on which no instruciton relies + last_foldable: Option, } impl RegisterState { - fn empty() -> Self { - Self::Unknown + fn single(t: T) -> Self { + RegisterState { + last_foldable: None, + current_value: Some(t), + } } - fn new(must_insert: bool) -> Self { - if must_insert { - Self::Unknown - } else { - Self::Inherited + fn empty() -> Self { + RegisterState { + last_foldable: None, + current_value: None, + } + } + + fn new(computed: &BasicBlockEntryState) -> (RegisterState, RegisterState) + where + U: Into, + { + match computed.twin_mode { + Some(ref mode) => ( + RegisterState::single(mode.f32.into()), + RegisterState::single(mode.f16f64.into()), + ), + None => (RegisterState::empty(), RegisterState::empty()), } } } impl<'a> BasicBlockControlState<'a> { fn new( - global_denormal_f32: &'a ModeInsertions, - global_denormal_f16f64: &'a ModeInsertions, - global_rounding_f32: &'a ModeInsertions, - global_rounding_f16f64: &'a ModeInsertions, + global_modes: &'a FullModeInsertion, + current_bb: SpirvWord, + initial_mode: &FullBasicBlockEntryState, ) -> Self { + let (denormal_f32, denormal_f16f64) = RegisterState::new(&initial_mode.denormal); + let (rounding_f32, rounding_f16f64) = RegisterState::new(&initial_mode.rounding); BasicBlockControlState { - global_denormal_f32, - global_denormal_f16f64, - global_rounding_f32, - global_rounding_f16f64, - basic_block: SpirvWord(u32::MAX), - denormal_f32: RegisterState::empty(), - denormal_f16f64: RegisterState::empty(), - foldable_rounding_f32: None, - foldable_rounding_f16f64: None, + global_modes, + denormal_f32, + denormal_f16f64, + rounding_f32, + rounding_f16f64, + current_bb, } } - fn start(&mut self, label: SpirvWord) { - self.denormal_f32 = - RegisterState::new(self.global_denormal_f32.basic_blocks.contains(&label)); - self.denormal_f16f64 = - RegisterState::new(self.global_denormal_f16f64.basic_blocks.contains(&label)); - self.foldable_rounding_f32 = None; - self.foldable_rounding_f16f64 = None; - self.basic_block = label; + fn start( + &mut self, + basic_block: SpirvWord, + statements: &mut Vec, SpirvWord>>, + ) -> Result<(), TranslateError> { + let bb_state = self + .global_modes + .basic_blocks + .get(&basic_block) + .ok_or_else(error_unreachable)?; + + let (denormal_f32, denormal_f16f64) = RegisterState::new(&bb_state.denormal); + self.denormal_f32 = denormal_f32; + self.denormal_f16f64 = denormal_f16f64; + let (rounding_f32, rounding_f16f64) = RegisterState::new(&bb_state.rounding); + self.rounding_f32 = rounding_f32; + self.rounding_f16f64 = rounding_f16f64; + if let Some(prologue) = bb_state.dual_prologue { + statements.push(Statement::Label(prologue)); + let denormal = bb_state.denormal.twin_mode.ok_or_else(error_unreachable)?; + statements.push(Statement::SetMode(ModeRegister::Denormal { + f32: denormal.f32.to_ftz(), + f16f64: denormal.f16f64.to_ftz(), + })); + let rounding = bb_state.rounding.twin_mode.ok_or_else(error_unreachable)?; + statements.push(Statement::SetMode(ModeRegister::Rounding { + f32: rounding.f32.to_ast(), + f16f64: rounding.f16f64.to_ast(), + })); + statements.push(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: basic_block }, + })); + } + if let Some(prologue) = bb_state.denormal.prologue { + statements.push(Statement::Label(prologue)); + let denormal = bb_state.denormal.twin_mode.ok_or_else(error_unreachable)?; + statements.push(Statement::SetMode(ModeRegister::Denormal { + f32: denormal.f32.to_ftz(), + f16f64: denormal.f16f64.to_ftz(), + })); + statements.push(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: basic_block }, + })); + } + if let Some(prologue) = bb_state.rounding.prologue { + statements.push(Statement::Label(prologue)); + let rounding = bb_state.rounding.twin_mode.ok_or_else(error_unreachable)?; + statements.push(Statement::SetMode(ModeRegister::Rounding { + f32: rounding.f32.to_ast(), + f16f64: rounding.f16f64.to_ast(), + })); + statements.push(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: basic_block }, + })); + } + Ok(()) } + /* fn add_or_fold_mode_set( &mut self, result: &mut Vec, SpirvWord>>, @@ -552,6 +896,7 @@ impl<'a> BasicBlockControlState<'a> { result.push(Statement::SetMode(ModeRegister::DenormalF32(new_mode))); Some(result.len() - 1) } + */ fn insert( &mut self, @@ -571,46 +916,54 @@ impl<'a> BasicBlockControlState<'a> { result: &mut Vec, SpirvWord>>, mode: Option, ) -> Result<(), TranslateError> { - if let Some(new_mode) = mode { - let register_state = View::get_register(self); - match register_state { - RegisterState::Inherited => { - View::set_register(self, RegisterState::Value(None, new_mode)); - } - RegisterState::Unknown => { - View::set_register( - self, - RegisterState::Value( - Some(self.add_or_fold_mode_set2::(result, new_mode)), - new_mode, - ), - ); - } - RegisterState::Value(_, old_value) => { - if new_mode == old_value { - return Ok(()); - } - View::set_register( - self, - RegisterState::Value( - Some(self.add_or_fold_mode_set2::(result, new_mode)), - new_mode, - ), - ); - } - } + fn set_fold_index(bb: &mut BasicBlockControlState, index: Option) { + let mut reg = View::get_register(bb); + reg.last_foldable = index; + View::set_register(bb, reg); } + let new_mode = unwrap_some_or!(mode, return Ok(())); + // if let Some(new_mode) = mode { + let register_state = View::get_register(self); + match register_state.current_value { + Some(old) if old == new_mode => { + set_fold_index::(self, None); + } + _ => match register_state.last_foldable { + // fold successful + Some(index) => { + if let Some(Statement::SetMode(mode_set)) = result.get_mut(index) { + View::set_single_mode(mode_set, new_mode)?; + set_fold_index::(self, None); + } else { + return Err(error_unreachable()); + } + } + // fold failed, insert new instruction + None => { + result.push(Statement::SetMode(View::new_mode( + new_mode, + View::TwinView::get_register(self) + .current_value + .ok_or_else(error_unreachable)?, + ))); + set_fold_index::(self, Some(result.len() - 1)); + } + }, + } + //} Ok(()) } // Return the index of the last insertion of SetMode with this mode + /* fn add_or_fold_mode_set2( &self, result: &mut Vec, SpirvWord>>, new_mode: View::Value, - ) -> usize { - // try and fold into the other mode set in struction - if let RegisterState::Value(Some(twin_index), _) = View::TwinView::get_register(self) { + ) -> Result<(), TranslateError> { + // try and fold into the other mode set instruction + View::get_register(bb) + if let RegisterState { last_foldable: } = View::TwinView::get_register(self) { if let Some(Statement::SetMode(register_mode)) = result.get_mut(twin_index) { if let Some(twin_mode) = View::TwinView::get_single_mode(register_mode) { *register_mode = View::new_mode(new_mode, Some(twin_mode)); @@ -621,6 +974,58 @@ impl<'a> BasicBlockControlState<'a> { result.push(Statement::SetMode(View::new_mode(new_mode, None))); result.len() - 1 } + */ + + fn redirect_jump(&self, jump_target: &mut SpirvWord) -> Result<(), TranslateError> { + let target = self + .global_modes + .basic_blocks + .get(jump_target) + .ok_or_else(error_unreachable)?; + let jump_to_denormal = match ( + self.denormal_f32.current_value, + self.denormal_f16f64.current_value, + ) { + (None, None) => false, + (Some(current_f32), Some(current_f16f64)) => { + if let Some(target_mode) = target.denormal.twin_mode { + target_mode.f32.to_ftz() != current_f32 + || target_mode.f16f64.to_ftz() != current_f16f64 + } else { + false + } + } + _ => return Err(error_unreachable()), + }; + let jump_to_rounding = match ( + self.rounding_f32.current_value, + self.rounding_f16f64.current_value, + ) { + (None, None) => false, + (Some(current_f32), Some(current_f16f64)) => { + if let Some(target_mode) = target.rounding.twin_mode { + target_mode.f32.to_ast() != current_f32 + || target_mode.f16f64.to_ast() != current_f16f64 + } else { + false + } + } + _ => return Err(error_unreachable()), + }; + match (jump_to_denormal, jump_to_rounding) { + (true, false) => { + *jump_target = target.denormal.prologue.ok_or_else(error_unreachable)?; + } + (false, true) => { + *jump_target = target.rounding.prologue.ok_or_else(error_unreachable)?; + } + (true, true) => { + *jump_target = target.dual_prologue.ok_or_else(error_unreachable)?; + } + (false, false) => {} + } + Ok(()) + } } trait ModeView { @@ -629,8 +1034,8 @@ trait ModeView { fn get_register(bb: &BasicBlockControlState) -> RegisterState; fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState); - fn new_mode(t: Self::Value, other: Option) -> ModeRegister; - fn get_single_mode(reg: &ModeRegister) -> Option; + fn new_mode(t: Self::Value, other: Self::Value) -> ModeRegister; + fn set_single_mode(reg: &mut ModeRegister, x: Self::Value) -> Result<(), TranslateError>; } struct DenormalF32View; @@ -647,18 +1052,16 @@ impl ModeView for DenormalF32View { bb.denormal_f32 = reg; } - fn new_mode(f32: Self::Value, f16f64: Option) -> ModeRegister { - match f16f64 { - Some(f16f64) => ModeRegister::DenormalBoth { f32, f16f64 }, - None => ModeRegister::DenormalF32(f32), - } + fn new_mode(f32: Self::Value, f16f64: Self::Value) -> ModeRegister { + ModeRegister::Denormal { f32, f16f64 } } - fn get_single_mode(reg: &ModeRegister) -> Option { + fn set_single_mode(reg: &mut ModeRegister, x: Self::Value) -> Result<(), TranslateError> { match reg { - ModeRegister::DenormalF32(value) => Some(*value), - _ => None, + ModeRegister::Denormal { f32, f16f64: _ } => *f32 = x, + ModeRegister::Rounding { .. } => return Err(error_unreachable()), } + Ok(()) } } @@ -676,18 +1079,16 @@ impl ModeView for DenormalF16F64View { bb.denormal_f16f64 = reg; } - fn new_mode(f16f64: Self::Value, f32: Option) -> ModeRegister { - match f32 { - Some(f32) => ModeRegister::DenormalBoth { f16f64, f32 }, - None => ModeRegister::DenormalF16F64(f16f64), - } + fn new_mode(f16f64: Self::Value, f32: Self::Value) -> ModeRegister { + ModeRegister::Denormal { f32, f16f64 } } - fn get_single_mode(reg: &ModeRegister) -> Option { + fn set_single_mode(reg: &mut ModeRegister, x: Self::Value) -> Result<(), TranslateError> { match reg { - ModeRegister::DenormalF16F64(value) => Some(*value), - _ => None, + ModeRegister::Denormal { f32: _, f16f64 } => *f16f64 = x, + ModeRegister::Rounding { .. } => return Err(error_unreachable()), } + Ok(()) } } @@ -709,7 +1110,9 @@ impl<'a> BasicBlockState<'a> { ) -> Result< ( BasicBlockState<'a>, - impl Iterator, SpirvWord>>, + std::iter::Peekable< + impl Iterator, SpirvWord>>, + >, ), TranslateError, > { @@ -732,7 +1135,7 @@ impl<'a> BasicBlockState<'a> { } _ => return Err(error_unreachable()), }; - Ok((bb_state, body_iter)) + Ok((bb_state, body_iter.peekable())) } fn start(&mut self, label: SpirvWord) { @@ -740,11 +1143,11 @@ impl<'a> BasicBlockState<'a> { self.node_index = Some(self.cfg.get_or_add_basic_block(label)); } - fn end(&mut self, jumps: &[SpirvWord]) { + fn end(&mut self, jumps: &[SpirvWord]) -> Option { let node_index = self.node_index.take(); let node_index = match node_index { Some(x) => x, - None => return, + None => return None, }; for target in jumps { self.cfg.add_jump(node_index, *target); @@ -754,6 +1157,30 @@ impl<'a> BasicBlockState<'a> { mem::replace(&mut self.entry, InstructionModes::none()), mem::replace(&mut self.exit, InstructionModes::none()), ); + Some(node_index) + } + + fn record_call( + &mut self, + fn_call: SpirvWord, + after_call_label: SpirvWord, + ) -> Result<(), TranslateError> { + let node_index = self.node_index.ok_or_else(error_unreachable)?; + let after_call_label = self.cfg.add_jump(node_index, after_call_label); + let call_returns = self + .cfg + .call_returns + .entry(fn_call) + .or_insert_with(|| Vec::new()); + call_returns.push(after_call_label); + Ok(()) + } + + fn record_ret(&mut self, fn_name: SpirvWord) -> Result<(), TranslateError> { + let node_index = self.node_index.ok_or_else(error_unreachable)?; + let function_rets = self.cfg.function_rets.entry(fn_name).or_insert(Vec::new()); + function_rets.push(node_index); + Ok(()) } fn append(&mut self, modes: InstructionModes) { @@ -838,7 +1265,7 @@ struct PartialModeInsertion { fn optimize + strum::VariantArray + std::fmt::Debug, const N: usize>( partial: PartialModeInsertion, -) -> ModeInsertions { +) -> MandatoryModeInsertions { let mut problem = Problem::new(OptimizationDirection::Maximize); let mut kernel_modes = FxHashMap::default(); let basic_block_variables = partial @@ -875,7 +1302,7 @@ fn optimize + strum::VariantArray + std::fmt::Debug, const } } } - ModeInsertions { + MandatoryModeInsertions { basic_blocks, kernels, } @@ -908,7 +1335,7 @@ fn one_of(problem: &mut Problem) -> [Variable; N] { result } -struct ModeInsertions { +struct MandatoryModeInsertions { basic_blocks: FxHashSet, kernels: FxHashMap, } @@ -1122,167 +1549,4 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { } #[cfg(test)] -mod tests { - use super::*; - use int_enum::IntEnum; - use strum::EnumCount; - - #[repr(usize)] - #[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)] - enum Bool { - False = 0, - True = 1, - } - - fn ftz() -> InstructionModes { - InstructionModes { - denormal_f32: Some(DenormalMode::FlushToZero), - denormal_f16f64: None, - rounding_f32: None, - rounding_f16f64: None, - } - } - - fn preserve() -> InstructionModes { - InstructionModes { - denormal_f32: Some(DenormalMode::Preserve), - denormal_f16f64: None, - rounding_f32: None, - rounding_f16f64: None, - } - } - - #[test] - fn transitive_mixed() { - let mut graph = ControlFlowGraph::new(); - let entry_id = SpirvWord(1); - let false_id = SpirvWord(2); - let empty_id = SpirvWord(3); - let false2_id = SpirvWord(4); - let entry = graph.add_entry_basic_block(entry_id); - graph.add_jump(entry, false_id); - let false_ = graph.get_or_add_basic_block(false_id); - graph.set_modes(false_, ftz(), ftz()); - graph.add_jump(false_, empty_id); - let empty = graph.get_or_add_basic_block(empty_id); - graph.add_jump(empty, false2_id); - let false2_ = graph.get_or_add_basic_block(false2_id); - graph.set_modes(false2_, ftz(), ftz()); - let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); - assert_eq!(partial_result.bb_must_insert_mode.len(), 0); - assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); - assert_eq!( - partial_result.bb_maybe_insert_mode[&false_id], - (DenormalMode::FlushToZero, iter::once(entry_id).collect()) - ); - - let result = optimize::(partial_result); - assert_eq!(result.basic_blocks.len(), 0); - assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); - } - - #[test] - fn transitive_change_twice() { - let mut graph = ControlFlowGraph::new(); - let entry_id = SpirvWord(1); - let false_id = SpirvWord(2); - let empty_id = SpirvWord(3); - let true_id = SpirvWord(4); - let entry = graph.add_entry_basic_block(entry_id); - graph.add_jump(entry, false_id); - let false_ = graph.get_or_add_basic_block(false_id); - graph.set_modes(false_, ftz(), ftz()); - graph.add_jump(false_, empty_id); - let empty = graph.get_or_add_basic_block(empty_id); - graph.add_jump(empty, true_id); - let true_ = graph.get_or_add_basic_block(true_id); - graph.set_modes(true_, preserve(), preserve()); - let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); - assert_eq!(partial_result.bb_must_insert_mode.len(), 1); - assert!(partial_result.bb_must_insert_mode.contains(&true_id)); - assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); - assert_eq!( - partial_result.bb_maybe_insert_mode[&false_id], - (DenormalMode::FlushToZero, iter::once(entry_id).collect()) - ); - - let result = optimize::(partial_result); - assert_eq!(result.basic_blocks, iter::once(true_id).collect()); - assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); - } - - #[test] - fn transitive_change() { - let mut graph = ControlFlowGraph::new(); - let entry_id = SpirvWord(1); - let empty_id = SpirvWord(2); - let true_id = SpirvWord(3); - let entry = graph.add_entry_basic_block(entry_id); - graph.add_jump(entry, empty_id); - let empty = graph.get_or_add_basic_block(empty_id); - graph.add_jump(empty, true_id); - let true_ = graph.get_or_add_basic_block(true_id); - graph.set_modes(true_, preserve(), preserve()); - let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); - assert_eq!(partial_result.bb_must_insert_mode.len(), 0); - assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); - assert_eq!( - partial_result.bb_maybe_insert_mode[&true_id], - (DenormalMode::Preserve, iter::once(entry_id).collect()) - ); - - let result = optimize::(partial_result); - assert_eq!(result.basic_blocks.len(), 0); - assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], DenormalMode::Preserve); - } - - #[test] - fn codependency() { - let mut graph = ControlFlowGraph::new(); - let entry_id = SpirvWord(1); - let left_f_id = SpirvWord(2); - let right_f_id = SpirvWord(3); - let left_none_id = SpirvWord(4); - let mid_none_id = SpirvWord(5); - let right_none_id = SpirvWord(6); - let entry = graph.add_entry_basic_block(entry_id); - graph.add_jump(entry, left_f_id); - graph.add_jump(entry, right_f_id); - let left_f = graph.get_or_add_basic_block(left_f_id); - graph.set_modes(left_f, ftz(), ftz()); - let right_f = graph.get_or_add_basic_block(right_f_id); - graph.set_modes(right_f, ftz(), ftz()); - graph.add_jump(left_f, left_none_id); - let left_none = graph.get_or_add_basic_block(left_none_id); - graph.add_jump(right_f, right_none_id); - let right_none = graph.get_or_add_basic_block(right_none_id); - graph.add_jump(left_none, mid_none_id); - graph.add_jump(right_none, mid_none_id); - let mid_none = graph.get_or_add_basic_block(mid_none_id); - graph.add_jump(mid_none, left_none_id); - graph.add_jump(mid_none, right_none_id); - //println!( - // "{:?}", - // petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel]) - //); - let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); - assert_eq!(partial_result.bb_must_insert_mode.len(), 0); - assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2); - assert_eq!( - partial_result.bb_maybe_insert_mode[&left_f_id], - (DenormalMode::FlushToZero, iter::once(entry_id).collect()) - ); - assert_eq!( - partial_result.bb_maybe_insert_mode[&right_f_id], - (DenormalMode::FlushToZero, iter::once(entry_id).collect()) - ); - - let result = optimize::(partial_result); - assert_eq!(result.basic_blocks.len(), 0); - assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); - } -} +mod test; \ No newline at end of file diff --git a/ptx/src/pass/insert_ftz_control/test.rs b/ptx/src/pass/insert_ftz_control/test.rs new file mode 100644 index 0000000..33ab51c --- /dev/null +++ b/ptx/src/pass/insert_ftz_control/test.rs @@ -0,0 +1,230 @@ +use super::*; +use int_enum::IntEnum; +use strum::EnumCount; + +#[repr(usize)] +#[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)] +enum Bool { + False = 0, + True = 1, +} + +fn ftz() -> InstructionModes { + InstructionModes { + denormal_f32: Some(DenormalMode::FlushToZero), + denormal_f16f64: None, + rounding_f32: None, + rounding_f16f64: None, + } +} + +fn preserve() -> InstructionModes { + InstructionModes { + denormal_f32: Some(DenormalMode::Preserve), + denormal_f16f64: None, + rounding_f32: None, + rounding_f16f64: None, + } +} + +#[test] +fn transitive_mixed() { + let mut graph = ControlFlowGraph::new(); + let entry_id = SpirvWord(1); + let false_id = SpirvWord(2); + let empty_id = SpirvWord(3); + let false2_id = SpirvWord(4); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, false_id); + let false_ = graph.get_or_add_basic_block(false_id); + graph.set_modes(false_, ftz(), ftz()); + graph.add_jump(false_, empty_id); + let empty = graph.get_or_add_basic_block(empty_id); + graph.add_jump(empty, false2_id); + let false2_ = graph.get_or_add_basic_block(false2_id); + graph.set_modes(false2_, ftz(), ftz()); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); + assert_eq!(partial_result.bb_must_insert_mode.len(), 0); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); + assert_eq!( + partial_result.bb_maybe_insert_mode[&false_id], + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks.len(), 0); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); +} + +#[test] +fn transitive_change_twice() { + let mut graph = ControlFlowGraph::new(); + let entry_id = SpirvWord(1); + let false_id = SpirvWord(2); + let empty_id = SpirvWord(3); + let true_id = SpirvWord(4); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, false_id); + let false_ = graph.get_or_add_basic_block(false_id); + graph.set_modes(false_, ftz(), ftz()); + graph.add_jump(false_, empty_id); + let empty = graph.get_or_add_basic_block(empty_id); + graph.add_jump(empty, true_id); + let true_ = graph.get_or_add_basic_block(true_id); + graph.set_modes(true_, preserve(), preserve()); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); + assert_eq!(partial_result.bb_must_insert_mode.len(), 1); + assert!(partial_result.bb_must_insert_mode.contains(&true_id)); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); + assert_eq!( + partial_result.bb_maybe_insert_mode[&false_id], + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks, iter::once(true_id).collect()); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); +} + +#[test] +fn transitive_change() { + let mut graph = ControlFlowGraph::new(); + let entry_id = SpirvWord(1); + let empty_id = SpirvWord(2); + let true_id = SpirvWord(3); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, empty_id); + let empty = graph.get_or_add_basic_block(empty_id); + graph.add_jump(empty, true_id); + let true_ = graph.get_or_add_basic_block(true_id); + graph.set_modes(true_, preserve(), preserve()); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); + assert_eq!(partial_result.bb_must_insert_mode.len(), 0); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); + assert_eq!( + partial_result.bb_maybe_insert_mode[&true_id], + (DenormalMode::Preserve, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks.len(), 0); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], DenormalMode::Preserve); +} + +#[test] +fn codependency() { + let mut graph = ControlFlowGraph::new(); + let entry_id = SpirvWord(1); + let left_f_id = SpirvWord(2); + let right_f_id = SpirvWord(3); + let left_none_id = SpirvWord(4); + let mid_none_id = SpirvWord(5); + let right_none_id = SpirvWord(6); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, left_f_id); + graph.add_jump(entry, right_f_id); + let left_f = graph.get_or_add_basic_block(left_f_id); + graph.set_modes(left_f, ftz(), ftz()); + let right_f = graph.get_or_add_basic_block(right_f_id); + graph.set_modes(right_f, ftz(), ftz()); + graph.add_jump(left_f, left_none_id); + let left_none = graph.get_or_add_basic_block(left_none_id); + graph.add_jump(right_f, right_none_id); + let right_none = graph.get_or_add_basic_block(right_none_id); + graph.add_jump(left_none, mid_none_id); + graph.add_jump(right_none, mid_none_id); + let mid_none = graph.get_or_add_basic_block(mid_none_id); + graph.add_jump(mid_none, left_none_id); + graph.add_jump(mid_none, right_none_id); + //println!( + // "{:?}", + // petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel]) + //); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); + assert_eq!(partial_result.bb_must_insert_mode.len(), 0); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2); + assert_eq!( + partial_result.bb_maybe_insert_mode[&left_f_id], + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) + ); + assert_eq!( + partial_result.bb_maybe_insert_mode[&right_f_id], + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks.len(), 0); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); +} + +static FOLD_DENORMAL_PTX: &'static str = include_str!("fold_denormal.ptx"); + +#[test] +fn fold_denormal() { + let method = compile_methods(FOLD_DENORMAL_PTX).pop().unwrap(); + assert_eq!(true, method.flush_to_zero_f32); + assert_eq!(true, method.flush_to_zero_f16f64); + let method_body = method.body.unwrap(); + assert!(matches!( + &*method_body, + [ + Statement::Label(..), + Statement::Variable(..), + Statement::Variable(..), + Statement::Variable(..), + Statement::Instruction(ast::Instruction::Add { .. }), + Statement::Instruction(ast::Instruction::Add { .. }), + Statement::SetMode(ModeRegister::Denormal { + f32: false, + f16f64: false + }), + Statement::Instruction(ast::Instruction::Add { .. }), + Statement::Instruction(ast::Instruction::Add { .. }), + Statement::Instruction(ast::Instruction::Ret { .. }), + ] + )); +} + +fn compile_methods(ptx: &str) -> Vec, SpirvWord>> { + use crate::pass::*; + + let module = ptx_parser::parse_module_checked(ptx).unwrap(); + let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); + let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); + let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap(); + let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); + let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); + let directives = normalize_basic_blocks::run(&mut flat_resolver, directives); + let directives = super::run(&mut flat_resolver, directives).unwrap(); + directives + .into_iter() + .filter_map(|s| match s { + Directive2::Method(m) => Some(m), + _ => None, + }) + .collect::>() +} + +static CALL_WITH_MODE_PTX: &'static str = include_str!("call_with_mode.ptx"); + +#[test] +fn call_with_mode() { + let methods = compile_methods(CALL_WITH_MODE_PTX); + assert!(matches!(methods[0].body, None)); + + assert!(matches!( + &**methods[1].body.as_ref().unwrap(), + [ + Statement::Label(..), + Statement::SetMode(ModeRegister::Denormal { + f32: false, + f16f64: false + }), + Statement::Instruction(ast::Instruction::Ret { .. }), + ] + )); +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 1a094fb..d6e9aa4 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -202,16 +202,14 @@ enum Statement { SetMode(ModeRegister), } +#[derive(Eq, PartialEq, Clone, Copy)] +#[cfg_attr(test, derive(Debug))] enum ModeRegister { - DenormalF32(bool), - DenormalF16F64(bool), - DenormalBoth { + Denormal { f32: bool, f16f64: bool, }, - RoundingF32(ast::RoundingMode), - RoundingF16F64(ast::RoundingMode), - RoundingBoth { + Rounding { f32: ast::RoundingMode, f16f64: ast::RoundingMode, }, @@ -594,8 +592,8 @@ struct Function2 { linkage: ast::LinkingDirective, flush_to_zero_f32: bool, flush_to_zero_f16f64: bool, - roundind_mode_f32: ast::RoundingMode, - roundind_mode_f16f64: ast::RoundingMode, + rounding_mode_f32: ast::RoundingMode, + rounding_mode_f16f64: ast::RoundingMode, } type NormalizedDirective2 = Directive2< diff --git a/ptx/src/pass/normalize_basic_blocks.rs b/ptx/src/pass/normalize_basic_blocks.rs index c87a8ad..1591bad 100644 --- a/ptx/src/pass/normalize_basic_blocks.rs +++ b/ptx/src/pass/normalize_basic_blocks.rs @@ -1,7 +1,7 @@ use super::*; // This pass normalized ptx modules in two ways that makes mode computation pass -// and code emissions passes much simpler: +// and code emissions passes much simpler: // * Inserts label at the start of every function // This makes control flow graph simpler in mode computation block: we can // represent kernels as separate nodes with its own separate entry/exit mode @@ -46,6 +46,9 @@ fn is_block_terminator(instruction: &Statement, Spir match instruction { Statement::Conditional(..) | Statement::Instruction(ast::Instruction::Bra { .. }) + // Normally call is not a terminator, but we treat it as such because it + // makes the instruction modes to global modes pass possible + | Statement::Instruction(ast::Instruction::Call { .. }) | Statement::Instruction(ast::Instruction::Ret { .. }) => true, _ => false, } diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index f5ef55c..810ef3e 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -57,8 +57,8 @@ fn run_method<'input, 'b>( tuning: method.tuning, flush_to_zero_f32: false, flush_to_zero_f16f64: false, - roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, - roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, }) } diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs index f8be688..ae41021 100644 --- a/ptx/src/pass/normalize_predicates2.rs +++ b/ptx/src/pass/normalize_predicates2.rs @@ -46,8 +46,8 @@ fn run_method<'input>( is_kernel: method.is_kernel, flush_to_zero_f32: method.flush_to_zero_f32, flush_to_zero_f16f64: method.flush_to_zero_f16f64, - roundind_mode_f32: method.roundind_mode_f32, - roundind_mode_f16f64: method.roundind_mode_f16f64, + rounding_mode_f32: method.rounding_mode_f32, + rounding_mode_f16f64: method.rounding_mode_f16f64, }) } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index f54c134..0f9311a 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -23,8 +23,8 @@ pub(super) fn run<'input>( is_kernel: false, flush_to_zero_f32: false, flush_to_zero_f16f64: false, - roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, - roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, }) }) .collect::>(); diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 4d9f23d..55b950a 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1049,7 +1049,7 @@ pub enum LdStQualifier { Release(MemScope), } -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Debug)] pub enum RoundingMode { NearestEven, Zero,