diff --git a/Cargo.lock b/Cargo.lock index f627511..c049067 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -820,8 +820,8 @@ dependencies = [ "ptx_parser", "quick-error", "rustc-hash 2.0.0", - "strum", - "strum_macros", + "strum 0.26.3", + "strum_macros 0.26.4", "tempfile", "thiserror 1.0.64", "unwrap_or", @@ -836,6 +836,7 @@ dependencies = [ "logos", "ptx_parser_macros", "rustc-hash 2.0.0", + "strum 0.27.1", "thiserror 1.0.64", "winnow", ] @@ -1047,6 +1048,15 @@ version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +[[package]] +name = "strum" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" +dependencies = [ + "strum_macros 0.27.1", +] + [[package]] name = "strum_macros" version = "0.26.4" @@ -1060,6 +1070,19 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "strum_macros" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.89", +] + [[package]] name = "syn" version = "1.0.109" diff --git a/compiler/src/error.rs b/compiler/src/error.rs index fc8d004..396f70d 100644 --- a/compiler/src/error.rs +++ b/compiler/src/error.rs @@ -1,6 +1,5 @@ use std::ffi::FromBytesUntilNulError; use std::io; -use std::path::PathBuf; use std::str::Utf8Error; use amd_comgr_sys::amd_comgr_status_s; @@ -10,22 +9,20 @@ use ptx_parser::PtxError; #[derive(Debug, thiserror::Error)] pub enum CompilerError { - #[error("HIP error: {0:?}")] + #[error("HIP error code: {0:?}")] HipError(hipErrorCode_t), - #[error("amd_comgr error: {0:?}")] + #[error("amd_comgr status code: {0:?}")] ComgrError(amd_comgr_status_s), - #[error("Not a regular file: {0}")] - CheckPathError(PathBuf), - #[error("Invalid output type: {0}")] - ParseOutputTypeError(String), - #[error("Error translating PTX: {0:?}")] - PtxTranslateError(TranslateError), - #[error("IO error: {0:?}")] - IoError(io::Error), - #[error("Error parsing file: {0:?}")] - ParseFileError(Utf8Error), - #[error("Error: {0}")] - GenericError(String) + #[error(transparent)] + IoError(#[from] io::Error), + #[error(transparent)] + Utf8Error(#[from] Utf8Error), + #[error("{message}")] + GenericError { + #[source] + cause: Option>, + message: String, + }, } impl From for CompilerError { @@ -42,32 +39,37 @@ impl From for CompilerError { impl From>> for CompilerError { fn from(causes: Vec) -> Self { - let errors: Vec = causes.iter().map(PtxError::to_string).collect(); - let msg = errors.join("\n"); - CompilerError::GenericError(msg) - } -} - -impl From for CompilerError { - fn from(cause: io::Error) -> Self { - CompilerError::IoError(cause) - } -} - -impl From for CompilerError { - fn from(cause: Utf8Error) -> Self { - CompilerError::ParseFileError(cause) + let errors: Vec = causes + .iter() + .map(|e| { + let msg = match e { + PtxError::UnrecognizedStatement(value) + | PtxError::UnrecognizedDirective(value) => value.unwrap_or("").to_string(), + other => other.to_string(), + }; + format!("PtxError::{}: {}", e.as_ref(), msg) + }) + .collect(); + let message = errors.join("\n"); + CompilerError::GenericError { + cause: None, + message, + } } } impl From for CompilerError { fn from(cause: TranslateError) -> Self { - CompilerError::PtxTranslateError(cause) + let message = format!("PTX TranslateError::{}", cause.as_ref()); + let cause = Some(Box::new(cause) as Box); + CompilerError::GenericError { cause, message } } } impl From for CompilerError { fn from(cause: FromBytesUntilNulError) -> Self { - CompilerError::GenericError(format!("{}", cause)) + let message = format!("{}", cause); + let cause = Some(Box::new(cause) as Box); + CompilerError::GenericError { cause, message } } -} \ No newline at end of file +} diff --git a/compiler/src/main.rs b/compiler/src/main.rs index 3a3d7e3..2d3fa8f 100644 --- a/compiler/src/main.rs +++ b/compiler/src/main.rs @@ -1,9 +1,11 @@ use std::env; +use std::error::Error; use std::ffi::{CStr, CString, OsStr}; use std::fs::{self, File}; use std::io::{self, Write}; use std::mem::MaybeUninit; use std::path::{Path, PathBuf}; +use std::process::ExitCode; use std::str::{self, FromStr}; use bpaf::Bpaf; @@ -26,7 +28,17 @@ pub struct Options { ptx_path: String, } -fn main() -> Result<(), CompilerError> { +fn main() -> ExitCode { + main_core().map_or_else( + |e| { + eprintln!("Error: {}", e); + ExitCode::FAILURE + }, + |_| ExitCode::SUCCESS, + ) +} + +fn main_core() -> Result<(), CompilerError> { let opts = options().run(); let output_type = opts.output_type.unwrap_or_default(); @@ -36,7 +48,7 @@ fn main() -> Result<(), CompilerError> { let output_path = match opts.output_path { Some(value) => value, - None => get_output_path(&ptx_path, &output_type)? + None => get_output_path(&ptx_path, &output_type)?, }; check_path(&output_path)?; @@ -48,7 +60,7 @@ fn main() -> Result<(), CompilerError> { OutputType::LlvmIrPreLinked => llvm.llvm_ir, OutputType::LlvmIrLinked => get_linked_bitcode(&llvm)?, OutputType::Elf => get_elf(&llvm)?, - OutputType::Assembly => get_assembly(&llvm)? + OutputType::Assembly => get_assembly(&llvm)?, }; write_to_file(&output, &output_path).map_err(CompilerError::from)?; @@ -56,7 +68,7 @@ fn main() -> Result<(), CompilerError> { } fn ptx_to_llvm(ptx: &str) -> Result { - let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from).map_err(CompilerError::from)?; + let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; let module = ptx::to_llvm_module(ast).map_err(CompilerError::from)?; let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); let linked_bitcode = module.linked_bitcode().to_vec(); @@ -82,7 +94,11 @@ fn get_arch() -> Result { unsafe { hipGetDevicePropertiesR0600(dev_props.as_mut_ptr(), 0) }?; let dev_props = unsafe { dev_props.assume_init() }; let arch = dev_props.gcnArchName; - let arch: Vec = arch.to_vec().iter().map(|&v| i8::to_ne_bytes(v)[0]).collect(); + let arch: Vec = arch + .to_vec() + .iter() + .map(|&v| i8::to_ne_bytes(v)[0]) + .collect(); let arch = CStr::from_bytes_until_nul(arch.as_slice())?; Ok(CString::from(arch)) } @@ -94,26 +110,29 @@ fn get_linked_bitcode(llvm: &LLVMArtifacts) -> Result, CompilerError> { fn get_elf(llvm: &LLVMArtifacts) -> Result, CompilerError> { let arch = get_arch()?; - comgr::get_executable_as_bytes(&arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(CompilerError::from) + comgr::get_executable_as_bytes(&arch, &llvm.bitcode, &llvm.linked_bitcode) + .map_err(CompilerError::from) } fn get_assembly(llvm: &LLVMArtifacts) -> Result, CompilerError> { let arch = get_arch()?; - comgr::get_assembly_as_bytes(&arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(CompilerError::from) + comgr::get_assembly_as_bytes(&arch, &llvm.bitcode, &llvm.linked_bitcode) + .map_err(CompilerError::from) } fn check_path(path: &Path) -> Result<(), CompilerError> { if path.try_exists().map_err(CompilerError::from)? && !path.is_file() { - let error = CompilerError::CheckPathError(path.to_path_buf()); + let message = format!("Not a regular file: {:?}", path.to_path_buf()); + let error = CompilerError::GenericError { + cause: None, + message, + }; return Err(error); } Ok(()) } -fn get_output_path( - ptx_path: &PathBuf, - output_type: &OutputType, -) -> Result { +fn get_output_path(ptx_path: &PathBuf, output_type: &OutputType) -> Result { let current_dir = env::current_dir().map_err(CompilerError::from)?; let output_path = current_dir.join( ptx_path @@ -169,7 +188,13 @@ impl FromStr for OutputType { "ll_linked" => Ok(Self::LlvmIrLinked), "elf" => Ok(Self::Elf), "asm" => Ok(Self::Assembly), - _ => Err(CompilerError::ParseOutputTypeError(s.into())), + _ => { + let message = format!("Not a valid output type: {}", s); + Err(CompilerError::GenericError { + cause: None, + message, + }) + } } } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 77d7e60..2abdac3 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -17,8 +17,8 @@ mod expand_operands; mod fix_special_registers2; mod hoist_globals; mod insert_explicit_load_store; -mod instruction_mode_to_global_mode; mod insert_implicit_conversions2; +mod instruction_mode_to_global_mode; mod normalize_basic_blocks; mod normalize_identifiers2; mod normalize_predicates2; @@ -31,7 +31,7 @@ static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl. const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; quick_error! { - #[derive(Debug)] + #[derive(Debug, strum_macros::AsRefStr)] pub enum TranslateError { UnknownSymbol {} UntypedSymbol {} diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 9032de5..3b96ac0 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -7,11 +7,12 @@ edition = "2021" [lib] [dependencies] +bitflags = "1.2" +derive_more = { version = "1", features = ["display"] } logos = "0.14" +ptx_parser_macros = { path = "../ptx_parser_macros" } +rustc-hash = "2.0.0" +strum = { version = "0.27.1", features = ["derive"] } +thiserror = "1.0" winnow = { version = "0.6.18" } #winnow = { version = "0.6.18", features = ["debug"] } -ptx_parser_macros = { path = "../ptx_parser_macros" } -thiserror = "1.0" -bitflags = "1.2" -rustc-hash = "2.0.0" -derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index da46a8c..aea4dc6 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1240,7 +1240,7 @@ impl ast::ParsedOperand { } } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, strum::AsRefStr)] pub enum PtxError<'input> { #[error("{source}")] ParseInt {