mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Apply computed denormal modes to basic blocks
This commit is contained in:
parent
aaa31da026
commit
82ca92c5c3
14 changed files with 626 additions and 123 deletions
|
@ -95,16 +95,7 @@ fn run_method<'input>(
|
||||||
Ok::<_, TranslateError>(body)
|
Ok::<_, TranslateError>(body)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 { body, ..method })
|
||||||
return_arguments: method.return_arguments,
|
|
||||||
name: method.name,
|
|
||||||
input_arguments: method.input_arguments,
|
|
||||||
body,
|
|
||||||
import_as: method.import_as,
|
|
||||||
tuning: method.tuning,
|
|
||||||
linkage: method.linkage,
|
|
||||||
is_kernel: method.is_kernel,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'input>(
|
fn run_statement<'input>(
|
||||||
|
|
|
@ -243,6 +243,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
}
|
}
|
||||||
if !method.is_kernel {
|
if !method.is_kernel {
|
||||||
self.resolver.register(method.name, fn_);
|
self.resolver.register(method.name, fn_);
|
||||||
|
self.emit_fn_attribute(fn_, "denormal-fp-math-f32", "dynamic");
|
||||||
|
self.emit_fn_attribute(fn_, "denormal-fp-math", "dynamic");
|
||||||
|
} else {
|
||||||
|
self.emit_fn_attribute(
|
||||||
|
fn_,
|
||||||
|
"denormal-fp-math-f32",
|
||||||
|
llvm_ftz(method.flush_to_zero_f32),
|
||||||
|
);
|
||||||
|
self.emit_fn_attribute(
|
||||||
|
fn_,
|
||||||
|
"denormal-fp-math",
|
||||||
|
llvm_ftz(method.flush_to_zero_f16f64),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
for (i, param) in method.input_arguments.iter().enumerate() {
|
for (i, param) in method.input_arguments.iter().enumerate() {
|
||||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||||
|
@ -413,6 +426,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn llvm_ftz(ftz: bool) -> &'static str {
|
||||||
|
if ftz {
|
||||||
|
"preserve-sign"
|
||||||
|
} else {
|
||||||
|
"ieee"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_input_argument_type(
|
fn get_input_argument_type(
|
||||||
context: LLVMContextRef,
|
context: LLVMContextRef,
|
||||||
v_type: &ast::Type,
|
v_type: &ast::Type,
|
||||||
|
@ -469,6 +490,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
Statement::FunctionPointer(_) => todo!(),
|
Statement::FunctionPointer(_) => todo!(),
|
||||||
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
|
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
|
||||||
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
|
||||||
|
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1124,7 +1146,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
let cos = self.emit_intrinsic(
|
let cos = self.emit_intrinsic(
|
||||||
c"llvm.cos.f32",
|
c"llvm.cos.f32",
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&ast::ScalarType::F32.into(),
|
Some(&ast::ScalarType::F32.into()),
|
||||||
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||||
)?;
|
)?;
|
||||||
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
||||||
|
@ -1377,7 +1399,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
let sin = self.emit_intrinsic(
|
let sin = self.emit_intrinsic(
|
||||||
c"llvm.sin.f32",
|
c"llvm.sin.f32",
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&ast::ScalarType::F32.into(),
|
Some(&ast::ScalarType::F32.into()),
|
||||||
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||||
)?;
|
)?;
|
||||||
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
|
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
|
||||||
|
@ -1388,12 +1410,12 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
&mut self,
|
&mut self,
|
||||||
name: &CStr,
|
name: &CStr,
|
||||||
dst: Option<SpirvWord>,
|
dst: Option<SpirvWord>,
|
||||||
return_type: &ast::Type,
|
return_type: Option<&ast::Type>,
|
||||||
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
|
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
|
||||||
) -> Result<LLVMValueRef, TranslateError> {
|
) -> Result<LLVMValueRef, TranslateError> {
|
||||||
let fn_type = get_function_type(
|
let fn_type = get_function_type(
|
||||||
self.context,
|
self.context,
|
||||||
iter::once(return_type),
|
return_type.into_iter(),
|
||||||
arguments.iter().map(|(_, type_)| Ok(*type_)),
|
arguments.iter().map(|(_, type_)| Ok(*type_)),
|
||||||
)?;
|
)?;
|
||||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
||||||
|
@ -1612,7 +1634,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
let clamped = self.emit_intrinsic(
|
let clamped = self.emit_intrinsic(
|
||||||
c"llvm.umin",
|
c"llvm.umin",
|
||||||
None,
|
None,
|
||||||
&from.into(),
|
Some(&from.into()),
|
||||||
vec![
|
vec![
|
||||||
(self.resolver.value(arguments.src)?, from_llvm),
|
(self.resolver.value(arguments.src)?, from_llvm),
|
||||||
(max, from_llvm),
|
(max, from_llvm),
|
||||||
|
@ -1642,7 +1664,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
let zero_clamped = self.emit_intrinsic(
|
let zero_clamped = self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
|
||||||
None,
|
None,
|
||||||
&from.into(),
|
Some(&from.into()),
|
||||||
vec![
|
vec![
|
||||||
(self.resolver.value(arguments.src)?, from_llvm),
|
(self.resolver.value(arguments.src)?, from_llvm),
|
||||||
(zero, from_llvm),
|
(zero, from_llvm),
|
||||||
|
@ -1661,7 +1683,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
let fully_clamped = self.emit_intrinsic(
|
let fully_clamped = self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
|
||||||
None,
|
None,
|
||||||
&from.into(),
|
Some(&from.into()),
|
||||||
vec![(zero_clamped, from_llvm), (max, from_llvm)],
|
vec![(zero_clamped, from_llvm), (max, from_llvm)],
|
||||||
)?;
|
)?;
|
||||||
let resize_fn = if to.layout().size() >= from.layout().size() {
|
let resize_fn = if to.layout().size() >= from.layout().size() {
|
||||||
|
@ -1701,7 +1723,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
let rounded_float = self.emit_intrinsic(
|
let rounded_float = self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
None,
|
None,
|
||||||
&from.into(),
|
Some(&from.into()),
|
||||||
vec![(
|
vec![(
|
||||||
self.resolver.value(arguments.src)?,
|
self.resolver.value(arguments.src)?,
|
||||||
get_scalar_type(self.context, from),
|
get_scalar_type(self.context, from),
|
||||||
|
@ -1770,7 +1792,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
Some(&data.type_.into()),
|
||||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1791,7 +1813,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
Some(&data.type_.into()),
|
||||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1813,7 +1835,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
Some(&data.type_.into()),
|
||||||
vec![(self.resolver.value(arguments.src)?, type_)],
|
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1935,7 +1957,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
Some(&data.type_.into()),
|
||||||
vec![(
|
vec![(
|
||||||
self.resolver.value(arguments.src)?,
|
self.resolver.value(arguments.src)?,
|
||||||
get_scalar_type(self.context, data.type_),
|
get_scalar_type(self.context, data.type_),
|
||||||
|
@ -1952,7 +1974,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
c"llvm.amdgcn.log.f32",
|
c"llvm.amdgcn.log.f32",
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&ast::ScalarType::F32.into(),
|
Some(&ast::ScalarType::F32.into()),
|
||||||
vec![(
|
vec![(
|
||||||
self.resolver.value(arguments.src)?,
|
self.resolver.value(arguments.src)?,
|
||||||
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
||||||
|
@ -2007,7 +2029,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&type_.into(),
|
Some(&type_.into()),
|
||||||
vec![(self.resolver.value(arguments.src)?, llvm_type)],
|
vec![(self.resolver.value(arguments.src)?, llvm_type)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -2031,7 +2053,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_().into(),
|
Some(&data.type_().into()),
|
||||||
vec![
|
vec![
|
||||||
(self.resolver.value(arguments.src1)?, llvm_type),
|
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||||
(self.resolver.value(arguments.src2)?, llvm_type),
|
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||||
|
@ -2058,7 +2080,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_().into(),
|
Some(&data.type_().into()),
|
||||||
vec![
|
vec![
|
||||||
(self.resolver.value(arguments.src1)?, llvm_type),
|
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||||
(self.resolver.value(arguments.src2)?, llvm_type),
|
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||||
|
@ -2076,7 +2098,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
Some(&data.type_.into()),
|
||||||
vec![
|
vec![
|
||||||
(
|
(
|
||||||
self.resolver.value(arguments.src1)?,
|
self.resolver.value(arguments.src1)?,
|
||||||
|
@ -2197,12 +2219,49 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
Some(&data.type_.into()),
|
||||||
intrinsic_arguments,
|
intrinsic_arguments,
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_set_mode(&mut self, mode_reg: ModeRegister) -> Result<(), TranslateError> {
|
||||||
|
let intrinsic = c"llvm.amdgcn.s.setreg";
|
||||||
|
let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32);
|
||||||
|
let (hwreg, value) = match mode_reg {
|
||||||
|
ModeRegister::DenormalF32(ftz) => {
|
||||||
|
let (reg, offset, size) = (1, 4, 2u32);
|
||||||
|
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||||
|
(hwreg, if ftz { 0u32 } else { 3 })
|
||||||
|
}
|
||||||
|
ModeRegister::DenormalF16F64(ftz) => {
|
||||||
|
let (reg, offset, size) = (1, 6, 2u32);
|
||||||
|
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||||
|
(hwreg, if ftz { 0 } else { 3 })
|
||||||
|
}
|
||||||
|
ModeRegister::DenormalBoth { f32, f16f64 } => {
|
||||||
|
let (reg, offset, size) = (1, 4, 4u32);
|
||||||
|
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
|
||||||
|
let f32 = if f32 { 0 } else { 3 };
|
||||||
|
let f16f64 = if f16f64 { 0 } else { 3 };
|
||||||
|
let value = f32 | f16f64 << 2;
|
||||||
|
(hwreg, value)
|
||||||
|
}
|
||||||
|
ModeRegister::RoundingF32(rounding_mode) => todo!(),
|
||||||
|
ModeRegister::RoundingF16F64(rounding_mode) => todo!(),
|
||||||
|
ModeRegister::RoundingBoth { f32, f16f64 } => todo!(),
|
||||||
|
};
|
||||||
|
let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) };
|
||||||
|
let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) };
|
||||||
|
self.emit_intrinsic(
|
||||||
|
intrinsic,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
vec![(hwreg_llvm, llvm_i32), (value_llvm, llvm_i32)],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||||
// Should be available in LLVM 19
|
// Should be available in LLVM 19
|
||||||
|
|
|
@ -41,14 +41,18 @@ fn run_method<'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
|
body,
|
||||||
return_arguments: method.return_arguments,
|
return_arguments: method.return_arguments,
|
||||||
name: method.name,
|
name: method.name,
|
||||||
input_arguments: method.input_arguments,
|
input_arguments: method.input_arguments,
|
||||||
body,
|
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
is_kernel: method.is_kernel,
|
is_kernel: method.is_kernel,
|
||||||
|
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||||
|
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||||
|
roundind_mode_f32: method.roundind_mode_f32,
|
||||||
|
roundind_mode_f16f64: method.roundind_mode_f16f64,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,10 @@ pub(super) fn run<'a, 'input>(
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
linkage: ast::LinkingDirective::EXTERN,
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
is_kernel: false,
|
is_kernel: false,
|
||||||
|
flush_to_zero_f32: false,
|
||||||
|
flush_to_zero_f16f64: false,
|
||||||
|
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||||
|
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||||
}));
|
}));
|
||||||
sreg_to_function.insert(sreg, name);
|
sreg_to_function.insert(sreg, name);
|
||||||
},
|
},
|
||||||
|
@ -60,16 +64,7 @@ fn run_method<'a, 'input>(
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 { body, ..method })
|
||||||
return_arguments: method.return_arguments,
|
|
||||||
name: method.name,
|
|
||||||
input_arguments: method.input_arguments,
|
|
||||||
body,
|
|
||||||
import_as: method.import_as,
|
|
||||||
tuning: method.tuning,
|
|
||||||
linkage: method.linkage,
|
|
||||||
is_kernel: method.is_kernel,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'a, 'input>(
|
fn run_statement<'a, 'input>(
|
||||||
|
|
|
@ -64,16 +64,7 @@ fn run_method<'a, 'input>(
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 { body, ..method })
|
||||||
return_arguments: method.return_arguments,
|
|
||||||
name: method.name,
|
|
||||||
input_arguments: method.input_arguments,
|
|
||||||
body,
|
|
||||||
import_as: method.import_as,
|
|
||||||
tuning: method.tuning,
|
|
||||||
linkage: method.linkage,
|
|
||||||
is_kernel: method.is_kernel,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'a, 'input>(
|
fn run_statement<'a, 'input>(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use super::BrachCondition;
|
use super::BrachCondition;
|
||||||
use super::Directive2;
|
use super::Directive2;
|
||||||
use super::Function2;
|
use super::Function2;
|
||||||
|
use super::ModeRegister;
|
||||||
use super::SpirvWord;
|
use super::SpirvWord;
|
||||||
use super::Statement;
|
use super::Statement;
|
||||||
use super::TranslateError;
|
use super::TranslateError;
|
||||||
|
@ -18,6 +19,7 @@ use rustc_hash::FxHashSet;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
use std::u32;
|
||||||
use strum::EnumCount;
|
use strum::EnumCount;
|
||||||
use strum_macros::{EnumCount, VariantArray};
|
use strum_macros::{EnumCount, VariantArray};
|
||||||
|
|
||||||
|
@ -36,6 +38,13 @@ impl DenormalMode {
|
||||||
DenormalMode::Preserve
|
DenormalMode::Preserve
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn to_ftz(self) -> bool {
|
||||||
|
match self {
|
||||||
|
DenormalMode::FlushToZero => true,
|
||||||
|
DenormalMode::Preserve => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Into<usize> for DenormalMode {
|
impl Into<usize> for DenormalMode {
|
||||||
|
@ -94,20 +103,19 @@ impl InstructionModes {
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn set_if_some<T: Copy>(source: &mut Option<T>, value: Option<T>) {
|
fn set_if_any<T: Copy>(source: &mut Option<T>, value: Option<T>) {
|
||||||
match (source, value) {
|
if let Some(x) = value {
|
||||||
(Some(ref mut x), Some(y)) => *x = y,
|
*source = Some(x);
|
||||||
_ => {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
set_if_none(&mut entry.denormal_f32, self.denormal_f32);
|
set_if_none(&mut entry.denormal_f32, self.denormal_f32);
|
||||||
set_if_none(&mut entry.denormal_f16f64, self.denormal_f16f64);
|
set_if_none(&mut entry.denormal_f16f64, self.denormal_f16f64);
|
||||||
set_if_none(&mut entry.rounding_f32, self.rounding_f32);
|
set_if_none(&mut entry.rounding_f32, self.rounding_f32);
|
||||||
set_if_none(&mut entry.rounding_f16f64, self.rounding_f16f64);
|
set_if_none(&mut entry.rounding_f16f64, self.rounding_f16f64);
|
||||||
set_if_some(&mut exit.denormal_f32, self.denormal_f32);
|
set_if_any(&mut exit.denormal_f32, self.denormal_f32);
|
||||||
set_if_some(&mut exit.denormal_f16f64, self.denormal_f16f64);
|
set_if_any(&mut exit.denormal_f16f64, self.denormal_f16f64);
|
||||||
set_if_some(&mut exit.rounding_f32, self.rounding_f32);
|
set_if_any(&mut exit.rounding_f32, self.rounding_f32);
|
||||||
set_if_some(&mut exit.rounding_f16f64, self.rounding_f16f64);
|
set_if_any(&mut exit.rounding_f16f64, self.rounding_f16f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn none() -> Self {
|
fn none() -> Self {
|
||||||
|
@ -209,18 +217,12 @@ impl InstructionModes {
|
||||||
flush_to_zero.map(DenormalMode::from_ftz),
|
flush_to_zero.map(DenormalMode::from_ftz),
|
||||||
Some(RoundingMode::from_ast(rounding)),
|
Some(RoundingMode::from_ast(rounding)),
|
||||||
),
|
),
|
||||||
ast::CvtMode::SignedFromFP {
|
// float to int contains rounding field, but it's not a rounding
|
||||||
flush_to_zero,
|
// mode but rather round-to-int operation that will be applied
|
||||||
rounding,
|
ast::CvtMode::SignedFromFP { flush_to_zero, .. }
|
||||||
|
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => {
|
||||||
|
Self::new(cvt.from, flush_to_zero.map(DenormalMode::from_ftz), None)
|
||||||
}
|
}
|
||||||
| ast::CvtMode::UnsignedFromFP {
|
|
||||||
flush_to_zero,
|
|
||||||
rounding,
|
|
||||||
} => Self::new(
|
|
||||||
cvt.from,
|
|
||||||
flush_to_zero.map(DenormalMode::from_ftz),
|
|
||||||
Some(RoundingMode::from_ast(rounding)),
|
|
||||||
),
|
|
||||||
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
|
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
|
||||||
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
|
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
|
||||||
}
|
}
|
||||||
|
@ -263,22 +265,15 @@ impl ControlFlowGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) {
|
fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) {
|
||||||
self.graph[node].denormal_f32 = Mode {
|
let node = &mut self.graph[node];
|
||||||
entry: entry.denormal_f32.map(ExtendedMode::BasicBlock),
|
node.denormal_f32.entry = entry.denormal_f32.map(ExtendedMode::BasicBlock);
|
||||||
exit: exit.denormal_f32.map(ExtendedMode::BasicBlock),
|
node.denormal_f16f64.entry = entry.denormal_f16f64.map(ExtendedMode::BasicBlock);
|
||||||
};
|
node.rounding_f32.entry = entry.rounding_f32.map(ExtendedMode::BasicBlock);
|
||||||
self.graph[node].denormal_f16f64 = Mode {
|
node.rounding_f16f64.entry = entry.rounding_f16f64.map(ExtendedMode::BasicBlock);
|
||||||
entry: entry.denormal_f16f64.map(ExtendedMode::BasicBlock),
|
node.denormal_f32.exit = exit.denormal_f32.map(ExtendedMode::BasicBlock);
|
||||||
exit: exit.denormal_f16f64.map(ExtendedMode::BasicBlock),
|
node.denormal_f16f64.exit = exit.denormal_f16f64.map(ExtendedMode::BasicBlock);
|
||||||
};
|
node.rounding_f32.exit = exit.rounding_f32.map(ExtendedMode::BasicBlock);
|
||||||
self.graph[node].rounding_f32 = Mode {
|
node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock);
|
||||||
entry: entry.rounding_f32.map(ExtendedMode::BasicBlock),
|
|
||||||
exit: exit.rounding_f32.map(ExtendedMode::BasicBlock),
|
|
||||||
};
|
|
||||||
self.graph[node].rounding_f16f64 = Mode {
|
|
||||||
entry: entry.rounding_f16f64.map(ExtendedMode::BasicBlock),
|
|
||||||
exit: exit.rounding_f16f64.map(ExtendedMode::BasicBlock),
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,7 +338,7 @@ trait EnumTuple {
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
mut directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
let mut cfg = ControlFlowGraph::new();
|
let mut cfg = ControlFlowGraph::new();
|
||||||
for directive in directives.iter() {
|
for directive in directives.iter() {
|
||||||
|
@ -351,42 +346,39 @@ pub(crate) fn run<'input>(
|
||||||
super::Directive2::Method(Function2 {
|
super::Directive2::Method(Function2 {
|
||||||
name,
|
name,
|
||||||
body: Some(body),
|
body: Some(body),
|
||||||
|
is_kernel,
|
||||||
..
|
..
|
||||||
}) => {
|
}) => {
|
||||||
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
|
// TODO: implement for non-kernels
|
||||||
let mut entry = InstructionModes::none();
|
if !*is_kernel {
|
||||||
let mut exit = InstructionModes::none();
|
todo!()
|
||||||
|
}
|
||||||
|
let entry_index = cfg.add_entry_basic_block(*name);
|
||||||
|
let mut bb_state = BasicBlockState::new(&mut cfg);
|
||||||
|
let mut body_iter = body.iter();
|
||||||
|
match body_iter.next() {
|
||||||
|
Some(Statement::Label(label)) => {
|
||||||
|
bb_state.cfg.add_jump(entry_index, *label);
|
||||||
|
bb_state.start(*label);
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
for statement in body.iter() {
|
for statement in body.iter() {
|
||||||
match statement {
|
match statement {
|
||||||
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
|
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
|
||||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
bb_state.end(&[arguments.src]);
|
||||||
cfg.add_jump(bb_index, arguments.src);
|
|
||||||
cfg.set_modes(
|
|
||||||
bb_index,
|
|
||||||
mem::replace(&mut entry, InstructionModes::none()),
|
|
||||||
mem::replace(&mut exit, InstructionModes::none()),
|
|
||||||
);
|
|
||||||
basic_block = None;
|
|
||||||
}
|
}
|
||||||
Statement::Label(label) => {
|
Statement::Label(label) => {
|
||||||
basic_block = Some(cfg.get_or_add_basic_block(*label));
|
bb_state.start(*label);
|
||||||
}
|
}
|
||||||
Statement::Conditional(BrachCondition {
|
Statement::Conditional(BrachCondition {
|
||||||
if_true, if_false, ..
|
if_true, if_false, ..
|
||||||
}) => {
|
}) => {
|
||||||
let bb_index = basic_block.ok_or_else(error_unreachable)?;
|
bb_state.end(&[*if_true, *if_false]);
|
||||||
cfg.add_jump(bb_index, *if_true);
|
|
||||||
cfg.add_jump(bb_index, *if_false);
|
|
||||||
cfg.set_modes(
|
|
||||||
bb_index,
|
|
||||||
mem::replace(&mut entry, InstructionModes::none()),
|
|
||||||
mem::replace(&mut exit, InstructionModes::none()),
|
|
||||||
);
|
|
||||||
basic_block = None;
|
|
||||||
}
|
}
|
||||||
Statement::Instruction(instruction) => {
|
Statement::Instruction(instruction) => {
|
||||||
let modes = get_modes(instruction);
|
let modes = get_modes(instruction);
|
||||||
modes.fold_into(&mut entry, &mut exit);
|
bb_state.append(modes);
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
@ -395,7 +387,370 @@ pub(crate) fn run<'input>(
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
todo!()
|
let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32);
|
||||||
|
let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64);
|
||||||
|
let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32);
|
||||||
|
let rounding_f16f64 = compute_single_mode(&cfg, |node| node.rounding_f16f64);
|
||||||
|
let denormal_f32 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f32);
|
||||||
|
let denormal_f16f64 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f16f64);
|
||||||
|
let rounding_f32 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f32);
|
||||||
|
let rounding_f16f64 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f16f64);
|
||||||
|
insert_mode_control(
|
||||||
|
flat_resolver,
|
||||||
|
&mut directives,
|
||||||
|
&cfg,
|
||||||
|
denormal_f32,
|
||||||
|
denormal_f16f64,
|
||||||
|
rounding_f32,
|
||||||
|
rounding_f16f64,
|
||||||
|
)?;
|
||||||
|
Ok(directives)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mode_control<'input>(
|
||||||
|
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||||
|
directives: &mut [Directive2<ast::Instruction<SpirvWord>, SpirvWord>],
|
||||||
|
cfg: &ControlFlowGraph,
|
||||||
|
denormal_f32: ModeInsertions<DenormalMode>,
|
||||||
|
denormal_f16f64: ModeInsertions<DenormalMode>,
|
||||||
|
rounding_f32: ModeInsertions<RoundingMode>,
|
||||||
|
rounding_f16f64: ModeInsertions<RoundingMode>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
for directive in directives.iter_mut() {
|
||||||
|
let body_ptr = match directive {
|
||||||
|
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => continue,
|
||||||
|
Directive2::Method(Function2 {
|
||||||
|
name,
|
||||||
|
body: Some(body),
|
||||||
|
flush_to_zero_f32,
|
||||||
|
flush_to_zero_f16f64,
|
||||||
|
roundind_mode_f32: rounding_mode_f32,
|
||||||
|
roundind_mode_f16f64: rounding_mode_f16f64,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
*flush_to_zero_f32 = denormal_f32
|
||||||
|
.kernels
|
||||||
|
.get(name)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(DenormalMode::default())
|
||||||
|
.to_ftz();
|
||||||
|
*flush_to_zero_f16f64 = denormal_f16f64
|
||||||
|
.kernels
|
||||||
|
.get(name)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(DenormalMode::default())
|
||||||
|
.to_ftz();
|
||||||
|
*rounding_mode_f32 = rounding_f32
|
||||||
|
.kernels
|
||||||
|
.get(name)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(RoundingMode::default())
|
||||||
|
.to_ast();
|
||||||
|
*rounding_mode_f16f64 = rounding_f16f64
|
||||||
|
.kernels
|
||||||
|
.get(name)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(RoundingMode::default())
|
||||||
|
.to_ast();
|
||||||
|
body
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut old_body = mem::replace(body_ptr, Vec::new());
|
||||||
|
let mut result = Vec::with_capacity(old_body.len());
|
||||||
|
let mut bb_state = BasicBlockControlState::new(
|
||||||
|
&denormal_f32,
|
||||||
|
&denormal_f16f64,
|
||||||
|
&rounding_f32,
|
||||||
|
&rounding_f16f64,
|
||||||
|
);
|
||||||
|
for statement in old_body.into_iter() {
|
||||||
|
match &statement {
|
||||||
|
Statement::Label(label) => {
|
||||||
|
bb_state.start(*label);
|
||||||
|
}
|
||||||
|
Statement::Instruction(instruction) => {
|
||||||
|
let modes = get_modes(&instruction);
|
||||||
|
bb_state.insert(&mut result, modes)?;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
result.push(statement);
|
||||||
|
}
|
||||||
|
*body_ptr = result;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BasicBlockControlState<'a> {
|
||||||
|
global_denormal_f32: &'a ModeInsertions<DenormalMode>,
|
||||||
|
global_denormal_f16f64: &'a ModeInsertions<DenormalMode>,
|
||||||
|
global_rounding_f32: &'a ModeInsertions<RoundingMode>,
|
||||||
|
global_rounding_f16f64: &'a ModeInsertions<RoundingMode>,
|
||||||
|
basic_block: SpirvWord,
|
||||||
|
denormal_f32: RegisterState<bool>,
|
||||||
|
denormal_f16f64: RegisterState<bool>,
|
||||||
|
foldable_rounding_f32: Option<usize>,
|
||||||
|
foldable_rounding_f16f64: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
enum RegisterState<T> {
|
||||||
|
Inherited,
|
||||||
|
Unknown,
|
||||||
|
Value(Option<usize>, T),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> RegisterState<T> {
|
||||||
|
fn empty() -> Self {
|
||||||
|
Self::Unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(must_insert: bool) -> Self {
|
||||||
|
if must_insert {
|
||||||
|
Self::Unknown
|
||||||
|
} else {
|
||||||
|
Self::Inherited
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> BasicBlockControlState<'a> {
|
||||||
|
fn new(
|
||||||
|
global_denormal_f32: &'a ModeInsertions<DenormalMode>,
|
||||||
|
global_denormal_f16f64: &'a ModeInsertions<DenormalMode>,
|
||||||
|
global_rounding_f32: &'a ModeInsertions<RoundingMode>,
|
||||||
|
global_rounding_f16f64: &'a ModeInsertions<RoundingMode>,
|
||||||
|
) -> Self {
|
||||||
|
BasicBlockControlState {
|
||||||
|
global_denormal_f32,
|
||||||
|
global_denormal_f16f64,
|
||||||
|
global_rounding_f32,
|
||||||
|
global_rounding_f16f64,
|
||||||
|
basic_block: SpirvWord(u32::MAX),
|
||||||
|
denormal_f32: RegisterState::empty(),
|
||||||
|
denormal_f16f64: RegisterState::empty(),
|
||||||
|
foldable_rounding_f32: None,
|
||||||
|
foldable_rounding_f16f64: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start(&mut self, label: SpirvWord) {
|
||||||
|
self.denormal_f32 =
|
||||||
|
RegisterState::new(self.global_denormal_f32.basic_blocks.contains(&label));
|
||||||
|
self.denormal_f32 =
|
||||||
|
RegisterState::new(self.global_denormal_f16f64.basic_blocks.contains(&label));
|
||||||
|
self.foldable_rounding_f32 = None;
|
||||||
|
self.foldable_rounding_f16f64 = None;
|
||||||
|
self.basic_block = label;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_or_fold_mode_set(
|
||||||
|
&mut self,
|
||||||
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
new_mode: bool,
|
||||||
|
) -> Option<usize> {
|
||||||
|
// try and fold into the other mode set
|
||||||
|
if let RegisterState::Value(Some(other_index), other_value) = self.denormal_f16f64 {
|
||||||
|
if let Some(Statement::SetMode(ModeRegister::DenormalF16F64(_))) =
|
||||||
|
result.get_mut(other_index)
|
||||||
|
{
|
||||||
|
result[other_index] = Statement::SetMode(ModeRegister::DenormalBoth {
|
||||||
|
f32: new_mode,
|
||||||
|
f16f64: other_value,
|
||||||
|
});
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push(Statement::SetMode(ModeRegister::DenormalF32(new_mode)));
|
||||||
|
Some(result.len() - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert(
|
||||||
|
&mut self,
|
||||||
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
modes: InstructionModes,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
self.insert_one::<DenormalF32View>(result, modes.denormal_f32.map(DenormalMode::to_ftz))?;
|
||||||
|
self.insert_one::<DenormalF16F64View>(
|
||||||
|
result,
|
||||||
|
modes.denormal_f16f64.map(DenormalMode::to_ftz),
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_one<View: ModeView>(
|
||||||
|
&mut self,
|
||||||
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
mode: Option<View::Value>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
if let Some(new_mode) = mode {
|
||||||
|
let register_state = View::get_register(self);
|
||||||
|
match register_state {
|
||||||
|
RegisterState::Inherited => {
|
||||||
|
View::set_register(self, RegisterState::Value(None, new_mode));
|
||||||
|
}
|
||||||
|
RegisterState::Unknown => {
|
||||||
|
View::set_register(
|
||||||
|
self,
|
||||||
|
RegisterState::Value(
|
||||||
|
Some(self.add_or_fold_mode_set2::<View>(result, new_mode)),
|
||||||
|
new_mode,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
RegisterState::Value(_, old_value) => {
|
||||||
|
if new_mode == old_value {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
View::set_register(
|
||||||
|
self,
|
||||||
|
RegisterState::Value(
|
||||||
|
Some(self.add_or_fold_mode_set2::<View>(result, new_mode)),
|
||||||
|
new_mode,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the index of the last insertion of SetMode with this mode
|
||||||
|
fn add_or_fold_mode_set2<View: ModeView>(
|
||||||
|
&self,
|
||||||
|
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
new_mode: View::Value,
|
||||||
|
) -> usize {
|
||||||
|
// try and fold into the other mode set in struction
|
||||||
|
if let RegisterState::Value(Some(twin_index), _) = View::TwinView::get_register(self) {
|
||||||
|
if let Some(Statement::SetMode(register_mode)) = result.get_mut(twin_index) {
|
||||||
|
if let Some(twin_mode) = View::TwinView::get_single_mode(register_mode) {
|
||||||
|
*register_mode = View::new_mode(new_mode, Some(twin_mode));
|
||||||
|
return twin_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push(Statement::SetMode(View::new_mode(new_mode, None)));
|
||||||
|
result.len() - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait ModeView {
|
||||||
|
type Value: PartialEq + Eq + Copy + Clone;
|
||||||
|
type TwinView: ModeView<Value = Self::Value>;
|
||||||
|
|
||||||
|
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value>;
|
||||||
|
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>);
|
||||||
|
fn new_mode(t: Self::Value, other: Option<Self::Value>) -> ModeRegister;
|
||||||
|
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value>;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DenormalF32View;
|
||||||
|
|
||||||
|
impl ModeView for DenormalF32View {
|
||||||
|
type Value = bool;
|
||||||
|
type TwinView = DenormalF16F64View;
|
||||||
|
|
||||||
|
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value> {
|
||||||
|
bb.denormal_f32
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>) {
|
||||||
|
bb.denormal_f32 = reg;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_mode(f32: Self::Value, f16f64: Option<Self::Value>) -> ModeRegister {
|
||||||
|
match f16f64 {
|
||||||
|
Some(f16f64) => ModeRegister::DenormalBoth { f32, f16f64 },
|
||||||
|
None => ModeRegister::DenormalF32(f32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value> {
|
||||||
|
match reg {
|
||||||
|
ModeRegister::DenormalF32(value) => Some(*value),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DenormalF16F64View;
|
||||||
|
|
||||||
|
impl ModeView for DenormalF16F64View {
|
||||||
|
type Value = bool;
|
||||||
|
type TwinView = DenormalF32View;
|
||||||
|
|
||||||
|
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value> {
|
||||||
|
bb.denormal_f16f64
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>) {
|
||||||
|
bb.denormal_f16f64 = reg;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_mode(f16f64: Self::Value, f32: Option<Self::Value>) -> ModeRegister {
|
||||||
|
match f32 {
|
||||||
|
Some(f32) => ModeRegister::DenormalBoth { f16f64, f32 },
|
||||||
|
None => ModeRegister::DenormalF16F64(f16f64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value> {
|
||||||
|
match reg {
|
||||||
|
ModeRegister::DenormalF16F64(value) => Some(*value),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BasicBlockState<'a> {
|
||||||
|
cfg: &'a mut ControlFlowGraph,
|
||||||
|
node_index: Option<NodeIndex>,
|
||||||
|
// If it's a kernel basic block then we don't track entry instruction mode
|
||||||
|
entry: InstructionModes,
|
||||||
|
exit: InstructionModes,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> BasicBlockState<'a> {
|
||||||
|
fn new(cfg: &'a mut ControlFlowGraph) -> BasicBlockState<'a> {
|
||||||
|
Self {
|
||||||
|
cfg,
|
||||||
|
node_index: None,
|
||||||
|
entry: InstructionModes::none(),
|
||||||
|
exit: InstructionModes::none(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start(&mut self, label: SpirvWord) {
|
||||||
|
self.end(&[]);
|
||||||
|
self.node_index = Some(self.cfg.get_or_add_basic_block(label));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end(&mut self, jumps: &[SpirvWord]) {
|
||||||
|
let node_index = self.node_index.take();
|
||||||
|
let node_index = match node_index {
|
||||||
|
Some(x) => x,
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
for target in jumps {
|
||||||
|
self.cfg.add_jump(node_index, *target);
|
||||||
|
}
|
||||||
|
self.cfg.set_modes(
|
||||||
|
node_index,
|
||||||
|
mem::replace(&mut self.entry, InstructionModes::none()),
|
||||||
|
mem::replace(&mut self.exit, InstructionModes::none()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append(&mut self, modes: InstructionModes) {
|
||||||
|
modes.fold_into(&mut self.entry, &mut self.exit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Drop for BasicBlockState<'a> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.end(&[]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_single_mode<T: Copy + Eq>(
|
fn compute_single_mode<T: Copy + Eq>(
|
||||||
|
@ -424,10 +779,9 @@ fn compute_single_mode<T: Copy + Eq>(
|
||||||
UniqueVec::new(graph.graph.neighbors_directed(index, Direction::Incoming));
|
UniqueVec::new(graph.graph.neighbors_directed(index, Direction::Incoming));
|
||||||
let mut visited = FxHashSet::default();
|
let mut visited = FxHashSet::default();
|
||||||
while let Some(current) = to_visit.pop() {
|
while let Some(current) = to_visit.pop() {
|
||||||
if visited.contains(¤t) {
|
if !visited.insert(current) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
visited.insert(current);
|
|
||||||
let exit_mode = getter(graph.graph.node_weight(current).unwrap()).exit;
|
let exit_mode = getter(graph.graph.node_weight(current).unwrap()).exit;
|
||||||
match exit_mode {
|
match exit_mode {
|
||||||
None => {
|
None => {
|
||||||
|
@ -462,6 +816,7 @@ fn compute_single_mode<T: Copy + Eq>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
struct PartialModeInsertion<T> {
|
struct PartialModeInsertion<T> {
|
||||||
bb_must_insert_mode: FxHashSet<SpirvWord>,
|
bb_must_insert_mode: FxHashSet<SpirvWord>,
|
||||||
bb_maybe_insert_mode: FxHashMap<SpirvWord, (T, FxHashSet<SpirvWord>)>,
|
bb_maybe_insert_mode: FxHashMap<SpirvWord, (T, FxHashSet<SpirvWord>)>,
|
||||||
|
@ -498,10 +853,11 @@ fn optimize<T: Copy + Into<usize> + strum::VariantArray + std::fmt::Debug, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let mut kernels = FxHashMap::default();
|
let mut kernels = FxHashMap::default();
|
||||||
for (kernel, modes) in kernel_modes {
|
'iterate_kernels: for (kernel, modes) in kernel_modes {
|
||||||
for (mode, var) in modes.into_iter().enumerate() {
|
for (mode, var) in modes.into_iter().enumerate() {
|
||||||
if solution[var] > 0.5 {
|
if solution[var] > 0.5 {
|
||||||
kernels.insert(kernel, T::VARIANTS[mode]);
|
kernels.insert(kernel, T::VARIANTS[mode]);
|
||||||
|
continue 'iterate_kernels;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ mod hoist_globals;
|
||||||
mod insert_explicit_load_store;
|
mod insert_explicit_load_store;
|
||||||
mod insert_ftz_control;
|
mod insert_ftz_control;
|
||||||
mod insert_implicit_conversions2;
|
mod insert_implicit_conversions2;
|
||||||
|
mod normalize_basic_blocks;
|
||||||
mod normalize_identifiers2;
|
mod normalize_identifiers2;
|
||||||
mod normalize_predicates2;
|
mod normalize_predicates2;
|
||||||
mod replace_instructions_with_function_calls;
|
mod replace_instructions_with_function_calls;
|
||||||
|
@ -52,6 +53,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
||||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||||
|
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives);
|
||||||
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
||||||
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
||||||
let directives = hoist_globals::run(directives)?;
|
let directives = hoist_globals::run(directives)?;
|
||||||
|
@ -197,6 +199,22 @@ enum Statement<I, P: ast::Operand> {
|
||||||
FunctionPointer(FunctionPointerDetails),
|
FunctionPointer(FunctionPointerDetails),
|
||||||
VectorRead(VectorRead),
|
VectorRead(VectorRead),
|
||||||
VectorWrite(VectorWrite),
|
VectorWrite(VectorWrite),
|
||||||
|
SetMode(ModeRegister),
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ModeRegister {
|
||||||
|
DenormalF32(bool),
|
||||||
|
DenormalF16F64(bool),
|
||||||
|
DenormalBoth {
|
||||||
|
f32: bool,
|
||||||
|
f16f64: bool,
|
||||||
|
},
|
||||||
|
RoundingF32(ast::RoundingMode),
|
||||||
|
RoundingF16F64(ast::RoundingMode),
|
||||||
|
RoundingBoth {
|
||||||
|
f32: ast::RoundingMode,
|
||||||
|
f16f64: ast::RoundingMode,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
|
@ -469,6 +487,7 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
||||||
let src = visitor.visit_ident(src, None, false, false)?;
|
let src = visitor.visit_ident(src, None, false, false)?;
|
||||||
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
|
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
|
||||||
}
|
}
|
||||||
|
Statement::SetMode(mode_register) => Statement::SetMode(mode_register),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -573,6 +592,10 @@ struct Function2<Instruction, Operand: ast::Operand> {
|
||||||
import_as: Option<String>,
|
import_as: Option<String>,
|
||||||
tuning: Vec<ast::TuningDirective>,
|
tuning: Vec<ast::TuningDirective>,
|
||||||
linkage: ast::LinkingDirective,
|
linkage: ast::LinkingDirective,
|
||||||
|
flush_to_zero_f32: bool,
|
||||||
|
flush_to_zero_f16f64: bool,
|
||||||
|
roundind_mode_f32: ast::RoundingMode,
|
||||||
|
roundind_mode_f16f64: ast::RoundingMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
type NormalizedDirective2 = Directive2<
|
type NormalizedDirective2 = Directive2<
|
||||||
|
|
52
ptx/src/pass/normalize_basic_blocks.rs
Normal file
52
ptx/src/pass/normalize_basic_blocks.rs
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// This pass normalized ptx modules in two ways that makes mode computation pass
|
||||||
|
// and code emissions passes much simpler:
|
||||||
|
// * Inserts label at the start of every function
|
||||||
|
// This makes control flow graph simpler in mode computation block: we can
|
||||||
|
// represent kernels as separate nodes with its own separate entry/exit mode
|
||||||
|
// * Inserts label at the start of every basic block
|
||||||
|
|
||||||
|
pub(crate) fn run(
|
||||||
|
flat_resolver: &mut GlobalStringIdentResolver2<'_>,
|
||||||
|
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
) -> Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>> {
|
||||||
|
for directive in directives.iter_mut() {
|
||||||
|
let body_ref = match directive {
|
||||||
|
Directive2::Method(Function2 {
|
||||||
|
body: Some(body), ..
|
||||||
|
}) => body,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
let body = std::mem::replace(body_ref, Vec::new());
|
||||||
|
let mut result = Vec::with_capacity(body.len());
|
||||||
|
let mut needs_label = false;
|
||||||
|
let mut body_iterator = body.into_iter();
|
||||||
|
match body_iterator.next() {
|
||||||
|
Some(Statement::Label(_)) => {}
|
||||||
|
Some(statement) => {
|
||||||
|
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
|
||||||
|
result.push(statement);
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
for statement in body_iterator {
|
||||||
|
if needs_label && !matches!(statement, Statement::Label(..)) {
|
||||||
|
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
|
||||||
|
}
|
||||||
|
needs_label = is_block_terminator(&statement);
|
||||||
|
result.push(statement);
|
||||||
|
}
|
||||||
|
*body_ref = result;
|
||||||
|
}
|
||||||
|
directives
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_block_terminator(instruction: &Statement<ast::Instruction<SpirvWord>, SpirvWord>) -> bool {
|
||||||
|
match instruction {
|
||||||
|
Statement::Conditional(..)
|
||||||
|
| Statement::Instruction(ast::Instruction::Bra { .. })
|
||||||
|
| Statement::Instruction(ast::Instruction::Ret { .. }) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
|
@ -52,9 +52,13 @@ fn run_method<'input, 'b>(
|
||||||
input_arguments,
|
input_arguments,
|
||||||
body,
|
body,
|
||||||
import_as: None,
|
import_as: None,
|
||||||
tuning: method.tuning,
|
|
||||||
linkage,
|
linkage,
|
||||||
is_kernel,
|
is_kernel,
|
||||||
|
tuning: method.tuning,
|
||||||
|
flush_to_zero_f32: false,
|
||||||
|
flush_to_zero_f16f64: false,
|
||||||
|
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||||
|
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,14 +36,18 @@ fn run_method<'input>(
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 {
|
||||||
|
body,
|
||||||
return_arguments: method.return_arguments,
|
return_arguments: method.return_arguments,
|
||||||
name: method.name,
|
name: method.name,
|
||||||
input_arguments: method.input_arguments,
|
input_arguments: method.input_arguments,
|
||||||
body,
|
|
||||||
import_as: method.import_as,
|
import_as: method.import_as,
|
||||||
tuning: method.tuning,
|
tuning: method.tuning,
|
||||||
linkage: method.linkage,
|
linkage: method.linkage,
|
||||||
is_kernel: method.is_kernel,
|
is_kernel: method.is_kernel,
|
||||||
|
flush_to_zero_f32: method.flush_to_zero_f32,
|
||||||
|
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
|
||||||
|
roundind_mode_f32: method.roundind_mode_f32,
|
||||||
|
roundind_mode_f16f64: method.roundind_mode_f16f64,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,10 @@ pub(super) fn run<'input>(
|
||||||
tuning: Vec::new(),
|
tuning: Vec::new(),
|
||||||
linkage: ast::LinkingDirective::EXTERN,
|
linkage: ast::LinkingDirective::EXTERN,
|
||||||
is_kernel: false,
|
is_kernel: false,
|
||||||
|
flush_to_zero_f32: false,
|
||||||
|
flush_to_zero_f16f64: false,
|
||||||
|
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||||
|
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
|
@ -40,16 +40,7 @@ fn run_method<'input>(
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
Ok(Function2 {
|
Ok(Function2 { body, ..method })
|
||||||
return_arguments: method.return_arguments,
|
|
||||||
name: method.name,
|
|
||||||
input_arguments: method.input_arguments,
|
|
||||||
body,
|
|
||||||
import_as: method.import_as,
|
|
||||||
tuning: method.tuning,
|
|
||||||
linkage: method.linkage,
|
|
||||||
is_kernel: method.is_kernel,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_statement<'input>(
|
fn run_statement<'input>(
|
||||||
|
|
27
ptx/src/test/spirv_run/malformed_label.ptx
Normal file
27
ptx/src/test/spirv_run/malformed_label.ptx
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_30
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry malformed_label(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .u64 temp;
|
||||||
|
.reg .u64 temp2;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
bra BB0;
|
||||||
|
// this basic block does not start with a label
|
||||||
|
ld.u64 temp, [out_addr];
|
||||||
|
|
||||||
|
BB0:
|
||||||
|
ld.u64 temp, [in_addr];
|
||||||
|
add.u64 temp2, temp, 1;
|
||||||
|
st.u64 [out_addr], temp2;
|
||||||
|
ret;
|
||||||
|
}
|
|
@ -186,6 +186,8 @@ test_ptx!(
|
||||||
[0x800000u32, 0xFFFFFF]
|
[0x800000u32, 0xFFFFFF]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
test_ptx!(malformed_label, [2u64], [3u64]);
|
||||||
|
|
||||||
test_ptx!(assertfail);
|
test_ptx!(assertfail);
|
||||||
test_ptx!(func_ptr);
|
test_ptx!(func_ptr);
|
||||||
test_ptx!(lanemask_lt);
|
test_ptx!(lanemask_lt);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue