Parse ld, add, ret

This commit is contained in:
Andrzej Janik 2024-08-16 16:02:26 +02:00
parent 0da45ea7d8
commit 0112880f27
3 changed files with 313 additions and 50 deletions

View file

@ -2,9 +2,10 @@ use gen_impl::parser;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{collections::hash_map, hash::Hash, rc::Rc};
use std::{collections::hash_map, hash::Hash, iter, rc::Rc};
use syn::{
parse_macro_input, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, Variant,
parse_macro_input, parse_quote, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath,
Variant,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors
@ -176,10 +177,15 @@ impl SingleOpcodeDefinition {
})
.chain(self.arguments.0.iter().map(|arg| {
let name = &arg.ident;
if arg.optional {
quote! { #name : Option<ParsedOperand<'input>> }
let arg_type = if arg.unified {
quote! { (ParsedOperand<'input>, bool) }
} else {
quote! { #name : ParsedOperand<'input> }
quote! { ParsedOperand<'input> }
};
if arg.optional {
quote! { #name : Option<#arg_type> }
} else {
quote! { #name : #arg_type }
}
}))
}
@ -477,7 +483,8 @@ fn emit_parse_function(
#type_name :: #variant => Some(#value),
}
});
let modifier_names = all_modifier.iter().map(|m| m.dot_capitalized());
let modifier_names = iter::once(Ident::new("DotUnified", Span::call_site()))
.chain(all_modifier.iter().map(|m| m.dot_capitalized()));
quote! {
impl<'input> #type_name<'input> {
fn opcode_text(self) -> Option<&'static str> {
@ -550,7 +557,16 @@ fn emit_definition_parser(
}
}
}
DotModifierRef::Direct { type_: Some(_), .. } => { todo!() }
DotModifierRef::Direct { optional: false, type_: Some(type_), name, value } => {
let variable = name.ident();
let variant = value.dot_capitalized();
let parsed_variant = value.variant_capitalized();
quote! {
any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?;
#variable = #type_ :: #parsed_variant;
}
}
DotModifierRef::Direct { optional: true, type_: Some(_), .. } => { todo!() }
DotModifierRef::Indirect { optional, value, .. } => {
let variants = value.alternatives.iter().map(|alt| {
let type_ = value.type_.as_ref().unwrap();
@ -669,7 +685,7 @@ fn emit_definition_parser(
DotModifierRef::Direct {
optional: false,
name,
type_: Some(type_),
type_: Some(_),
..
} => {
let variable = name.ident();
@ -700,11 +716,11 @@ fn emit_definition_parser(
let comma = if idx == 0 {
quote! { empty }
} else {
quote! { any.verify(|t| *t == #token_type::Comma) }
quote! { any.verify(|t| *t == #token_type::Comma).void() }
};
let pre_bracket = if arg.pre_bracket {
quote! {
any.verify(|t| *t == #token_type::LBracket).map(|_| ())
any.verify(|t| *t == #token_type::LBracket).void()
}
} else {
quote! {
@ -713,7 +729,7 @@ fn emit_definition_parser(
};
let pre_pipe = if arg.pre_pipe {
quote! {
any.verify(|t| *t == #token_type::Or).map(|_| ())
any.verify(|t| *t == #token_type::Or).void()
}
} else {
quote! {
@ -736,24 +752,42 @@ fn emit_definition_parser(
};
let post_bracket = if arg.post_bracket {
quote! {
any.verify(|t| *t == #token_type::RBracket).map(|_| ())
any.verify(|t| *t == #token_type::RBracket).void()
}
} else {
quote! {
empty
}
};
let parser = quote! {
(#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket)
};
let arg_name = &arg.ident;
if arg.optional {
let unified = if arg.unified {
quote! {
let #arg_name = opt(#parser.map(|(_, _, _, _, name, _)| name)).parse_next(stream)?;
opt(any.verify(|t| *t == #token_type::DotUnified).void()).map(|u| u.is_some())
}
} else {
quote! {
let #arg_name = #parser.map(|(_, _, _, _, name, _)| name).parse_next(stream)?;
empty
}
};
let pattern = quote! {
(#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified)
};
let arg_name = &arg.ident;
let inner_parser = if arg.unified {
quote! {
#pattern.map(|(_, _, _, _, name, _, unified)| (name, unified))
}
} else {
quote! {
#pattern.map(|(_, _, _, _, name, _, _)| name)
}
};
if arg.optional {
quote! {
let #arg_name = opt(#inner_parser).parse_next(stream)?;
}
} else {
quote! {
let #arg_name = #inner_parser.parse_next(stream)?;
}
}
});
@ -812,6 +846,10 @@ fn write_definitions_into_tokens<'a>(
};
variants.push(arg);
}
variants.push(parse_quote! {
#[token(".unified")]
DotUnified
});
(all_opcodes, all_modifiers)
}

View file

@ -25,6 +25,46 @@ pub enum StCacheOperator {
Writethrough,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdCacheOperator {
Cached,
L2Only,
Streaming,
LastUse,
Uncached,
}
#[derive(Copy, Clone)]
pub enum ArithDetails {
Integer(ArithInteger),
Float(ArithFloat),
}
impl ArithDetails {
pub fn type_(&self) -> super::ScalarType {
match self {
ArithDetails::Integer(t) => t.type_,
ArithDetails::Float(arith) => arith.type_,
}
}
}
#[derive(Copy, Clone)]
pub struct ArithInteger {
pub type_: super::ScalarType,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub type_: super::ScalarType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdStQualifier {
Weak,
@ -33,3 +73,11 @@ pub enum LdStQualifier {
Acquire(MemScope),
Release(MemScope),
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum RoundingMode {
NearestEven,
Zero,
NegativeInf,
PositiveInf,
}

View file

@ -73,7 +73,7 @@ gen::generate_instruction_type!(
},
Add {
type: { data.type_().into() },
data: ArithDetails,
data: ast::ArithDetails,
arguments<T>: {
dst: T,
src1: T,
@ -101,7 +101,7 @@ gen::generate_instruction_type!(
pub struct LdDetails {
pub qualifier: ast::LdStQualifier,
pub state_space: StateSpace,
pub caching: LdCacheOperator,
pub caching: ast::LdCacheOperator,
pub typ: Type,
pub non_coherent: bool,
}
@ -145,15 +145,6 @@ pub enum RoundingMode {
PositiveInf,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdCacheOperator {
Cached,
L2Only,
Streaming,
LastUse,
Uncached,
}
#[derive(PartialEq, Eq, Clone, Hash)]
pub enum Type {
// .param.b32 foo;
@ -203,6 +194,18 @@ impl From<RawStCacheOperator> for ast::StCacheOperator {
}
}
impl From<RawLdCacheOperator> for ast::LdCacheOperator {
fn from(value: RawLdCacheOperator) -> Self {
match value {
RawLdCacheOperator::Ca => ast::LdCacheOperator::Cached,
RawLdCacheOperator::Cg => ast::LdCacheOperator::L2Only,
RawLdCacheOperator::Cs => ast::LdCacheOperator::Streaming,
RawLdCacheOperator::Lu => ast::LdCacheOperator::LastUse,
RawLdCacheOperator::Cv => ast::LdCacheOperator::Uncached,
}
}
}
impl From<RawLdStQualifier> for ast::LdStQualifier {
fn from(value: RawLdStQualifier) -> Self {
match value {
@ -212,6 +215,17 @@ impl From<RawLdStQualifier> for ast::LdStQualifier {
}
}
impl From<RawFloatRounding> for ast::RoundingMode {
fn from(value: RawFloatRounding) -> Self {
match value {
RawFloatRounding::Rn => ast::RoundingMode::NearestEven,
RawFloatRounding::Rz => ast::RoundingMode::Zero,
RawFloatRounding::Rm => ast::RoundingMode::NegativeInf,
RawFloatRounding::Rp => ast::RoundingMode::PositiveInf,
}
}
}
type PtxParserState = Vec<PtxError>;
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>;
@ -334,6 +348,12 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as
.parse_next(stream)
}
fn fn_body<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<Vec<Instruction<ParsedOperand<'input>>>> {
repeat(3.., terminated(parse_instruction, Token::Semicolon)).parse_next(stream)
}
impl<Ident> ast::ParsedOperand<Ident> {
fn parse<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
@ -518,7 +538,7 @@ impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parse
for Token<'input>
{
fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> {
any.parse_next(input)
any.verify(|t| t == self).parse_next(input)
}
}
@ -540,14 +560,14 @@ derive_parser!(
Comma,
#[token(".")]
Dot,
#[token(";")]
Semicolon,
#[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)]
Ident(&'input str),
#[token("|")]
Or,
#[token("!")]
Not,
#[token(";")]
Semicolon,
#[token("[")]
LBracket,
#[token("]")]
@ -675,23 +695,82 @@ derive_parser!(
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld
ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => {
todo!()
let (a, unified) = a;
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() {
state.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(),
typ: Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
}
}
ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => {
todo!()
if level_prefetch_size.is_some() {
state.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
qualifier: volatile.into(),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::LdCacheOperator::Cached,
typ: Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
}
}
ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
todo!()
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
state.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
qualifier: ast::LdStQualifier::Relaxed(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::LdCacheOperator::Cached,
typ: Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
}
}
ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
todo!()
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
state.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
qualifier: ast::LdStQualifier::Acquire(scope),
state_space: ss.unwrap_or(StateSpace::Generic),
caching: ast::LdCacheOperator::Cached,
typ: Type::maybe_vector(vec, type_),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
}
}
ld.mmio.relaxed.sys{.global}.type d, [a] => {
todo!()
state.push(PtxError::Todo);
Instruction::Ld {
data: LdDetails {
qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
state_space: global.unwrap_or(StateSpace::Generic),
caching: ast::LdCacheOperator::Cached,
typ: type_.into(),
non_coherent: false
},
arguments: LdArgs { dst:d, src:a }
}
}
.ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} };
.cop: RawCacheOp = { .ca, .cg, .cs, .lu, .cv };
.cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv };
.level::eviction_priority: EvictionPriority =
{ .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate };
.level::cache_hint = { .L2::cache_hint };
@ -702,47 +781,144 @@ derive_parser!(
.u8, .u16, .u32, .u64,
.s8, .s16, .s32, .s64,
.f32, .f64 };
RawLdStQualifier = { .weak, .volatile };
StateSpace = { .global };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add
add.type d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Integer(
ast::ArithInteger {
type_,
saturate: false
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
add{.sat}.s32 d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Integer(
ast::ArithInteger {
type_: s32,
saturate: sat
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
.type: ScalarType = { .u16, .u32, .u64,
.s16, .s64,
.u16x2, .s16x2 };
ScalarType = { .s32 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add
add{.rnd}{.ftz}{.sat}.f32 d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
add{.rnd}.f64 d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
.rnd: RawFloatRounding = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add
add{.rnd}{.ftz}{.sat}.f16 d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
add{.rnd}.bf16 d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
add{.rnd}.bf16x2 d, a, b => {
todo!()
Instruction::Add {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false
}
),
arguments: AddArgs {
dst: d, src1: a, src2: b
}
}
}
.rnd: RawFloatRounding = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
ret => {
todo!()
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
}
);
@ -776,7 +952,8 @@ fn main() {
input: &tokens[..],
state: Vec::new(),
};
parse_instruction(&mut stream).unwrap();
let fn_body = fn_body.parse(stream).unwrap();
println!("{}", fn_body.len());
//parse_prefix(&mut lexer);
let mut parser = &*tokens;
println!("{}", mem::size_of::<Token>());