mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Add mode-setting wrappers to functions
This commit is contained in:
parent
a0d4b7eeb2
commit
87fe601494
3 changed files with 323 additions and 94 deletions
|
@ -734,7 +734,9 @@ pub(crate) fn run<'input>(
|
||||||
}
|
}
|
||||||
Statement::RetValue(..)
|
Statement::RetValue(..)
|
||||||
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||||
bb_state.record_ret(*name)?;
|
if !is_kernel {
|
||||||
|
bb_state.record_ret(*name)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Statement::Label(label) => {
|
Statement::Label(label) => {
|
||||||
bb_state.start(*label);
|
bb_state.start(*label);
|
||||||
|
@ -808,7 +810,7 @@ pub(crate) fn run<'input>(
|
||||||
)?;
|
)?;
|
||||||
let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?;
|
let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?;
|
||||||
*/
|
*/
|
||||||
let directives = insert_mode_control(directives, temp)?;
|
let directives = insert_mode_control(flat_resolver, directives, temp)?;
|
||||||
Ok(directives)
|
Ok(directives)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1018,47 +1020,48 @@ struct TwinMode<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_mode_control(
|
fn insert_mode_control(
|
||||||
|
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
global_modes: FullModeInsertion2,
|
global_modes: FullModeInsertion2,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let directives_len = directives.len();
|
let directives_len = directives.len();
|
||||||
directives
|
directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|mut directive| {
|
.map(|directive| {
|
||||||
let mut new_directives = SmallVec::<[_; 4]>::new();
|
let mut new_directives = SmallVec::<[_; 4]>::new();
|
||||||
let (fn_name, initial_mode, body_ptr) = match directive {
|
let (mut method, initial_mode) = match directive {
|
||||||
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => {
|
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => {
|
||||||
new_directives.push(directive);
|
new_directives.push(directive);
|
||||||
return Ok(new_directives);
|
return Ok(new_directives);
|
||||||
}
|
}
|
||||||
Directive2::Method(Function2 {
|
Directive2::Method(
|
||||||
name,
|
mut method @ Function2 {
|
||||||
body: Some(ref mut body),
|
name,
|
||||||
ref mut flush_to_zero_f32,
|
body: Some(_),
|
||||||
ref mut flush_to_zero_f16f64,
|
..
|
||||||
ref mut rounding_mode_f32,
|
},
|
||||||
ref mut rounding_mode_f16f64,
|
) => {
|
||||||
..
|
|
||||||
}) => {
|
|
||||||
let initial_mode = global_modes
|
let initial_mode = global_modes
|
||||||
.basic_blocks
|
.basic_blocks
|
||||||
.get(&name)
|
.get(&name)
|
||||||
.ok_or_else(error_unreachable)?;
|
.ok_or_else(error_unreachable)?;
|
||||||
let denormal_mode = initial_mode.denormal.twin_mode;
|
let denormal_mode = initial_mode.denormal.twin_mode;
|
||||||
let rounding_mode = initial_mode.rounding.twin_mode;
|
let rounding_mode = initial_mode.rounding.twin_mode;
|
||||||
*flush_to_zero_f32 = denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz();
|
method.flush_to_zero_f32 =
|
||||||
*flush_to_zero_f16f64 =
|
denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz();
|
||||||
|
method.flush_to_zero_f16f64 =
|
||||||
denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz();
|
denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz();
|
||||||
*rounding_mode_f32 = rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast();
|
method.rounding_mode_f32 =
|
||||||
*rounding_mode_f16f64 =
|
rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast();
|
||||||
|
method.rounding_mode_f16f64 =
|
||||||
rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast();
|
rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast();
|
||||||
(name, initial_mode, body)
|
(method, initial_mode)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
emit_mode_prelude(fn_name, &mut new_directives);
|
emit_mode_prelude(flat_resolver, &method, &global_modes, &mut new_directives)?;
|
||||||
let old_body = mem::replace(body_ptr, Vec::new());
|
let old_body = method.body.take().unwrap();
|
||||||
let mut result = Vec::with_capacity(old_body.len());
|
let mut result = Vec::with_capacity(old_body.len());
|
||||||
let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode);
|
let mut bb_state = BasicBlockControlState::new(&global_modes, initial_mode);
|
||||||
let mut old_body = old_body.into_iter();
|
let mut old_body = old_body.into_iter();
|
||||||
while let Some(mut statement) = old_body.next() {
|
while let Some(mut statement) = old_body.next() {
|
||||||
let mut call_target = None;
|
let mut call_target = None;
|
||||||
|
@ -1115,8 +1118,8 @@ fn insert_mode_control(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*body_ptr = result;
|
method.body = Some(result);
|
||||||
new_directives.push(directive);
|
new_directives.push(Directive2::Method(method));
|
||||||
Ok(new_directives)
|
Ok(new_directives)
|
||||||
})
|
})
|
||||||
.try_fold(Vec::with_capacity(directives_len), |mut acc, d| {
|
.try_fold(Vec::with_capacity(directives_len), |mut acc, d| {
|
||||||
|
@ -1126,35 +1129,219 @@ fn insert_mode_control(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn emit_mode_prelude(
|
fn emit_mode_prelude(
|
||||||
fn_name: SpirvWord,
|
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||||
global_modes: FullModeInsertion2,
|
method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
global_modes: &FullModeInsertion2,
|
||||||
new_directives: &mut SmallVec<[Directive2<ptx_parser::Instruction<SpirvWord>, SpirvWord>; 4]>,
|
new_directives: &mut SmallVec<[Directive2<ptx_parser::Instruction<SpirvWord>, SpirvWord>; 4]>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let fn_mode_state = global_modes.basic_blocks.get(&fn_name).ok_or_else(error_unreachable)?;
|
let fn_mode_state = global_modes
|
||||||
|
.basic_blocks
|
||||||
|
.get(&method.name)
|
||||||
|
.ok_or_else(error_unreachable)?;
|
||||||
if let Some(dual_prologue) = fn_mode_state.dual_prologue {
|
if let Some(dual_prologue) = fn_mode_state.dual_prologue {
|
||||||
new_directives.push(Directive2::Method(
|
new_directives.push(create_fn_wrapper(
|
||||||
Function2 {
|
flat_resolver,
|
||||||
return_arguments: todo!(),
|
method,
|
||||||
name: dual_prologue,
|
dual_prologue,
|
||||||
input_arguments: todo!(),
|
[
|
||||||
body: todo!(),
|
ModeRegister::Denormal {
|
||||||
is_kernel: false,
|
f32: fn_mode_state
|
||||||
import_as: None,
|
.denormal
|
||||||
tuning: Vec::new(),
|
.twin_mode
|
||||||
linkage: ast::LinkingDirective::NONE,
|
.f32
|
||||||
flush_to_zero_f32: false,
|
.unwrap_of_default()
|
||||||
flush_to_zero_f16f64: false,
|
.to_ftz(),
|
||||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
f16f64: fn_mode_state
|
||||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
.denormal
|
||||||
}
|
.twin_mode
|
||||||
|
.f16f64
|
||||||
|
.unwrap_of_default()
|
||||||
|
.to_ftz(),
|
||||||
|
},
|
||||||
|
ModeRegister::Rounding {
|
||||||
|
f32: fn_mode_state
|
||||||
|
.rounding
|
||||||
|
.twin_mode
|
||||||
|
.f32
|
||||||
|
.unwrap_of_default()
|
||||||
|
.to_ast(),
|
||||||
|
f16f64: fn_mode_state
|
||||||
|
.rounding
|
||||||
|
.twin_mode
|
||||||
|
.f16f64
|
||||||
|
.unwrap_of_default()
|
||||||
|
.to_ast(),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
.into_iter(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if let Some(prologue) = fn_mode_state.denormal.prologue {
|
if let Some(prologue) = fn_mode_state.denormal.prologue {
|
||||||
todo!()
|
new_directives.push(create_fn_wrapper(
|
||||||
|
flat_resolver,
|
||||||
|
method,
|
||||||
|
prologue,
|
||||||
|
[ModeRegister::Denormal {
|
||||||
|
f32: fn_mode_state
|
||||||
|
.denormal
|
||||||
|
.twin_mode
|
||||||
|
.f32
|
||||||
|
.unwrap_of_default()
|
||||||
|
.to_ftz(),
|
||||||
|
f16f64: fn_mode_state
|
||||||
|
.denormal
|
||||||
|
.twin_mode
|
||||||
|
.f16f64
|
||||||
|
.unwrap_of_default()
|
||||||
|
.to_ftz(),
|
||||||
|
}]
|
||||||
|
.into_iter(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
if let Some(prologue) = fn_mode_state.rounding.prologue {
|
if let Some(prologue) = fn_mode_state.rounding.prologue {
|
||||||
todo!()
|
new_directives.push(create_fn_wrapper(
|
||||||
|
flat_resolver,
|
||||||
|
method,
|
||||||
|
prologue,
|
||||||
|
[ModeRegister::Rounding {
|
||||||
|
f32: fn_mode_state
|
||||||
|
.rounding
|
||||||
|
.twin_mode
|
||||||
|
.f32
|
||||||
|
.unwrap_of_default()
|
||||||
|
.to_ast(),
|
||||||
|
f16f64: fn_mode_state
|
||||||
|
.rounding
|
||||||
|
.twin_mode
|
||||||
|
.f16f64
|
||||||
|
.unwrap_of_default()
|
||||||
|
.to_ast(),
|
||||||
|
}]
|
||||||
|
.into_iter(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_fn_wrapper(
|
||||||
|
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||||
|
method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||||
|
name: SpirvWord,
|
||||||
|
modes: impl ExactSizeIterator<Item = ModeRegister>,
|
||||||
|
) -> Directive2<ast::Instruction<SpirvWord>, SpirvWord> {
|
||||||
|
// * Label
|
||||||
|
// * return argument registers
|
||||||
|
// * input argument registers
|
||||||
|
// * Load input arguments
|
||||||
|
// * set modes
|
||||||
|
// * call
|
||||||
|
// * return with value
|
||||||
|
let return_arguments = rename_variables(flat_resolver, &method.return_arguments);
|
||||||
|
let input_arguments = rename_variables(flat_resolver, &method.input_arguments);
|
||||||
|
let mut body = Vec::with_capacity(
|
||||||
|
1 + (input_arguments.len() * 2) + return_arguments.len() + modes.len() + 2,
|
||||||
|
);
|
||||||
|
body.push(Statement::Label(flat_resolver.register_unnamed(None)));
|
||||||
|
let return_variables = append_variables(flat_resolver, &mut body, &return_arguments);
|
||||||
|
let input_variables = append_variables(flat_resolver, &mut body, &input_arguments);
|
||||||
|
for (index, input_reg) in input_variables.iter().enumerate() {
|
||||||
|
body.push(Statement::Instruction(ast::Instruction::Ld {
|
||||||
|
data: ast::LdDetails {
|
||||||
|
qualifier: ast::LdStQualifier::Weak,
|
||||||
|
state_space: input_arguments[index].state_space,
|
||||||
|
caching: ast::LdCacheOperator::Cached,
|
||||||
|
typ: input_arguments[index].v_type.clone(),
|
||||||
|
non_coherent: false,
|
||||||
|
},
|
||||||
|
arguments: ast::LdArgs {
|
||||||
|
src: input_arguments[index].name,
|
||||||
|
dst: *input_reg,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
body.extend(modes.map(|mode_set| Statement::SetMode(mode_set)));
|
||||||
|
// Out of order because we want to use return_variables before they are moved
|
||||||
|
let ret_statement = if return_arguments.is_empty() {
|
||||||
|
Statement::Instruction(ast::Instruction::Ret {
|
||||||
|
data: ast::RetData { uniform: false },
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Statement::RetValue(
|
||||||
|
ast::RetData { uniform: false },
|
||||||
|
return_variables
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(index, var)| (*var, method.return_arguments[index].v_type.clone()))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
body.push(Statement::Instruction(ast::Instruction::Call {
|
||||||
|
data: ast::CallDetails {
|
||||||
|
uniform: false,
|
||||||
|
return_arguments: return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|arg| (arg.v_type.clone(), arg.state_space))
|
||||||
|
.collect(),
|
||||||
|
input_arguments: input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|arg| (arg.v_type.clone(), arg.state_space))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
arguments: ast::CallArgs {
|
||||||
|
return_arguments: return_variables,
|
||||||
|
func: method.name,
|
||||||
|
input_arguments: input_variables,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
body.push(ret_statement);
|
||||||
|
Directive2::Method(Function2 {
|
||||||
|
return_arguments,
|
||||||
|
name,
|
||||||
|
input_arguments,
|
||||||
|
body: Some(body),
|
||||||
|
is_kernel: false,
|
||||||
|
import_as: None,
|
||||||
|
tuning: Vec::new(),
|
||||||
|
linkage: ast::LinkingDirective::NONE,
|
||||||
|
flush_to_zero_f32: false,
|
||||||
|
flush_to_zero_f16f64: false,
|
||||||
|
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||||
|
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rename_variables(
|
||||||
|
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||||
|
variables: &Vec<ast::Variable<SpirvWord>>,
|
||||||
|
) -> Vec<ast::Variable<SpirvWord>> {
|
||||||
|
variables
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.map(|arg| ast::Variable {
|
||||||
|
name: flat_resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))),
|
||||||
|
..arg
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append_variables<'a, 'input: 'a>(
|
||||||
|
flat_resolver: &'a mut super::GlobalStringIdentResolver2<'input>,
|
||||||
|
body: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
arguments: &'a Vec<ast::Variable<SpirvWord>>,
|
||||||
|
) -> Vec<SpirvWord> {
|
||||||
|
let mut result = Vec::with_capacity(arguments.len());
|
||||||
|
for arg in arguments {
|
||||||
|
let name = flat_resolver.register_unnamed(Some((arg.v_type.clone(), ast::StateSpace::Reg)));
|
||||||
|
body.push(Statement::Variable(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
name,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}));
|
||||||
|
result.push(name);
|
||||||
|
}
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
struct BasicBlockControlState<'a> {
|
struct BasicBlockControlState<'a> {
|
||||||
|
@ -1163,7 +1350,6 @@ struct BasicBlockControlState<'a> {
|
||||||
denormal_f16f64: RegisterState<bool>,
|
denormal_f16f64: RegisterState<bool>,
|
||||||
rounding_f32: RegisterState<ast::RoundingMode>,
|
rounding_f32: RegisterState<ast::RoundingMode>,
|
||||||
rounding_f16f64: RegisterState<ast::RoundingMode>,
|
rounding_f16f64: RegisterState<ast::RoundingMode>,
|
||||||
current_bb: SpirvWord,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
|
@ -1175,20 +1361,6 @@ struct RegisterState<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> RegisterState<T> {
|
impl<T> RegisterState<T> {
|
||||||
fn single(t: T) -> Self {
|
|
||||||
RegisterState {
|
|
||||||
last_foldable: None,
|
|
||||||
current_value: Resolved::Value(t),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conflict() -> Self {
|
|
||||||
RegisterState {
|
|
||||||
last_foldable: None,
|
|
||||||
current_value: Resolved::Conflict,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new<U>(value: Resolved<U>) -> RegisterState<T>
|
fn new<U>(value: Resolved<U>) -> RegisterState<T>
|
||||||
where
|
where
|
||||||
U: Into<T>,
|
U: Into<T>,
|
||||||
|
@ -1201,11 +1373,7 @@ impl<T> RegisterState<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> BasicBlockControlState<'a> {
|
impl<'a> BasicBlockControlState<'a> {
|
||||||
fn new(
|
fn new(global_modes: &'a FullModeInsertion2, initial_mode: &FullBasicBlockEntryState) -> Self {
|
||||||
global_modes: &'a FullModeInsertion2,
|
|
||||||
current_bb: SpirvWord,
|
|
||||||
initial_mode: &FullBasicBlockEntryState,
|
|
||||||
) -> Self {
|
|
||||||
let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32);
|
let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32);
|
||||||
let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64);
|
let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64);
|
||||||
let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32);
|
let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32);
|
||||||
|
@ -1216,7 +1384,6 @@ impl<'a> BasicBlockControlState<'a> {
|
||||||
denormal_f16f64,
|
denormal_f16f64,
|
||||||
rounding_f32,
|
rounding_f32,
|
||||||
rounding_f16f64,
|
rounding_f16f64,
|
||||||
current_bb,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1405,12 +1572,6 @@ fn redirect_jump_impl(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ModeJumpTargets {
|
|
||||||
dual_prologue: Option<SpirvWord>,
|
|
||||||
denormal: Option<SpirvWord>,
|
|
||||||
rounding: Option<SpirvWord>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
enum Resolved<T> {
|
enum Resolved<T> {
|
||||||
Conflict,
|
Conflict,
|
||||||
|
@ -1457,13 +1618,6 @@ impl<T> Resolved<T> {
|
||||||
Resolved::Conflict => Err(err()),
|
Resolved::Conflict => Err(err()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn has_value(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
Resolved::Value(_) => true,
|
|
||||||
Resolved::Conflict => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
trait ModeView {
|
trait ModeView {
|
||||||
|
|
|
@ -258,30 +258,83 @@ fn call_with_mode() {
|
||||||
));
|
));
|
||||||
let [to_fn0] = calls(method_1);
|
let [to_fn0] = calls(method_1);
|
||||||
let [_, dual_prelude, _, _, add] = labels(method_1);
|
let [_, dual_prelude, _, _, add] = labels(method_1);
|
||||||
let [post_call, post_prelude_0, post_prelude_1, post_prelude_2] = branches(method_1);
|
let [post_call, post_prelude_dual, post_prelude_denormal, post_prelude_rounding] =
|
||||||
|
branches(method_1);
|
||||||
assert_eq!(methods[0].name, to_fn0);
|
assert_eq!(methods[0].name, to_fn0);
|
||||||
assert_eq!(post_call, dual_prelude);
|
assert_eq!(post_call, dual_prelude);
|
||||||
assert_eq!(post_prelude_0, add);
|
assert_eq!(post_prelude_dual, add);
|
||||||
assert_eq!(post_prelude_1, add);
|
assert_eq!(post_prelude_denormal, add);
|
||||||
assert_eq!(post_prelude_2, add);
|
assert_eq!(post_prelude_rounding, add);
|
||||||
|
|
||||||
let method_2 = methods[2].body.as_ref().unwrap();
|
let method_2 = methods[2].body.as_ref().unwrap();
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
&**method_2,
|
&**method_2,
|
||||||
[
|
[
|
||||||
Statement::Label(..),
|
Statement::Label(..),
|
||||||
|
Statement::Variable(..),
|
||||||
|
Statement::Variable(..),
|
||||||
|
Statement::Conditional(..),
|
||||||
|
Statement::Label(..),
|
||||||
|
Statement::Conditional(..),
|
||||||
|
Statement::Label(..),
|
||||||
|
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||||
|
Statement::Label(..),
|
||||||
|
// Dual prelude
|
||||||
Statement::SetMode(ModeRegister::Denormal {
|
Statement::SetMode(ModeRegister::Denormal {
|
||||||
f32: true,
|
f32: false,
|
||||||
f16f64: true
|
f16f64: true
|
||||||
}),
|
}),
|
||||||
Statement::SetMode(ModeRegister::Rounding {
|
Statement::SetMode(ModeRegister::Rounding {
|
||||||
f32: ast::RoundingMode::PositiveInf,
|
f32: ast::RoundingMode::NegativeInf,
|
||||||
f16f64: ast::RoundingMode::NearestEven
|
f16f64: ast::RoundingMode::NearestEven
|
||||||
}),
|
}),
|
||||||
Statement::Instruction(ast::Instruction::Call { .. }),
|
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||||
|
// Denormal prelude
|
||||||
|
Statement::Label(..),
|
||||||
|
Statement::SetMode(ModeRegister::Denormal {
|
||||||
|
f32: false,
|
||||||
|
f16f64: true
|
||||||
|
}),
|
||||||
|
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||||
|
// Rounding prelude
|
||||||
|
Statement::Label(..),
|
||||||
|
Statement::SetMode(ModeRegister::Rounding {
|
||||||
|
f32: ast::RoundingMode::NegativeInf,
|
||||||
|
f16f64: ast::RoundingMode::NearestEven
|
||||||
|
}),
|
||||||
|
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||||
|
Statement::Label(..),
|
||||||
|
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||||
|
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||||
|
Statement::Label(..),
|
||||||
|
Statement::SetMode(ModeRegister::Denormal {
|
||||||
|
f32: false,
|
||||||
|
f16f64: true
|
||||||
|
}),
|
||||||
|
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||||
|
Statement::Label(..),
|
||||||
|
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||||
|
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||||
|
Statement::Label(..),
|
||||||
Statement::Instruction(ast::Instruction::Ret { .. }),
|
Statement::Instruction(ast::Instruction::Ret { .. }),
|
||||||
]
|
]
|
||||||
));
|
));
|
||||||
|
let [(if_rm_true, if_rm_false), (if_rz_true, if_rz_false)] = conditionals(method_2);
|
||||||
|
let [_, conditional2, post_conditional2, prelude_dual, _, _, add1, add2_set_denormal, add2, ret] =
|
||||||
|
labels(method_2);
|
||||||
|
let [post_conditional2_jump, post_prelude_dual, post_prelude_denormal, post_prelude_rounding, post_add1, post_add2_set_denormal, post_add2] =
|
||||||
|
branches(method_2);
|
||||||
|
assert_eq!(if_rm_true, prelude_dual);
|
||||||
|
assert_eq!(if_rm_false, conditional2);
|
||||||
|
assert_eq!(if_rz_true, post_conditional2);
|
||||||
|
assert_eq!(if_rz_false, add2_set_denormal);
|
||||||
|
assert_eq!(post_conditional2_jump, prelude_dual);
|
||||||
|
assert_eq!(post_prelude_dual, add1);
|
||||||
|
assert_eq!(post_prelude_denormal, add1);
|
||||||
|
assert_eq!(post_prelude_rounding, add1);
|
||||||
|
assert_eq!(post_add1, ret);
|
||||||
|
assert_eq!(post_add2_set_denormal, add2);
|
||||||
|
assert_eq!(post_add2, ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn branches<const N: usize>(
|
fn branches<const N: usize>(
|
||||||
|
@ -303,10 +356,12 @@ fn labels<const N: usize>(
|
||||||
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> [SpirvWord; N] {
|
) -> [SpirvWord; N] {
|
||||||
fn_.iter()
|
fn_.iter()
|
||||||
.filter_map(|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
|
.filter_map(
|
||||||
Statement::Label(label) => Some(*label),
|
|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
|
||||||
_ => None,
|
Statement::Label(label) => Some(*label),
|
||||||
})
|
_ => None,
|
||||||
|
},
|
||||||
|
)
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.try_into()
|
.try_into()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -317,7 +372,25 @@ fn calls<const N: usize>(
|
||||||
) -> [SpirvWord; N] {
|
) -> [SpirvWord; N] {
|
||||||
fn_.iter()
|
fn_.iter()
|
||||||
.filter_map(|s| match s {
|
.filter_map(|s| match s {
|
||||||
Statement::Instruction(ast::Instruction::Call { arguments: ast::CallArgs { func,.. }, .. }) => Some(*func),
|
Statement::Instruction(ast::Instruction::Call {
|
||||||
|
arguments: ast::CallArgs { func, .. },
|
||||||
|
..
|
||||||
|
}) => Some(*func),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.try_into()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conditionals<const N: usize>(
|
||||||
|
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
) -> [(SpirvWord, SpirvWord); N] {
|
||||||
|
fn_.iter()
|
||||||
|
.filter_map(|s| match s {
|
||||||
|
Statement::Conditional(BrachCondition {
|
||||||
|
if_true, if_false, ..
|
||||||
|
}) => Some((*if_true, *if_false)),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
|
|
|
@ -7,7 +7,7 @@ use super::*;
|
||||||
// represent kernels as separate nodes with its own separate entry/exit mode
|
// represent kernels as separate nodes with its own separate entry/exit mode
|
||||||
// * Inserts label at the start of every basic block
|
// * Inserts label at the start of every basic block
|
||||||
// * Insert explicit jumps before labels
|
// * Insert explicit jumps before labels
|
||||||
// * Functions get a single `ret;` exit point - this is because mode computation
|
// * Non-.entry methods get a single `ret;` exit point - this is because mode computation
|
||||||
// logic requires it. Control flow graph constructed by mode computation
|
// logic requires it. Control flow graph constructed by mode computation
|
||||||
// models function calls as jumps into and then from another function.
|
// models function calls as jumps into and then from another function.
|
||||||
// If this cfg allowed multiple return basic blocks then there would be cases
|
// If this cfg allowed multiple return basic blocks then there would be cases
|
||||||
|
@ -19,10 +19,10 @@ pub(crate) fn run(
|
||||||
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
for directive in directives.iter_mut() {
|
for directive in directives.iter_mut() {
|
||||||
let body_ref = match directive {
|
let (body_ref, is_kernel) = match directive {
|
||||||
Directive2::Method(Function2 {
|
Directive2::Method(Function2 {
|
||||||
body: Some(body), ..
|
body: Some(body), is_kernel, ..
|
||||||
}) => body,
|
}) => (body, *is_kernel),
|
||||||
_ => continue,
|
_ => continue,
|
||||||
};
|
};
|
||||||
let body = std::mem::replace(body_ref, Vec::new());
|
let body = std::mem::replace(body_ref, Vec::new());
|
||||||
|
@ -74,7 +74,9 @@ pub(crate) fn run(
|
||||||
return Err(error_unreachable());
|
return Err(error_unreachable());
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||||
return_statements.push(result.len())
|
if !is_kernel {
|
||||||
|
return_statements.push(result.len());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue