Automate generation of HIP bindings

This commit is contained in:
Andrzej Janik 2024-11-22 16:58:55 +00:00
parent 122676bb13
commit 3ec7bffdc5
7 changed files with 12759 additions and 7454 deletions

View file

@ -28,7 +28,8 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get update -y && DEBIAN_FRONTEND=noninter
nvidia-utils-${CUDA_DRIVER} \ nvidia-utils-${CUDA_DRIVER} \
cuda-cudart-dev-${CUDA_PKG_VERSION} \ cuda-cudart-dev-${CUDA_PKG_VERSION} \
cuda-cudart-${CUDA_PKG_VERSION} \ cuda-cudart-${CUDA_PKG_VERSION} \
cuda-profiler-api-${CUDA_PKG_VERSION} cuda-profiler-api-${CUDA_PKG_VERSION} \
cuda-nvcc-${CUDA_PKG_VERSION}
ARG ROCM_VERSION=6.2.2 ARG ROCM_VERSION=6.2.2
RUN mkdir --parents --mode=0755 /etc/apt/keyrings && \ RUN mkdir --parents --mode=0755 /etc/apt/keyrings && \
@ -41,7 +42,8 @@ RUN mkdir --parents --mode=0755 /etc/apt/keyrings && \
rocm-gdb \ rocm-gdb \
rocm-smi-lib \ rocm-smi-lib \
rocm-llvm-dev \ rocm-llvm-dev \
hip-runtime-amd && \ hip-runtime-amd && \
hip-dev && \
echo '/opt/rocm/lib' > /etc/ld.so.conf.d/rocm.conf && \ echo '/opt/rocm/lib' > /etc/ld.so.conf.d/rocm.conf && \
ldconfig ldconfig

View file

@ -1 +0,0 @@
bindgen --rust-target 1.77 /opt/rocm/include/hip/hip_runtime_api.h -o hip_runtime_api.rs --no-layout-tests --default-enum-style=newtype --allowlist-function "hip.*" --allowlist-type "hip.*" --no-derive-debug --must-use-type hipError_t --new-type-alias "^hipDeviceptr_t$" --allowlist-var "^hip.*$" -- -I/opt/rocm/include -D__HIP_PLATFORM_AMD__

View file

@ -1,2 +0,0 @@
#define __HIP_PLATFORM_HCC__
#include <hip/hip_runtime_api.h>

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -58,7 +58,7 @@ macro_rules! remap_attribute {
paste::paste! { hipDeviceAttribute_t:: [< hipDeviceAttribute $($word:camel)* >] } paste::paste! { hipDeviceAttribute_t:: [< hipDeviceAttribute $($word:camel)* >] }
} }
)* )*
_ => return Err(hipErrorCode_t::hipErrorNotSupported) _ => return Err(hipErrorCode_t::NotSupported)
} }
} }
} }
@ -245,7 +245,7 @@ pub(crate) fn get_luid(
let luid = unsafe { let luid = unsafe {
luid.cast::<[i8; 8]>() luid.cast::<[i8; 8]>()
.as_mut() .as_mut()
.ok_or(hipErrorCode_t::hipErrorInvalidValue) .ok_or(hipErrorCode_t::InvalidValue)
}?; }?;
let mut properties = unsafe { mem::zeroed() }; let mut properties = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut properties, dev) }?; unsafe { hipGetDevicePropertiesR0600(&mut properties, dev) }?;

View file

@ -10,33 +10,34 @@ use syn::{
fn main() { fn main() {
let crate_root = PathBuf::from_str(env!("CARGO_MANIFEST_DIR")).unwrap(); let crate_root = PathBuf::from_str(env!("CARGO_MANIFEST_DIR")).unwrap();
generate_hip_runtime(
&crate_root,
&["..", "ext", "hip_runtime-sys", "src", "lib.rs"],
);
let cuda_header = bindgen::Builder::default() let cuda_header = bindgen::Builder::default()
.use_core() .use_core()
.header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h")) .rust_target(bindgen::RustTarget::Stable_1_77)
.no_partialeq("CUDA_HOST_NODE_PARAMS_st") .layout_tests(false)
.derive_eq(true)
.allowlist_type("^CU.*")
.allowlist_function("^cu.*")
.allowlist_var("^CU.*")
.default_enum_style(bindgen::EnumVariation::NewType { .default_enum_style(bindgen::EnumVariation::NewType {
is_bitfield: false, is_bitfield: false,
is_global: false, is_global: false,
}) })
.layout_tests(false) .derive_eq(true)
.new_type_alias(r"^CUdevice_v\d+$") .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
.new_type_alias(r"^CUdeviceptr_v\d+$") .allowlist_type("^CU.*")
.allowlist_function("^cu.*")
.allowlist_var("^CU.*")
.must_use_type("cudaError_enum") .must_use_type("cudaError_enum")
.constified_enum("cudaError_enum") .constified_enum("cudaError_enum")
.no_partialeq("CUDA_HOST_NODE_PARAMS_st")
.new_type_alias(r"^CUdevice_v\d+$")
.new_type_alias(r"^CUdeviceptr_v\d+$")
.clang_args(["-I/usr/local/cuda/include"]) .clang_args(["-I/usr/local/cuda/include"])
.generate() .generate()
.unwrap() .unwrap()
.to_string(); .to_string();
let module: syn::File = syn::parse_str(&cuda_header).unwrap(); let module: syn::File = syn::parse_str(&cuda_header).unwrap();
generate_functions( generate_functions(&crate_root, &["..", "cuda_base", "src", "cuda.rs"], &module);
&crate_root,
&["..", "cuda_base", "src", "cuda.rs"],
&module,
);
generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module); generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
generate_display( generate_display(
&crate_root, &crate_root,
@ -46,6 +47,68 @@ fn main() {
) )
} }
fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
let hiprt_header = bindgen::Builder::default()
.use_core()
.rust_target(bindgen::RustTarget::Stable_1_77)
.layout_tests(false)
.default_enum_style(bindgen::EnumVariation::NewType {
is_bitfield: false,
is_global: false,
})
.derive_eq(true)
.header("/opt/rocm/include/hip/hip_runtime_api.h")
.allowlist_type("^hip.*")
.allowlist_function("^hip.*")
.allowlist_var("^hip.*")
.must_use_type("hipError_t")
.constified_enum("hipError_t")
.new_type_alias("^hipDeviceptr_t$")
.new_type_alias("^hipModule_t$")
.clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__"])
.generate()
.unwrap()
.to_string();
let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap();
let mut converter = ConvertIntoRustResult {
type_: "hipError_t",
underlying_type: "hipError_t",
new_error_type: "hipErrorCode_t",
error_prefix: ("hipError", "Error"),
success: ("hipSuccess", "Success"),
constants: Vec::new(),
};
module.items = module
.items
.into_iter()
.filter_map(|item| match item {
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
item => Some(item),
})
.collect::<Vec<_>>();
converter.flush(&mut module.items);
add_send_sync(&mut module.items, &["hipModule_t"]);
let mut output = output.clone();
output.extend(path);
write_rust_to_file(output, &prettyplease::unparse(&module))
}
fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) {
for type_ in arg {
let type_ = Ident::new(type_, Span::call_site());
items.extend([
parse_quote! {
unsafe impl Send for #type_ {}
},
parse_quote! {
unsafe impl Sync for #type_ {}
},
]);
}
}
fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) { fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
let fns_ = module.items.iter().filter_map(|item| match item { let fns_ = module.items.iter().filter_map(|item| match item {
Item::ForeignMod(extern_) => match &*extern_.items { Item::ForeignMod(extern_) => match &*extern_.items {
@ -73,7 +136,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
type_: "CUresult", type_: "CUresult",
underlying_type: "cudaError_enum", underlying_type: "cudaError_enum",
new_error_type: "CUerror", new_error_type: "CUerror",
error_prefix: ("CUDA_ERROR", "ERROR"), error_prefix: ("CUDA_ERROR_", "ERROR_"),
success: ("CUDA_SUCCESS", "SUCCESS"), success: ("CUDA_SUCCESS", "SUCCESS"),
constants: Vec::new(), constants: Vec::new(),
}; };
@ -84,6 +147,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
Item::ForeignMod(_) => None, Item::ForeignMod(_) => None,
Item::Const(const_) => converter.get_const(const_).map(Item::Const), Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use), Item::Use(use_) => converter.get_use(use_).map(Item::Use),
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
Item::Struct(mut struct_) => { Item::Struct(mut struct_) => {
let ident_string = struct_.ident.to_string(); let ident_string = struct_.ident.to_string();
match &*ident_string { match &*ident_string {
@ -105,6 +169,13 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
converter.flush(&mut module.items); converter.flush(&mut module.items);
module.items.push(parse_quote! {
impl From<hip_runtime_sys::hipErrorCode_t> for CUerror {
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
Self(error.0)
}
}
});
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module); syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module);
let mut output = output.clone(); let mut output = output.clone();
output.extend(path); output.extend(path);
@ -163,9 +234,9 @@ impl ConvertIntoRustResult {
const #success: #type_ = #type_::Ok(()); const #success: #type_ = #type_::Ok(());
}); });
} else { } else {
let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len() + 1; let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len();
let variant_ident = let variant_ident =
format_ident!("{}_{}", self.error_prefix.1, &ident[old_prefix_len..]); format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]);
let error_ident = format_ident!("{}", &ident[old_prefix_len..]); let error_ident = format_ident!("{}", &ident[old_prefix_len..]);
let expr = &const_.expr; let expr = &const_.expr;
result_variants.push(quote! { result_variants.push(quote! {
@ -193,15 +264,17 @@ impl ConvertIntoRustResult {
const _: fn() = || { const _: fn() = || {
let _ = std::mem::transmute::<#type_, u32>; let _ = std::mem::transmute::<#type_, u32>;
}; };
impl From<hip_runtime_sys::hipErrorCode_t> for #new_error_type {
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
Self(error.0)
}
}
}; };
items.extend(extra_items); items.extend(extra_items);
} }
fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> {
if type_.ident.to_string() == self.type_ {
None
} else {
Some(type_)
}
}
} }
struct FixAbi; struct FixAbi;