Implement most of twin modes and mode-changing jumps

This commit is contained in:
Andrzej Janik 2025-03-08 00:54:31 +00:00
parent 2b65701f02
commit 7bd26aa480
14 changed files with 888 additions and 367 deletions

View file

@ -20,6 +20,8 @@ strum_macros = "0.26"
petgraph = "0.7.1" petgraph = "0.7.1"
microlp = "0.2.10" microlp = "0.2.10"
int-enum = "1.1" int-enum = "1.1"
smallvec = "1.13"
unwrap_or = "1.0.1"
[dev-dependencies] [dev-dependencies]
hip_runtime-sys = { path = "../ext/hip_runtime-sys" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" }

View file

@ -2264,17 +2264,7 @@ impl<'a> MethodEmitContext<'a> {
let intrinsic = c"llvm.amdgcn.s.setreg"; let intrinsic = c"llvm.amdgcn.s.setreg";
let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32);
let (hwreg, value) = match mode_reg { let (hwreg, value) = match mode_reg {
ModeRegister::DenormalF32(ftz) => { ModeRegister::Denormal { f32, f16f64 } => {
let (reg, offset, size) = (1, 4, 2u32);
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
(hwreg, if ftz { 0u32 } else { 3 })
}
ModeRegister::DenormalF16F64(ftz) => {
let (reg, offset, size) = (1, 6, 2u32);
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
(hwreg, if ftz { 0 } else { 3 })
}
ModeRegister::DenormalBoth { f32, f16f64 } => {
let (reg, offset, size) = (1, 4, 4u32); let (reg, offset, size) = (1, 4, 4u32);
let hwreg = reg | (offset << 6) | ((size - 1) << 11); let hwreg = reg | (offset << 6) | ((size - 1) << 11);
let f32 = if f32 { 0 } else { 3 }; let f32 = if f32 { 0 } else { 3 };
@ -2282,9 +2272,7 @@ impl<'a> MethodEmitContext<'a> {
let value = f32 | f16f64 << 2; let value = f32 | f16f64 << 2;
(hwreg, value) (hwreg, value)
} }
ModeRegister::RoundingF32(rounding_mode) => todo!(), ModeRegister::Rounding { .. } => todo!(),
ModeRegister::RoundingF16F64(rounding_mode) => todo!(),
ModeRegister::RoundingBoth { f32, f16f64 } => todo!(),
}; };
let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) }; let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) };
let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) }; let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) };

View file

@ -51,8 +51,8 @@ fn run_method<'input>(
is_kernel: method.is_kernel, is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32, flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64, flush_to_zero_f16f64: method.flush_to_zero_f16f64,
roundind_mode_f32: method.roundind_mode_f32, rounding_mode_f32: method.rounding_mode_f32,
roundind_mode_f16f64: method.roundind_mode_f16f64, rounding_mode_f16f64: method.rounding_mode_f16f64,
}) })
} }

View file

@ -22,8 +22,8 @@ pub(super) fn run<'a, 'input>(
is_kernel: false, is_kernel: false,
flush_to_zero_f32: false, flush_to_zero_f32: false,
flush_to_zero_f16f64: false, flush_to_zero_f16f64: false,
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})); }));
sreg_to_function.insert(sreg, name); sreg_to_function.insert(sreg, name);
}, },

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_50
.address_size 64
.func use_modes();
.visible .entry kernel()
{
.reg .f32 temp;
add.rz.ftz.f32 temp, temp, temp;
call use_modes;
ret;
}
.func use_modes()
{
.reg .f32 temp;
add.rm.f32 temp, temp, temp;
ret;
}

View file

@ -0,0 +1,15 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry add()
{
.reg .f32 temp<3>;
add.ftz.f16 temp2, temp1, temp0;
add.ftz.f32 temp2, temp1, temp0;
add.f16 temp2, temp1, temp0;
add.f32 temp2, temp1, temp0;
ret;
}

View file

@ -0,0 +1,230 @@
use super::*;
use int_enum::IntEnum;
use strum::EnumCount;
#[repr(usize)]
#[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)]
enum Bool {
False = 0,
True = 1,
}
fn ftz() -> InstructionModes {
InstructionModes {
denormal_f32: Some(DenormalMode::FlushToZero),
denormal_f16f64: None,
rounding_f32: None,
rounding_f16f64: None,
}
}
fn preserve() -> InstructionModes {
InstructionModes {
denormal_f32: Some(DenormalMode::Preserve),
denormal_f16f64: None,
rounding_f32: None,
rounding_f16f64: None,
}
}
#[test]
fn transitive_mixed() {
let mut graph = ControlFlowGraph::new();
let entry_id = SpirvWord(1);
let false_id = SpirvWord(2);
let empty_id = SpirvWord(3);
let false2_id = SpirvWord(4);
let entry = graph.add_entry_basic_block(entry_id);
graph.add_jump(entry, false_id);
let false_ = graph.get_or_add_basic_block(false_id);
graph.set_modes(false_, ftz(), ftz());
graph.add_jump(false_, empty_id);
let empty = graph.get_or_add_basic_block(empty_id);
graph.add_jump(empty, false2_id);
let false2_ = graph.get_or_add_basic_block(false2_id);
graph.set_modes(false2_, ftz(), ftz());
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
assert_eq!(
partial_result.bb_maybe_insert_mode[&false_id],
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
);
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
assert_eq!(result.basic_blocks.len(), 0);
assert_eq!(result.kernels.len(), 1);
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
}
#[test]
fn transitive_change_twice() {
let mut graph = ControlFlowGraph::new();
let entry_id = SpirvWord(1);
let false_id = SpirvWord(2);
let empty_id = SpirvWord(3);
let true_id = SpirvWord(4);
let entry = graph.add_entry_basic_block(entry_id);
graph.add_jump(entry, false_id);
let false_ = graph.get_or_add_basic_block(false_id);
graph.set_modes(false_, ftz(), ftz());
graph.add_jump(false_, empty_id);
let empty = graph.get_or_add_basic_block(empty_id);
graph.add_jump(empty, true_id);
let true_ = graph.get_or_add_basic_block(true_id);
graph.set_modes(true_, preserve(), preserve());
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
assert_eq!(partial_result.bb_must_insert_mode.len(), 1);
assert!(partial_result.bb_must_insert_mode.contains(&true_id));
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
assert_eq!(
partial_result.bb_maybe_insert_mode[&false_id],
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
);
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
assert_eq!(result.basic_blocks, iter::once(true_id).collect());
assert_eq!(result.kernels.len(), 1);
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
}
#[test]
fn transitive_change() {
let mut graph = ControlFlowGraph::new();
let entry_id = SpirvWord(1);
let empty_id = SpirvWord(2);
let true_id = SpirvWord(3);
let entry = graph.add_entry_basic_block(entry_id);
graph.add_jump(entry, empty_id);
let empty = graph.get_or_add_basic_block(empty_id);
graph.add_jump(empty, true_id);
let true_ = graph.get_or_add_basic_block(true_id);
graph.set_modes(true_, preserve(), preserve());
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
assert_eq!(
partial_result.bb_maybe_insert_mode[&true_id],
(DenormalMode::Preserve, iter::once(entry_id).collect())
);
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
assert_eq!(result.basic_blocks.len(), 0);
assert_eq!(result.kernels.len(), 1);
assert_eq!(result.kernels[&entry_id], DenormalMode::Preserve);
}
#[test]
fn codependency() {
let mut graph = ControlFlowGraph::new();
let entry_id = SpirvWord(1);
let left_f_id = SpirvWord(2);
let right_f_id = SpirvWord(3);
let left_none_id = SpirvWord(4);
let mid_none_id = SpirvWord(5);
let right_none_id = SpirvWord(6);
let entry = graph.add_entry_basic_block(entry_id);
graph.add_jump(entry, left_f_id);
graph.add_jump(entry, right_f_id);
let left_f = graph.get_or_add_basic_block(left_f_id);
graph.set_modes(left_f, ftz(), ftz());
let right_f = graph.get_or_add_basic_block(right_f_id);
graph.set_modes(right_f, ftz(), ftz());
graph.add_jump(left_f, left_none_id);
let left_none = graph.get_or_add_basic_block(left_none_id);
graph.add_jump(right_f, right_none_id);
let right_none = graph.get_or_add_basic_block(right_none_id);
graph.add_jump(left_none, mid_none_id);
graph.add_jump(right_none, mid_none_id);
let mid_none = graph.get_or_add_basic_block(mid_none_id);
graph.add_jump(mid_none, left_none_id);
graph.add_jump(mid_none, right_none_id);
//println!(
// "{:?}",
// petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel])
//);
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2);
assert_eq!(
partial_result.bb_maybe_insert_mode[&left_f_id],
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
);
assert_eq!(
partial_result.bb_maybe_insert_mode[&right_f_id],
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
);
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
assert_eq!(result.basic_blocks.len(), 0);
assert_eq!(result.kernels.len(), 1);
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
}
static FOLD_DENORMAL_PTX: &'static str = include_str!("fold_denormal.ptx");
#[test]
fn fold_denormal() {
let method = compile_methods(FOLD_DENORMAL_PTX).pop().unwrap();
assert_eq!(true, method.flush_to_zero_f32);
assert_eq!(true, method.flush_to_zero_f16f64);
let method_body = method.body.unwrap();
assert!(matches!(
&*method_body,
[
Statement::Label(..),
Statement::Variable(..),
Statement::Variable(..),
Statement::Variable(..),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::SetMode(ModeRegister::Denormal {
f32: false,
f16f64: false
}),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Ret { .. }),
]
));
}
fn compile_methods(ptx: &str) -> Vec<Function2<ast::Instruction<SpirvWord>, SpirvWord>> {
use crate::pass::*;
let module = ptx_parser::parse_module_checked(ptx).unwrap();
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap();
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives);
let directives = super::run(&mut flat_resolver, directives).unwrap();
directives
.into_iter()
.filter_map(|s| match s {
Directive2::Method(m) => Some(m),
_ => None,
})
.collect::<Vec<_>>()
}
static CALL_WITH_MODE_PTX: &'static str = include_str!("call_with_mode.ptx");
#[test]
fn call_with_mode() {
let methods = compile_methods(CALL_WITH_MODE_PTX);
assert!(matches!(methods[0].body, None));
assert!(matches!(
&**methods[1].body.as_ref().unwrap(),
[
Statement::Label(..),
Statement::SetMode(ModeRegister::Denormal {
f32: false,
f16f64: false
}),
Statement::Instruction(ast::Instruction::Ret { .. }),
]
));
}

View file

@ -202,16 +202,14 @@ enum Statement<I, P: ast::Operand> {
SetMode(ModeRegister), SetMode(ModeRegister),
} }
#[derive(Eq, PartialEq, Clone, Copy)]
#[cfg_attr(test, derive(Debug))]
enum ModeRegister { enum ModeRegister {
DenormalF32(bool), Denormal {
DenormalF16F64(bool),
DenormalBoth {
f32: bool, f32: bool,
f16f64: bool, f16f64: bool,
}, },
RoundingF32(ast::RoundingMode), Rounding {
RoundingF16F64(ast::RoundingMode),
RoundingBoth {
f32: ast::RoundingMode, f32: ast::RoundingMode,
f16f64: ast::RoundingMode, f16f64: ast::RoundingMode,
}, },
@ -594,8 +592,8 @@ struct Function2<Instruction, Operand: ast::Operand> {
linkage: ast::LinkingDirective, linkage: ast::LinkingDirective,
flush_to_zero_f32: bool, flush_to_zero_f32: bool,
flush_to_zero_f16f64: bool, flush_to_zero_f16f64: bool,
roundind_mode_f32: ast::RoundingMode, rounding_mode_f32: ast::RoundingMode,
roundind_mode_f16f64: ast::RoundingMode, rounding_mode_f16f64: ast::RoundingMode,
} }
type NormalizedDirective2 = Directive2< type NormalizedDirective2 = Directive2<

View file

@ -1,7 +1,7 @@
use super::*; use super::*;
// This pass normalized ptx modules in two ways that makes mode computation pass // This pass normalized ptx modules in two ways that makes mode computation pass
// and code emissions passes much simpler: // and code emissions passes much simpler:
// * Inserts label at the start of every function // * Inserts label at the start of every function
// This makes control flow graph simpler in mode computation block: we can // This makes control flow graph simpler in mode computation block: we can
// represent kernels as separate nodes with its own separate entry/exit mode // represent kernels as separate nodes with its own separate entry/exit mode
@ -46,6 +46,9 @@ fn is_block_terminator(instruction: &Statement<ast::Instruction<SpirvWord>, Spir
match instruction { match instruction {
Statement::Conditional(..) Statement::Conditional(..)
| Statement::Instruction(ast::Instruction::Bra { .. }) | Statement::Instruction(ast::Instruction::Bra { .. })
// Normally call is not a terminator, but we treat it as such because it
// makes the instruction modes to global modes pass possible
| Statement::Instruction(ast::Instruction::Call { .. })
| Statement::Instruction(ast::Instruction::Ret { .. }) => true, | Statement::Instruction(ast::Instruction::Ret { .. }) => true,
_ => false, _ => false,
} }

View file

@ -57,8 +57,8 @@ fn run_method<'input, 'b>(
tuning: method.tuning, tuning: method.tuning,
flush_to_zero_f32: false, flush_to_zero_f32: false,
flush_to_zero_f16f64: false, flush_to_zero_f16f64: false,
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}) })
} }

View file

@ -46,8 +46,8 @@ fn run_method<'input>(
is_kernel: method.is_kernel, is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32, flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64, flush_to_zero_f16f64: method.flush_to_zero_f16f64,
roundind_mode_f32: method.roundind_mode_f32, rounding_mode_f32: method.rounding_mode_f32,
roundind_mode_f16f64: method.roundind_mode_f16f64, rounding_mode_f16f64: method.rounding_mode_f16f64,
}) })
} }

View file

@ -23,8 +23,8 @@ pub(super) fn run<'input>(
is_kernel: false, is_kernel: false,
flush_to_zero_f32: false, flush_to_zero_f32: false,
flush_to_zero_f16f64: false, flush_to_zero_f16f64: false,
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}) })
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View file

@ -1049,7 +1049,7 @@ pub enum LdStQualifier {
Release(MemScope), Release(MemScope),
} }
#[derive(PartialEq, Eq, Copy, Clone)] #[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub enum RoundingMode { pub enum RoundingMode {
NearestEven, NearestEven,
Zero, Zero,