Support instruction modes (denormal and rounding) on AMD GPUs (#342)

This commit is contained in:
Andrzej Janik 2025-03-17 21:37:26 +01:00 committed by GitHub
parent 867e4728d5
commit d704e92c97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
123 changed files with 6758 additions and 4120 deletions

View file

@ -1028,9 +1028,16 @@ pub struct ArithInteger {
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub type_: ScalarType,
pub rounding: Option<RoundingMode>,
pub rounding: RoundingMode,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
// From PTX documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions-add
// Note that an add instruction with an explicit rounding modifier is treated conservatively by
// the code optimizer. An add instruction with no rounding modifier defaults to
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
// instructions on the target device.
pub is_fusable: bool,
}
#[derive(Copy, Clone, PartialEq, Eq)]
@ -1042,7 +1049,7 @@ pub enum LdStQualifier {
Release(MemScope),
}
#[derive(PartialEq, Eq, Copy, Clone)]
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub enum RoundingMode {
NearestEven,
Zero,
@ -1456,6 +1463,7 @@ pub struct CvtDetails {
pub mode: CvtMode,
}
#[derive(Clone, Copy)]
pub enum CvtMode {
// int from int
ZeroExtend,
@ -1474,7 +1482,7 @@ pub enum CvtMode {
flush_to_zero: Option<bool>,
},
FPRound {
integer_rounding: Option<RoundingMode>,
integer_rounding: RoundingMode,
flush_to_zero: Option<bool>,
},
// int from float
@ -1528,7 +1536,7 @@ impl CvtDetails {
flush_to_zero,
},
Ordering::Equal => CvtMode::FPRound {
integer_rounding: rounding,
integer_rounding: rounding.unwrap_or(RoundingMode::NearestEven),
flush_to_zero,
},
Ordering::Greater => {

View file

@ -1909,9 +1909,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1924,9 +1925,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1943,9 +1945,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1958,9 +1961,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1973,9 +1977,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1988,9 +1993,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -2035,9 +2041,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2048,9 +2055,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2064,9 +2072,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2077,9 +2086,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2090,9 +2100,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2103,9 +2114,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2389,9 +2401,10 @@ derive_parser!(
data: ast::MadDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: None,
rounding: ast::RoundingMode::NearestEven,
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2402,9 +2415,10 @@ derive_parser!(
data: ast::MadDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2415,9 +2429,10 @@ derive_parser!(
data: ast::MadDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2432,9 +2447,10 @@ derive_parser!(
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
@ -2443,9 +2459,10 @@ derive_parser!(
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: false
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
@ -2457,9 +2474,10 @@ derive_parser!(
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f16,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
@ -2507,9 +2525,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2520,9 +2539,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2536,9 +2556,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2549,9 +2570,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2562,9 +2584,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2575,9 +2598,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2880,7 +2904,7 @@ derive_parser!(
rsqrt.approx.f64 d, a => {
ast::Instruction::Rsqrt {
data: ast::TypeFtz {
flush_to_zero: None,
flush_to_zero: Some(false),
type_: f64
},
arguments: RsqrtArgs { dst: d, src: a }
@ -2889,7 +2913,7 @@ derive_parser!(
rsqrt.approx.ftz.f64 d, a => {
ast::Instruction::Rsqrt {
data: ast::TypeFtz {
flush_to_zero: None,
flush_to_zero: Some(true),
type_: f64
},
arguments: RsqrtArgs { dst: d, src: a }