mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Carry all 4 modes in cfg
This commit is contained in:
parent
5121bba285
commit
aaa31da026
1 changed files with 198 additions and 67 deletions
|
@ -1,11 +1,10 @@
|
||||||
use crate::pass::error_unreachable;
|
|
||||||
|
|
||||||
use super::BrachCondition;
|
use super::BrachCondition;
|
||||||
use super::Directive2;
|
use super::Directive2;
|
||||||
use super::Function2;
|
use super::Function2;
|
||||||
use super::SpirvWord;
|
use super::SpirvWord;
|
||||||
use super::Statement;
|
use super::Statement;
|
||||||
use super::TranslateError;
|
use super::TranslateError;
|
||||||
|
use crate::pass::error_unreachable;
|
||||||
use microlp::OptimizationDirection;
|
use microlp::OptimizationDirection;
|
||||||
use microlp::Problem;
|
use microlp::Problem;
|
||||||
use microlp::Variable;
|
use microlp::Variable;
|
||||||
|
@ -18,8 +17,11 @@ use rustc_hash::FxHashMap;
|
||||||
use rustc_hash::FxHashSet;
|
use rustc_hash::FxHashSet;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
use std::mem;
|
||||||
|
use strum::EnumCount;
|
||||||
|
use strum_macros::{EnumCount, VariantArray};
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default, PartialEq, Eq, Clone, Copy, Debug, VariantArray, EnumCount)]
|
||||||
enum DenormalMode {
|
enum DenormalMode {
|
||||||
#[default]
|
#[default]
|
||||||
FlushToZero,
|
FlushToZero,
|
||||||
|
@ -36,7 +38,13 @@ impl DenormalMode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
impl Into<usize> for DenormalMode {
|
||||||
|
fn into(self) -> usize {
|
||||||
|
self as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, PartialEq, Eq, Clone, Copy, Debug, VariantArray, EnumCount)]
|
||||||
enum RoundingMode {
|
enum RoundingMode {
|
||||||
#[default]
|
#[default]
|
||||||
NearestEven,
|
NearestEven,
|
||||||
|
@ -65,20 +73,49 @@ impl RoundingMode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Into<usize> for RoundingMode {
|
||||||
|
fn into(self) -> usize {
|
||||||
|
self as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct InstructionModes {
|
struct InstructionModes {
|
||||||
denormal_f32: Option<DenormalMode>,
|
denormal_f32: Option<DenormalMode>,
|
||||||
denormal_f16_f64: Option<DenormalMode>,
|
denormal_f16f64: Option<DenormalMode>,
|
||||||
rounding_f32: Option<RoundingMode>,
|
rounding_f32: Option<RoundingMode>,
|
||||||
rounding_f16_f64: Option<RoundingMode>,
|
rounding_f16f64: Option<RoundingMode>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InstructionModes {
|
impl InstructionModes {
|
||||||
|
fn fold_into(self, entry: &mut Self, exit: &mut Self) {
|
||||||
|
fn set_if_none<T: Copy>(source: &mut Option<T>, value: Option<T>) {
|
||||||
|
match (*source, value) {
|
||||||
|
(None, Some(x)) => *source = Some(x),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn set_if_some<T: Copy>(source: &mut Option<T>, value: Option<T>) {
|
||||||
|
match (source, value) {
|
||||||
|
(Some(ref mut x), Some(y)) => *x = y,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
set_if_none(&mut entry.denormal_f32, self.denormal_f32);
|
||||||
|
set_if_none(&mut entry.denormal_f16f64, self.denormal_f16f64);
|
||||||
|
set_if_none(&mut entry.rounding_f32, self.rounding_f32);
|
||||||
|
set_if_none(&mut entry.rounding_f16f64, self.rounding_f16f64);
|
||||||
|
set_if_some(&mut exit.denormal_f32, self.denormal_f32);
|
||||||
|
set_if_some(&mut exit.denormal_f16f64, self.denormal_f16f64);
|
||||||
|
set_if_some(&mut exit.rounding_f32, self.rounding_f32);
|
||||||
|
set_if_some(&mut exit.rounding_f16f64, self.rounding_f16f64);
|
||||||
|
}
|
||||||
|
|
||||||
fn none() -> Self {
|
fn none() -> Self {
|
||||||
Self {
|
Self {
|
||||||
denormal_f32: None,
|
denormal_f32: None,
|
||||||
denormal_f16_f64: None,
|
denormal_f16f64: None,
|
||||||
rounding_f32: None,
|
rounding_f32: None,
|
||||||
rounding_f16_f64: None,
|
rounding_f16f64: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,8 +126,8 @@ impl InstructionModes {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if type_ != ast::ScalarType::F32 {
|
if type_ != ast::ScalarType::F32 {
|
||||||
Self {
|
Self {
|
||||||
denormal_f16_f64: denormal,
|
denormal_f16f64: denormal,
|
||||||
rounding_f16_f64: rounding,
|
rounding_f16f64: rounding,
|
||||||
..Self::none()
|
..Self::none()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -109,7 +146,7 @@ impl InstructionModes {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if type_ != ast::ScalarType::F32 {
|
if type_ != ast::ScalarType::F32 {
|
||||||
Self {
|
Self {
|
||||||
denormal_f16_f64: denormal,
|
denormal_f16f64: denormal,
|
||||||
rounding_f32: rounding,
|
rounding_f32: rounding,
|
||||||
..Self::none()
|
..Self::none()
|
||||||
}
|
}
|
||||||
|
@ -191,13 +228,13 @@ impl InstructionModes {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ControlFlowGraph<T: Eq + PartialEq> {
|
struct ControlFlowGraph {
|
||||||
entry_points: FxHashMap<SpirvWord, NodeIndex>,
|
entry_points: FxHashMap<SpirvWord, NodeIndex>,
|
||||||
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
|
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
|
||||||
graph: Graph<Node<ExtendedMode<T>>, ()>,
|
graph: Graph<Node, ()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Eq + PartialEq> ControlFlowGraph<T> {
|
impl ControlFlowGraph {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
entry_points: FxHashMap::default(),
|
entry_points: FxHashMap::default(),
|
||||||
|
@ -207,22 +244,14 @@ impl<T: Eq + PartialEq> ControlFlowGraph<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_entry_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
|
fn add_entry_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
|
||||||
let idx = self.graph.add_node(Node {
|
let idx = self.graph.add_node(Node::entry(label));
|
||||||
label,
|
|
||||||
entry: Some(ExtendedMode::Entry(label)),
|
|
||||||
exit: Some(ExtendedMode::Entry(label)),
|
|
||||||
});
|
|
||||||
assert_eq!(self.entry_points.insert(label, idx), None);
|
assert_eq!(self.entry_points.insert(label, idx), None);
|
||||||
idx
|
idx
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_or_add_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
|
fn get_or_add_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
|
||||||
self.basic_blocks.get(&label).copied().unwrap_or_else(|| {
|
self.basic_blocks.get(&label).copied().unwrap_or_else(|| {
|
||||||
let idx = self.graph.add_node(Node {
|
let idx = self.graph.add_node(Node::new(label));
|
||||||
label,
|
|
||||||
entry: None,
|
|
||||||
exit: None,
|
|
||||||
});
|
|
||||||
self.basic_blocks.insert(label, idx);
|
self.basic_blocks.insert(label, idx);
|
||||||
idx
|
idx
|
||||||
})
|
})
|
||||||
|
@ -233,24 +262,90 @@ impl<T: Eq + PartialEq> ControlFlowGraph<T> {
|
||||||
self.graph.add_edge(from, to, ());
|
self.graph.add_edge(from, to, ());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_modes(&mut self, node: NodeIndex, entry: T, exit: T) {
|
fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) {
|
||||||
self.graph[node].entry = Some(ExtendedMode::BasicBlock(entry));
|
self.graph[node].denormal_f32 = Mode {
|
||||||
self.graph[node].exit = Some(ExtendedMode::BasicBlock(exit));
|
entry: entry.denormal_f32.map(ExtendedMode::BasicBlock),
|
||||||
|
exit: exit.denormal_f32.map(ExtendedMode::BasicBlock),
|
||||||
|
};
|
||||||
|
self.graph[node].denormal_f16f64 = Mode {
|
||||||
|
entry: entry.denormal_f16f64.map(ExtendedMode::BasicBlock),
|
||||||
|
exit: exit.denormal_f16f64.map(ExtendedMode::BasicBlock),
|
||||||
|
};
|
||||||
|
self.graph[node].rounding_f32 = Mode {
|
||||||
|
entry: entry.rounding_f32.map(ExtendedMode::BasicBlock),
|
||||||
|
exit: exit.rounding_f32.map(ExtendedMode::BasicBlock),
|
||||||
|
};
|
||||||
|
self.graph[node].rounding_f16f64 = Mode {
|
||||||
|
entry: entry.rounding_f16f64.map(ExtendedMode::BasicBlock),
|
||||||
|
exit: exit.rounding_f16f64.map(ExtendedMode::BasicBlock),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Copy)]
|
||||||
struct Node<T> {
|
struct Mode<T: Eq + PartialEq> {
|
||||||
|
entry: Option<ExtendedMode<T>>,
|
||||||
|
exit: Option<ExtendedMode<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Eq + PartialEq> Mode<T> {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
entry: None,
|
||||||
|
exit: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn entry(label: SpirvWord) -> Self {
|
||||||
|
Self {
|
||||||
|
entry: Some(ExtendedMode::Entry(label)),
|
||||||
|
exit: Some(ExtendedMode::Entry(label)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Node {
|
||||||
label: SpirvWord,
|
label: SpirvWord,
|
||||||
entry: Option<T>,
|
denormal_f32: Mode<DenormalMode>,
|
||||||
exit: Option<T>,
|
denormal_f16f64: Mode<DenormalMode>,
|
||||||
|
rounding_f32: Mode<RoundingMode>,
|
||||||
|
rounding_f16f64: Mode<RoundingMode>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Node {
|
||||||
|
fn entry(label: SpirvWord) -> Self {
|
||||||
|
Self {
|
||||||
|
label,
|
||||||
|
denormal_f32: Mode::entry(label),
|
||||||
|
denormal_f16f64: Mode::entry(label),
|
||||||
|
rounding_f32: Mode::entry(label),
|
||||||
|
rounding_f16f64: Mode::entry(label),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(label: SpirvWord) -> Self {
|
||||||
|
Self {
|
||||||
|
label,
|
||||||
|
denormal_f32: Mode::new(),
|
||||||
|
denormal_f16f64: Mode::new(),
|
||||||
|
rounding_f32: Mode::new(),
|
||||||
|
rounding_f16f64: Mode::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait EnumTuple {
|
||||||
|
const LENGTH: usize;
|
||||||
|
|
||||||
|
fn get(&self, x: usize) -> u8;
|
||||||
|
fn get_mut(&mut self, x: usize) -> &mut u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let mut cfg = ControlFlowGraph::<bool>::new();
|
let mut cfg = ControlFlowGraph::new();
|
||||||
for directive in directives.iter() {
|
for directive in directives.iter() {
|
||||||
match directive {
|
match directive {
|
||||||
super::Directive2::Method(Function2 {
|
super::Directive2::Method(Function2 {
|
||||||
|
@ -259,11 +354,18 @@ pub(crate) fn run<'input>(
|
||||||
..
|
..
|
||||||
}) => {
|
}) => {
|
||||||
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
|
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
|
||||||
|
let mut entry = InstructionModes::none();
|
||||||
|
let mut exit = InstructionModes::none();
|
||||||
for statement in body.iter() {
|
for statement in body.iter() {
|
||||||
match statement {
|
match statement {
|
||||||
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
|
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
|
||||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||||
cfg.add_jump(bb_index, arguments.src);
|
cfg.add_jump(bb_index, arguments.src);
|
||||||
|
cfg.set_modes(
|
||||||
|
bb_index,
|
||||||
|
mem::replace(&mut entry, InstructionModes::none()),
|
||||||
|
mem::replace(&mut exit, InstructionModes::none()),
|
||||||
|
);
|
||||||
basic_block = None;
|
basic_block = None;
|
||||||
}
|
}
|
||||||
Statement::Label(label) => {
|
Statement::Label(label) => {
|
||||||
|
@ -275,22 +377,31 @@ pub(crate) fn run<'input>(
|
||||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
||||||
cfg.add_jump(bb_index, *if_true);
|
cfg.add_jump(bb_index, *if_true);
|
||||||
cfg.add_jump(bb_index, *if_false);
|
cfg.add_jump(bb_index, *if_false);
|
||||||
|
cfg.set_modes(
|
||||||
|
bb_index,
|
||||||
|
mem::replace(&mut entry, InstructionModes::none()),
|
||||||
|
mem::replace(&mut exit, InstructionModes::none()),
|
||||||
|
);
|
||||||
basic_block = None;
|
basic_block = None;
|
||||||
}
|
}
|
||||||
Statement::Instruction(instruction) => {
|
Statement::Instruction(instruction) => {
|
||||||
let modes = get_modes(instruction);
|
let modes = get_modes(instruction);
|
||||||
|
modes.fold_into(&mut entry, &mut exit);
|
||||||
}
|
}
|
||||||
_ => continue,
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => continue,
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute<T: Copy + Eq>(graph: ControlFlowGraph<T>) -> PartialModeInsertion<T> {
|
fn compute_single_mode<T: Copy + Eq>(
|
||||||
|
graph: &ControlFlowGraph,
|
||||||
|
mut getter: impl FnMut(&Node) -> Mode<T>,
|
||||||
|
) -> PartialModeInsertion<T> {
|
||||||
let mut must_insert_mode = FxHashSet::<SpirvWord>::default();
|
let mut must_insert_mode = FxHashSet::<SpirvWord>::default();
|
||||||
let mut maybe_insert_mode = FxHashMap::default();
|
let mut maybe_insert_mode = FxHashMap::default();
|
||||||
let mut remaining = graph
|
let mut remaining = graph
|
||||||
|
@ -298,7 +409,8 @@ fn compute<T: Copy + Eq>(graph: ControlFlowGraph<T>) -> PartialModeInsertion<T>
|
||||||
.node_references()
|
.node_references()
|
||||||
.rev()
|
.rev()
|
||||||
.filter_map(|(index, node)| {
|
.filter_map(|(index, node)| {
|
||||||
node.entry
|
getter(node)
|
||||||
|
.entry
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|mode| match mode {
|
.map(|mode| match mode {
|
||||||
ExtendedMode::BasicBlock(mode) => Some((index, node.label, *mode)),
|
ExtendedMode::BasicBlock(mode) => Some((index, node.label, *mode)),
|
||||||
|
@ -316,7 +428,7 @@ fn compute<T: Copy + Eq>(graph: ControlFlowGraph<T>) -> PartialModeInsertion<T>
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
visited.insert(current);
|
visited.insert(current);
|
||||||
let exit_mode = graph.graph.node_weight(current).unwrap().exit;
|
let exit_mode = getter(graph.graph.node_weight(current).unwrap()).exit;
|
||||||
match exit_mode {
|
match exit_mode {
|
||||||
None => {
|
None => {
|
||||||
for predecessor in graph.graph.neighbors_directed(current, Direction::Incoming)
|
for predecessor in graph.graph.neighbors_directed(current, Direction::Incoming)
|
||||||
|
@ -355,7 +467,7 @@ struct PartialModeInsertion<T> {
|
||||||
bb_maybe_insert_mode: FxHashMap<SpirvWord, (T, FxHashSet<SpirvWord>)>,
|
bb_maybe_insert_mode: FxHashMap<SpirvWord, (T, FxHashSet<SpirvWord>)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn optimize<T: Copy + Into<usize> + TryFrom<usize> + std::fmt::Debug, const N: usize>(
|
fn optimize<T: Copy + Into<usize> + strum::VariantArray + std::fmt::Debug, const N: usize>(
|
||||||
partial: PartialModeInsertion<T>,
|
partial: PartialModeInsertion<T>,
|
||||||
) -> ModeInsertions<T> {
|
) -> ModeInsertions<T> {
|
||||||
let mut problem = Problem::new(OptimizationDirection::Maximize);
|
let mut problem = Problem::new(OptimizationDirection::Maximize);
|
||||||
|
@ -389,7 +501,7 @@ fn optimize<T: Copy + Into<usize> + TryFrom<usize> + std::fmt::Debug, const N: u
|
||||||
for (kernel, modes) in kernel_modes {
|
for (kernel, modes) in kernel_modes {
|
||||||
for (mode, var) in modes.into_iter().enumerate() {
|
for (mode, var) in modes.into_iter().enumerate() {
|
||||||
if solution[var] > 0.5 {
|
if solution[var] > 0.5 {
|
||||||
kernels.insert(kernel, T::try_from(mode).unwrap_or_else(|_| todo!()));
|
kernels.insert(kernel, T::VARIANTS[mode]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -642,6 +754,7 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use int_enum::IntEnum;
|
use int_enum::IntEnum;
|
||||||
|
use strum::EnumCount;
|
||||||
|
|
||||||
#[repr(usize)]
|
#[repr(usize)]
|
||||||
#[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)]
|
#[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)]
|
||||||
|
@ -650,9 +763,27 @@ mod tests {
|
||||||
True = 1,
|
True = 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ftz() -> InstructionModes {
|
||||||
|
InstructionModes {
|
||||||
|
denormal_f32: Some(DenormalMode::FlushToZero),
|
||||||
|
denormal_f16f64: None,
|
||||||
|
rounding_f32: None,
|
||||||
|
rounding_f16f64: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn preserve() -> InstructionModes {
|
||||||
|
InstructionModes {
|
||||||
|
denormal_f32: Some(DenormalMode::Preserve),
|
||||||
|
denormal_f16f64: None,
|
||||||
|
rounding_f32: None,
|
||||||
|
rounding_f16f64: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn transitive_mixed() {
|
fn transitive_mixed() {
|
||||||
let mut graph = ControlFlowGraph::<Bool>::new();
|
let mut graph = ControlFlowGraph::new();
|
||||||
let entry_id = SpirvWord(1);
|
let entry_id = SpirvWord(1);
|
||||||
let false_id = SpirvWord(2);
|
let false_id = SpirvWord(2);
|
||||||
let empty_id = SpirvWord(3);
|
let empty_id = SpirvWord(3);
|
||||||
|
@ -660,29 +791,29 @@ mod tests {
|
||||||
let entry = graph.add_entry_basic_block(entry_id);
|
let entry = graph.add_entry_basic_block(entry_id);
|
||||||
graph.add_jump(entry, false_id);
|
graph.add_jump(entry, false_id);
|
||||||
let false_ = graph.get_or_add_basic_block(false_id);
|
let false_ = graph.get_or_add_basic_block(false_id);
|
||||||
graph.set_modes(false_, Bool::False, Bool::False);
|
graph.set_modes(false_, ftz(), ftz());
|
||||||
graph.add_jump(false_, empty_id);
|
graph.add_jump(false_, empty_id);
|
||||||
let empty = graph.get_or_add_basic_block(empty_id);
|
let empty = graph.get_or_add_basic_block(empty_id);
|
||||||
graph.add_jump(empty, false2_id);
|
graph.add_jump(empty, false2_id);
|
||||||
let false2_ = graph.get_or_add_basic_block(false2_id);
|
let false2_ = graph.get_or_add_basic_block(false2_id);
|
||||||
graph.set_modes(false2_, Bool::False, Bool::False);
|
graph.set_modes(false2_, ftz(), ftz());
|
||||||
let partial_result = super::compute(graph);
|
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
||||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
partial_result.bb_maybe_insert_mode[&false_id],
|
partial_result.bb_maybe_insert_mode[&false_id],
|
||||||
(Bool::False, iter::once(entry_id).collect())
|
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = optimize::<Bool, 2>(partial_result);
|
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||||
assert_eq!(result.basic_blocks.len(), 0);
|
assert_eq!(result.basic_blocks.len(), 0);
|
||||||
assert_eq!(result.kernels.len(), 1);
|
assert_eq!(result.kernels.len(), 1);
|
||||||
assert_eq!(result.kernels[&entry_id], Bool::False);
|
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn transitive_change_twice() {
|
fn transitive_change_twice() {
|
||||||
let mut graph = ControlFlowGraph::<Bool>::new();
|
let mut graph = ControlFlowGraph::new();
|
||||||
let entry_id = SpirvWord(1);
|
let entry_id = SpirvWord(1);
|
||||||
let false_id = SpirvWord(2);
|
let false_id = SpirvWord(2);
|
||||||
let empty_id = SpirvWord(3);
|
let empty_id = SpirvWord(3);
|
||||||
|
@ -690,30 +821,30 @@ mod tests {
|
||||||
let entry = graph.add_entry_basic_block(entry_id);
|
let entry = graph.add_entry_basic_block(entry_id);
|
||||||
graph.add_jump(entry, false_id);
|
graph.add_jump(entry, false_id);
|
||||||
let false_ = graph.get_or_add_basic_block(false_id);
|
let false_ = graph.get_or_add_basic_block(false_id);
|
||||||
graph.set_modes(false_, Bool::False, Bool::False);
|
graph.set_modes(false_, ftz(), ftz());
|
||||||
graph.add_jump(false_, empty_id);
|
graph.add_jump(false_, empty_id);
|
||||||
let empty = graph.get_or_add_basic_block(empty_id);
|
let empty = graph.get_or_add_basic_block(empty_id);
|
||||||
graph.add_jump(empty, true_id);
|
graph.add_jump(empty, true_id);
|
||||||
let true_ = graph.get_or_add_basic_block(true_id);
|
let true_ = graph.get_or_add_basic_block(true_id);
|
||||||
graph.set_modes(true_, Bool::True, Bool::True);
|
graph.set_modes(true_, preserve(), preserve());
|
||||||
let partial_result = super::compute(graph);
|
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 1);
|
assert_eq!(partial_result.bb_must_insert_mode.len(), 1);
|
||||||
assert!(partial_result.bb_must_insert_mode.contains(&true_id));
|
assert!(partial_result.bb_must_insert_mode.contains(&true_id));
|
||||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
partial_result.bb_maybe_insert_mode[&false_id],
|
partial_result.bb_maybe_insert_mode[&false_id],
|
||||||
(Bool::False, iter::once(entry_id).collect())
|
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = optimize::<Bool, 2>(partial_result);
|
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||||
assert_eq!(result.basic_blocks, iter::once(true_id).collect());
|
assert_eq!(result.basic_blocks, iter::once(true_id).collect());
|
||||||
assert_eq!(result.kernels.len(), 1);
|
assert_eq!(result.kernels.len(), 1);
|
||||||
assert_eq!(result.kernels[&entry_id], Bool::False);
|
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn transitive_change() {
|
fn transitive_change() {
|
||||||
let mut graph = ControlFlowGraph::<Bool>::new();
|
let mut graph = ControlFlowGraph::new();
|
||||||
let entry_id = SpirvWord(1);
|
let entry_id = SpirvWord(1);
|
||||||
let empty_id = SpirvWord(2);
|
let empty_id = SpirvWord(2);
|
||||||
let true_id = SpirvWord(3);
|
let true_id = SpirvWord(3);
|
||||||
|
@ -722,24 +853,24 @@ mod tests {
|
||||||
let empty = graph.get_or_add_basic_block(empty_id);
|
let empty = graph.get_or_add_basic_block(empty_id);
|
||||||
graph.add_jump(empty, true_id);
|
graph.add_jump(empty, true_id);
|
||||||
let true_ = graph.get_or_add_basic_block(true_id);
|
let true_ = graph.get_or_add_basic_block(true_id);
|
||||||
graph.set_modes(true_, Bool::True, Bool::True);
|
graph.set_modes(true_, preserve(), preserve());
|
||||||
let partial_result = super::compute(graph);
|
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
||||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
partial_result.bb_maybe_insert_mode[&true_id],
|
partial_result.bb_maybe_insert_mode[&true_id],
|
||||||
(Bool::True, iter::once(entry_id).collect())
|
(DenormalMode::Preserve, iter::once(entry_id).collect())
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = optimize::<Bool, 2>(partial_result);
|
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||||
assert_eq!(result.basic_blocks.len(), 0);
|
assert_eq!(result.basic_blocks.len(), 0);
|
||||||
assert_eq!(result.kernels.len(), 1);
|
assert_eq!(result.kernels.len(), 1);
|
||||||
assert_eq!(result.kernels[&entry_id], Bool::True);
|
assert_eq!(result.kernels[&entry_id], DenormalMode::Preserve);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn codependency() {
|
fn codependency() {
|
||||||
let mut graph = ControlFlowGraph::<Bool>::new();
|
let mut graph = ControlFlowGraph::new();
|
||||||
let entry_id = SpirvWord(1);
|
let entry_id = SpirvWord(1);
|
||||||
let left_f_id = SpirvWord(2);
|
let left_f_id = SpirvWord(2);
|
||||||
let right_f_id = SpirvWord(3);
|
let right_f_id = SpirvWord(3);
|
||||||
|
@ -750,9 +881,9 @@ mod tests {
|
||||||
graph.add_jump(entry, left_f_id);
|
graph.add_jump(entry, left_f_id);
|
||||||
graph.add_jump(entry, right_f_id);
|
graph.add_jump(entry, right_f_id);
|
||||||
let left_f = graph.get_or_add_basic_block(left_f_id);
|
let left_f = graph.get_or_add_basic_block(left_f_id);
|
||||||
graph.set_modes(left_f, Bool::False, Bool::False);
|
graph.set_modes(left_f, ftz(), ftz());
|
||||||
let right_f = graph.get_or_add_basic_block(right_f_id);
|
let right_f = graph.get_or_add_basic_block(right_f_id);
|
||||||
graph.set_modes(right_f, Bool::False, Bool::False);
|
graph.set_modes(right_f, ftz(), ftz());
|
||||||
graph.add_jump(left_f, left_none_id);
|
graph.add_jump(left_f, left_none_id);
|
||||||
let left_none = graph.get_or_add_basic_block(left_none_id);
|
let left_none = graph.get_or_add_basic_block(left_none_id);
|
||||||
graph.add_jump(right_f, right_none_id);
|
graph.add_jump(right_f, right_none_id);
|
||||||
|
@ -766,21 +897,21 @@ mod tests {
|
||||||
// "{:?}",
|
// "{:?}",
|
||||||
// petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
// petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
||||||
//);
|
//);
|
||||||
let partial_result = super::compute(graph);
|
let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32);
|
||||||
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
assert_eq!(partial_result.bb_must_insert_mode.len(), 0);
|
||||||
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2);
|
assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
partial_result.bb_maybe_insert_mode[&left_f_id],
|
partial_result.bb_maybe_insert_mode[&left_f_id],
|
||||||
(Bool::False, iter::once(entry_id).collect())
|
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
partial_result.bb_maybe_insert_mode[&right_f_id],
|
partial_result.bb_maybe_insert_mode[&right_f_id],
|
||||||
(Bool::False, iter::once(entry_id).collect())
|
(DenormalMode::FlushToZero, iter::once(entry_id).collect())
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = optimize::<Bool, 2>(partial_result);
|
let result = optimize::<DenormalMode, { DenormalMode::COUNT }>(partial_result);
|
||||||
assert_eq!(result.basic_blocks.len(), 0);
|
assert_eq!(result.basic_blocks.len(), 0);
|
||||||
assert_eq!(result.kernels.len(), 1);
|
assert_eq!(result.kernels.len(), 1);
|
||||||
assert_eq!(result.kernels[&entry_id], Bool::False);
|
assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue