mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 13:37:57 +03:00
Implement most of twin modes and mode-changing jumps
This commit is contained in:
parent
2b65701f02
commit
7bd26aa480
14 changed files with 888 additions and 367 deletions
|
@ -20,6 +20,8 @@ strum_macros = "0.26"
|
|||
petgraph = "0.7.1"
|
||||
microlp = "0.2.10"
|
||||
int-enum = "1.1"
|
||||
smallvec = "1.13"
|
||||
unwrap_or = "1.0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
|
|
|
@ -2264,17 +2264,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let intrinsic = c"llvm.amdgcn.s.setreg";
|
||||
let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32);
|
||||
let (hwreg, value) = match mode_reg {
|
||||
ModeRegister::DenormalF32(ftz) => {
|
||||
let (reg, offset, size) = (1, 4, 2u32);
|
||||
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||
(hwreg, if ftz { 0u32 } else { 3 })
|
||||
}
|
||||
ModeRegister::DenormalF16F64(ftz) => {
|
||||
let (reg, offset, size) = (1, 6, 2u32);
|
||||
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||
(hwreg, if ftz { 0 } else { 3 })
|
||||
}
|
||||
ModeRegister::DenormalBoth { f32, f16f64 } => {
|
||||
ModeRegister::Denormal { f32, f16f64 } => {
|
||||
let (reg, offset, size) = (1, 4, 4u32);
|
||||
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||
let f32 = if f32 { 0 } else { 3 };
|
||||
|
@ -2282,9 +2272,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
let value = f32 | f16f64 << 2;
|
||||
(hwreg, value)
|
||||
}
|
||||
ModeRegister::RoundingF32(rounding_mode) => todo!(),
|
||||
ModeRegister::RoundingF16F64(rounding_mode) => todo!(),
|
||||
ModeRegister::RoundingBoth { f32, f16f64 } => todo!(),
|
||||
ModeRegister::Rounding { .. } => todo!(),
|
||||
};
|
||||
let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) };
|
||||
let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) };
|
||||
|
|
|
@ -51,8 +51,8 @@ fn run_method<'input>(
|
|||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
roundind_mode_f32: method.roundind_mode_f32,
|
||||
roundind_mode_f16f64: method.roundind_mode_f16f64,
|
||||
rounding_mode_f32: method.rounding_mode_f32,
|
||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -22,8 +22,8 @@ pub(super) fn run<'a, 'input>(
|
|||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
}));
|
||||
sreg_to_function.insert(sreg, name);
|
||||
},
|
||||
|
|
21
ptx/src/pass/insert_ftz_control/call_with_mode.ptx
Normal file
21
ptx/src/pass/insert_ftz_control/call_with_mode.ptx
Normal file
|
@ -0,0 +1,21 @@
|
|||
.version 6.5
|
||||
.target sm_50
|
||||
.address_size 64
|
||||
|
||||
.func use_modes();
|
||||
|
||||
.visible .entry kernel()
|
||||
{
|
||||
.reg .f32 temp;
|
||||
|
||||
add.rz.ftz.f32 temp, temp, temp;
|
||||
call use_modes;
|
||||
ret;
|
||||
}
|
||||
|
||||
.func use_modes()
|
||||
{
|
||||
.reg .f32 temp;
|
||||
add.rm.f32 temp, temp, temp;
|
||||
ret;
|
||||
}
|
15
ptx/src/pass/insert_ftz_control/fold_denormal.ptx
Normal file
15
ptx/src/pass/insert_ftz_control/fold_denormal.ptx
Normal file
|
@ -0,0 +1,15 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry add()
|
||||
{
|
||||
.reg .f32 temp<3>;
|
||||
|
||||
add.ftz.f16 temp2, temp1, temp0;
|
||||
add.ftz.f32 temp2, temp1, temp0;
|
||||
|
||||
add.f16 temp2, temp1, temp0;
|
||||
add.f32 temp2, temp1, temp0;
|
||||
ret;
|
||||
}
|
File diff suppressed because it is too large
Load diff
230
ptx/src/pass/insert_ftz_control/test.rs
Normal file
230
ptx/src/pass/insert_ftz_control/test.rs
Normal file
|
@ -0,0 +1,230 @@
|
|||
use super::*;
|
||||
use int_enum::IntEnum;
|
||||
use strum::EnumCount;
|
||||
|
||||
#[repr(usize)]
|
||||
#[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)]
|
||||
enum Bool {
|
||||
False = 0,
|
||||
True = 1,
|
||||
}
|
||||
|
||||
fn ftz() -> InstructionModes {
|
||||
InstructionModes {
|
||||
denormal_f32: Some(DenormalMode::FlushToZero),
|
||||
denormal_f16f64: None,
|
||||
rounding_f32: None,
|
||||
rounding_f16f64: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn preserve() -> InstructionModes {
|
||||
InstructionModes {
|
||||
denormal_f32: Some(DenormalMode::Preserve),
|
||||
denormal_f16f64: None,
|
||||
rounding_f32: None,
|
||||
rounding_f16f64: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transitive_mixed() {
|
||||
let mut graph = ControlFlowGraph::new();
|
||||
let entry_id = SpirvWord(1);
|
||||
let false_id = SpirvWord(2);
|
||||
let empty_id = SpirvWord(3);
|
||||
let false2_id = SpirvWord(4);
|
||||
let entry = graph.add_entry_basic_block(entry_id);
|
||||
graph.add_jump(entry, false_id);
|
||||
let false_ = graph.get_or_add_basic_block(false_id);
|
||||
graph.set_modes(false_, ftz(), ftz());
|
||||
graph.add_jump(false_, empty_id);
|
||||
let empty = graph.get_or_add_basic_block(empty_id);
|
||||
graph.add_jump(empty, false2_id);
|
||||
let false2_ = graph.get_or_add_basic_block(false2_id);
|
||||
graph.set_modes(false2_, ftz(), ftz());
|
||||
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
||||
assert_eq!(
|
||||
partial_result.bb_maybe_insert_mode[&false_id],
|
||||
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||
);
|
||||
|
||||
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||
assert_eq!(result.basic_blocks.len(), 0);
|
||||
assert_eq!(result.kernels.len(), 1);
|
||||
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transitive_change_twice() {
|
||||
let mut graph = ControlFlowGraph::new();
|
||||
let entry_id = SpirvWord(1);
|
||||
let false_id = SpirvWord(2);
|
||||
let empty_id = SpirvWord(3);
|
||||
let true_id = SpirvWord(4);
|
||||
let entry = graph.add_entry_basic_block(entry_id);
|
||||
graph.add_jump(entry, false_id);
|
||||
let false_ = graph.get_or_add_basic_block(false_id);
|
||||
graph.set_modes(false_, ftz(), ftz());
|
||||
graph.add_jump(false_, empty_id);
|
||||
let empty = graph.get_or_add_basic_block(empty_id);
|
||||
graph.add_jump(empty, true_id);
|
||||
let true_ = graph.get_or_add_basic_block(true_id);
|
||||
graph.set_modes(true_, preserve(), preserve());
|
||||
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 1);
|
||||
assert!(partial_result.bb_must_insert_mode.contains(&true_id));
|
||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
||||
assert_eq!(
|
||||
partial_result.bb_maybe_insert_mode[&false_id],
|
||||
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||
);
|
||||
|
||||
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||
assert_eq!(result.basic_blocks, iter::once(true_id).collect());
|
||||
assert_eq!(result.kernels.len(), 1);
|
||||
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transitive_change() {
|
||||
let mut graph = ControlFlowGraph::new();
|
||||
let entry_id = SpirvWord(1);
|
||||
let empty_id = SpirvWord(2);
|
||||
let true_id = SpirvWord(3);
|
||||
let entry = graph.add_entry_basic_block(entry_id);
|
||||
graph.add_jump(entry, empty_id);
|
||||
let empty = graph.get_or_add_basic_block(empty_id);
|
||||
graph.add_jump(empty, true_id);
|
||||
let true_ = graph.get_or_add_basic_block(true_id);
|
||||
graph.set_modes(true_, preserve(), preserve());
|
||||
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
||||
assert_eq!(
|
||||
partial_result.bb_maybe_insert_mode[&true_id],
|
||||
(DenormalMode::Preserve, iter::once(entry_id).collect())
|
||||
);
|
||||
|
||||
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||
assert_eq!(result.basic_blocks.len(), 0);
|
||||
assert_eq!(result.kernels.len(), 1);
|
||||
assert_eq!(result.kernels[&entry_id], DenormalMode::Preserve);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn codependency() {
|
||||
let mut graph = ControlFlowGraph::new();
|
||||
let entry_id = SpirvWord(1);
|
||||
let left_f_id = SpirvWord(2);
|
||||
let right_f_id = SpirvWord(3);
|
||||
let left_none_id = SpirvWord(4);
|
||||
let mid_none_id = SpirvWord(5);
|
||||
let right_none_id = SpirvWord(6);
|
||||
let entry = graph.add_entry_basic_block(entry_id);
|
||||
graph.add_jump(entry, left_f_id);
|
||||
graph.add_jump(entry, right_f_id);
|
||||
let left_f = graph.get_or_add_basic_block(left_f_id);
|
||||
graph.set_modes(left_f, ftz(), ftz());
|
||||
let right_f = graph.get_or_add_basic_block(right_f_id);
|
||||
graph.set_modes(right_f, ftz(), ftz());
|
||||
graph.add_jump(left_f, left_none_id);
|
||||
let left_none = graph.get_or_add_basic_block(left_none_id);
|
||||
graph.add_jump(right_f, right_none_id);
|
||||
let right_none = graph.get_or_add_basic_block(right_none_id);
|
||||
graph.add_jump(left_none, mid_none_id);
|
||||
graph.add_jump(right_none, mid_none_id);
|
||||
let mid_none = graph.get_or_add_basic_block(mid_none_id);
|
||||
graph.add_jump(mid_none, left_none_id);
|
||||
graph.add_jump(mid_none, right_none_id);
|
||||
//println!(
|
||||
// "{:?}",
|
||||
// petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
||||
//);
|
||||
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2);
|
||||
assert_eq!(
|
||||
partial_result.bb_maybe_insert_mode[&left_f_id],
|
||||
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||
);
|
||||
assert_eq!(
|
||||
partial_result.bb_maybe_insert_mode[&right_f_id],
|
||||
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||
);
|
||||
|
||||
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||
assert_eq!(result.basic_blocks.len(), 0);
|
||||
assert_eq!(result.kernels.len(), 1);
|
||||
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
|
||||
}
|
||||
|
||||
static FOLD_DENORMAL_PTX: &'static str = include_str!("fold_denormal.ptx");
|
||||
|
||||
#[test]
|
||||
fn fold_denormal() {
|
||||
let method = compile_methods(FOLD_DENORMAL_PTX).pop().unwrap();
|
||||
assert_eq!(true, method.flush_to_zero_f32);
|
||||
assert_eq!(true, method.flush_to_zero_f16f64);
|
||||
let method_body = method.body.unwrap();
|
||||
assert!(matches!(
|
||||
&*method_body,
|
||||
[
|
||||
Statement::Label(..),
|
||||
Statement::Variable(..),
|
||||
Statement::Variable(..),
|
||||
Statement::Variable(..),
|
||||
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||
Statement::SetMode(ModeRegister::Denormal {
|
||||
f32: false,
|
||||
f16f64: false
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||
Statement::Instruction(ast::Instruction::Ret { .. }),
|
||||
]
|
||||
));
|
||||
}
|
||||
|
||||
fn compile_methods(ptx: &str) -> Vec<Function2<ast::Instruction<SpirvWord>, SpirvWord>> {
|
||||
use crate::pass::*;
|
||||
|
||||
let module = ptx_parser::parse_module_checked(ptx).unwrap();
|
||||
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
|
||||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap();
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
|
||||
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
|
||||
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives);
|
||||
let directives = super::run(&mut flat_resolver, directives).unwrap();
|
||||
directives
|
||||
.into_iter()
|
||||
.filter_map(|s| match s {
|
||||
Directive2::Method(m) => Some(m),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
static CALL_WITH_MODE_PTX: &'static str = include_str!("call_with_mode.ptx");
|
||||
|
||||
#[test]
|
||||
fn call_with_mode() {
|
||||
let methods = compile_methods(CALL_WITH_MODE_PTX);
|
||||
assert!(matches!(methods[0].body, None));
|
||||
|
||||
assert!(matches!(
|
||||
&**methods[1].body.as_ref().unwrap(),
|
||||
[
|
||||
Statement::Label(..),
|
||||
Statement::SetMode(ModeRegister::Denormal {
|
||||
f32: false,
|
||||
f16f64: false
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Ret { .. }),
|
||||
]
|
||||
));
|
||||
}
|
|
@ -202,16 +202,14 @@ enum Statement<I, P: ast::Operand> {
|
|||
SetMode(ModeRegister),
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Clone, Copy)]
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
enum ModeRegister {
|
||||
DenormalF32(bool),
|
||||
DenormalF16F64(bool),
|
||||
DenormalBoth {
|
||||
Denormal {
|
||||
f32: bool,
|
||||
f16f64: bool,
|
||||
},
|
||||
RoundingF32(ast::RoundingMode),
|
||||
RoundingF16F64(ast::RoundingMode),
|
||||
RoundingBoth {
|
||||
Rounding {
|
||||
f32: ast::RoundingMode,
|
||||
f16f64: ast::RoundingMode,
|
||||
},
|
||||
|
@ -594,8 +592,8 @@ struct Function2<Instruction, Operand: ast::Operand> {
|
|||
linkage: ast::LinkingDirective,
|
||||
flush_to_zero_f32: bool,
|
||||
flush_to_zero_f16f64: bool,
|
||||
roundind_mode_f32: ast::RoundingMode,
|
||||
roundind_mode_f16f64: ast::RoundingMode,
|
||||
rounding_mode_f32: ast::RoundingMode,
|
||||
rounding_mode_f16f64: ast::RoundingMode,
|
||||
}
|
||||
|
||||
type NormalizedDirective2 = Directive2<
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::*;
|
||||
|
||||
// This pass normalized ptx modules in two ways that makes mode computation pass
|
||||
// and code emissions passes much simpler:
|
||||
// and code emissions passes much simpler:
|
||||
// * Inserts label at the start of every function
|
||||
// This makes control flow graph simpler in mode computation block: we can
|
||||
// represent kernels as separate nodes with its own separate entry/exit mode
|
||||
|
@ -46,6 +46,9 @@ fn is_block_terminator(instruction: &Statement<ast::Instruction<SpirvWord>, Spir
|
|||
match instruction {
|
||||
Statement::Conditional(..)
|
||||
| Statement::Instruction(ast::Instruction::Bra { .. })
|
||||
// Normally call is not a terminator, but we treat it as such because it
|
||||
// makes the instruction modes to global modes pass possible
|
||||
| Statement::Instruction(ast::Instruction::Call { .. })
|
||||
| Statement::Instruction(ast::Instruction::Ret { .. }) => true,
|
||||
_ => false,
|
||||
}
|
||||
|
|
|
@ -57,8 +57,8 @@ fn run_method<'input, 'b>(
|
|||
tuning: method.tuning,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -46,8 +46,8 @@ fn run_method<'input>(
|
|||
is_kernel: method.is_kernel,
|
||||
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||
roundind_mode_f32: method.roundind_mode_f32,
|
||||
roundind_mode_f16f64: method.roundind_mode_f16f64,
|
||||
rounding_mode_f32: method.rounding_mode_f32,
|
||||
rounding_mode_f16f64: method.rounding_mode_f16f64,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -23,8 +23,8 @@ pub(super) fn run<'input>(
|
|||
is_kernel: false,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
|
|
@ -1049,7 +1049,7 @@ pub enum LdStQualifier {
|
|||
Release(MemScope),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
|
||||
pub enum RoundingMode {
|
||||
NearestEven,
|
||||
Zero,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue