mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-28 21:47:57 +03:00
Automate generation of HIP bindings
This commit is contained in:
parent
122676bb13
commit
3ec7bffdc5
7 changed files with 12759 additions and 7454 deletions
|
@ -28,7 +28,8 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get update -y && DEBIAN_FRONTEND=noninter
|
||||||
nvidia-utils-${CUDA_DRIVER} \
|
nvidia-utils-${CUDA_DRIVER} \
|
||||||
cuda-cudart-dev-${CUDA_PKG_VERSION} \
|
cuda-cudart-dev-${CUDA_PKG_VERSION} \
|
||||||
cuda-cudart-${CUDA_PKG_VERSION} \
|
cuda-cudart-${CUDA_PKG_VERSION} \
|
||||||
cuda-profiler-api-${CUDA_PKG_VERSION}
|
cuda-profiler-api-${CUDA_PKG_VERSION} \
|
||||||
|
cuda-nvcc-${CUDA_PKG_VERSION}
|
||||||
|
|
||||||
ARG ROCM_VERSION=6.2.2
|
ARG ROCM_VERSION=6.2.2
|
||||||
RUN mkdir --parents --mode=0755 /etc/apt/keyrings && \
|
RUN mkdir --parents --mode=0755 /etc/apt/keyrings && \
|
||||||
|
@ -41,7 +42,8 @@ RUN mkdir --parents --mode=0755 /etc/apt/keyrings && \
|
||||||
rocm-gdb \
|
rocm-gdb \
|
||||||
rocm-smi-lib \
|
rocm-smi-lib \
|
||||||
rocm-llvm-dev \
|
rocm-llvm-dev \
|
||||||
hip-runtime-amd && \
|
hip-runtime-amd && \
|
||||||
|
hip-dev && \
|
||||||
echo '/opt/rocm/lib' > /etc/ld.so.conf.d/rocm.conf && \
|
echo '/opt/rocm/lib' > /etc/ld.so.conf.d/rocm.conf && \
|
||||||
ldconfig
|
ldconfig
|
||||||
|
|
||||||
|
|
1
ext/hip_runtime-sys/README
vendored
1
ext/hip_runtime-sys/README
vendored
|
@ -1 +0,0 @@
|
||||||
bindgen --rust-target 1.77 /opt/rocm/include/hip/hip_runtime_api.h -o hip_runtime_api.rs --no-layout-tests --default-enum-style=newtype --allowlist-function "hip.*" --allowlist-type "hip.*" --no-derive-debug --must-use-type hipError_t --new-type-alias "^hipDeviceptr_t$" --allowlist-var "^hip.*$" -- -I/opt/rocm/include -D__HIP_PLATFORM_AMD__
|
|
|
@ -1,2 +0,0 @@
|
||||||
#define __HIP_PLATFORM_HCC__
|
|
||||||
#include <hip/hip_runtime_api.h>
|
|
7422
ext/hip_runtime-sys/src/hip_runtime_api.rs
vendored
7422
ext/hip_runtime-sys/src/hip_runtime_api.rs
vendored
File diff suppressed because it is too large
Load diff
12659
ext/hip_runtime-sys/src/lib.rs
vendored
12659
ext/hip_runtime-sys/src/lib.rs
vendored
File diff suppressed because it is too large
Load diff
|
@ -58,7 +58,7 @@ macro_rules! remap_attribute {
|
||||||
paste::paste! { hipDeviceAttribute_t:: [< hipDeviceAttribute $($word:camel)* >] }
|
paste::paste! { hipDeviceAttribute_t:: [< hipDeviceAttribute $($word:camel)* >] }
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
_ => return Err(hipErrorCode_t::hipErrorNotSupported)
|
_ => return Err(hipErrorCode_t::NotSupported)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,7 +245,7 @@ pub(crate) fn get_luid(
|
||||||
let luid = unsafe {
|
let luid = unsafe {
|
||||||
luid.cast::<[i8; 8]>()
|
luid.cast::<[i8; 8]>()
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.ok_or(hipErrorCode_t::hipErrorInvalidValue)
|
.ok_or(hipErrorCode_t::InvalidValue)
|
||||||
}?;
|
}?;
|
||||||
let mut properties = unsafe { mem::zeroed() };
|
let mut properties = unsafe { mem::zeroed() };
|
||||||
unsafe { hipGetDevicePropertiesR0600(&mut properties, dev) }?;
|
unsafe { hipGetDevicePropertiesR0600(&mut properties, dev) }?;
|
||||||
|
|
|
@ -10,33 +10,34 @@ use syn::{
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let crate_root = PathBuf::from_str(env!("CARGO_MANIFEST_DIR")).unwrap();
|
let crate_root = PathBuf::from_str(env!("CARGO_MANIFEST_DIR")).unwrap();
|
||||||
|
generate_hip_runtime(
|
||||||
|
&crate_root,
|
||||||
|
&["..", "ext", "hip_runtime-sys", "src", "lib.rs"],
|
||||||
|
);
|
||||||
let cuda_header = bindgen::Builder::default()
|
let cuda_header = bindgen::Builder::default()
|
||||||
.use_core()
|
.use_core()
|
||||||
.header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
|
.rust_target(bindgen::RustTarget::Stable_1_77)
|
||||||
.no_partialeq("CUDA_HOST_NODE_PARAMS_st")
|
.layout_tests(false)
|
||||||
.derive_eq(true)
|
|
||||||
.allowlist_type("^CU.*")
|
|
||||||
.allowlist_function("^cu.*")
|
|
||||||
.allowlist_var("^CU.*")
|
|
||||||
.default_enum_style(bindgen::EnumVariation::NewType {
|
.default_enum_style(bindgen::EnumVariation::NewType {
|
||||||
is_bitfield: false,
|
is_bitfield: false,
|
||||||
is_global: false,
|
is_global: false,
|
||||||
})
|
})
|
||||||
.layout_tests(false)
|
.derive_eq(true)
|
||||||
.new_type_alias(r"^CUdevice_v\d+$")
|
.header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
|
||||||
.new_type_alias(r"^CUdeviceptr_v\d+$")
|
.allowlist_type("^CU.*")
|
||||||
|
.allowlist_function("^cu.*")
|
||||||
|
.allowlist_var("^CU.*")
|
||||||
.must_use_type("cudaError_enum")
|
.must_use_type("cudaError_enum")
|
||||||
.constified_enum("cudaError_enum")
|
.constified_enum("cudaError_enum")
|
||||||
|
.no_partialeq("CUDA_HOST_NODE_PARAMS_st")
|
||||||
|
.new_type_alias(r"^CUdevice_v\d+$")
|
||||||
|
.new_type_alias(r"^CUdeviceptr_v\d+$")
|
||||||
.clang_args(["-I/usr/local/cuda/include"])
|
.clang_args(["-I/usr/local/cuda/include"])
|
||||||
.generate()
|
.generate()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_string();
|
.to_string();
|
||||||
let module: syn::File = syn::parse_str(&cuda_header).unwrap();
|
let module: syn::File = syn::parse_str(&cuda_header).unwrap();
|
||||||
generate_functions(
|
generate_functions(&crate_root, &["..", "cuda_base", "src", "cuda.rs"], &module);
|
||||||
&crate_root,
|
|
||||||
&["..", "cuda_base", "src", "cuda.rs"],
|
|
||||||
&module,
|
|
||||||
);
|
|
||||||
generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
|
generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
|
||||||
generate_display(
|
generate_display(
|
||||||
&crate_root,
|
&crate_root,
|
||||||
|
@ -46,6 +47,68 @@ fn main() {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
|
||||||
|
let hiprt_header = bindgen::Builder::default()
|
||||||
|
.use_core()
|
||||||
|
.rust_target(bindgen::RustTarget::Stable_1_77)
|
||||||
|
.layout_tests(false)
|
||||||
|
.default_enum_style(bindgen::EnumVariation::NewType {
|
||||||
|
is_bitfield: false,
|
||||||
|
is_global: false,
|
||||||
|
})
|
||||||
|
.derive_eq(true)
|
||||||
|
.header("/opt/rocm/include/hip/hip_runtime_api.h")
|
||||||
|
.allowlist_type("^hip.*")
|
||||||
|
.allowlist_function("^hip.*")
|
||||||
|
.allowlist_var("^hip.*")
|
||||||
|
.must_use_type("hipError_t")
|
||||||
|
.constified_enum("hipError_t")
|
||||||
|
.new_type_alias("^hipDeviceptr_t$")
|
||||||
|
.new_type_alias("^hipModule_t$")
|
||||||
|
.clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__"])
|
||||||
|
.generate()
|
||||||
|
.unwrap()
|
||||||
|
.to_string();
|
||||||
|
let mut module: syn::File = syn::parse_str(&hiprt_header).unwrap();
|
||||||
|
let mut converter = ConvertIntoRustResult {
|
||||||
|
type_: "hipError_t",
|
||||||
|
underlying_type: "hipError_t",
|
||||||
|
new_error_type: "hipErrorCode_t",
|
||||||
|
error_prefix: ("hipError", "Error"),
|
||||||
|
success: ("hipSuccess", "Success"),
|
||||||
|
constants: Vec::new(),
|
||||||
|
};
|
||||||
|
module.items = module
|
||||||
|
.items
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|item| match item {
|
||||||
|
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
|
||||||
|
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
|
||||||
|
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
|
||||||
|
item => Some(item),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
converter.flush(&mut module.items);
|
||||||
|
add_send_sync(&mut module.items, &["hipModule_t"]);
|
||||||
|
let mut output = output.clone();
|
||||||
|
output.extend(path);
|
||||||
|
write_rust_to_file(output, &prettyplease::unparse(&module))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) {
|
||||||
|
for type_ in arg {
|
||||||
|
let type_ = Ident::new(type_, Span::call_site());
|
||||||
|
items.extend([
|
||||||
|
parse_quote! {
|
||||||
|
unsafe impl Send for #type_ {}
|
||||||
|
},
|
||||||
|
parse_quote! {
|
||||||
|
unsafe impl Sync for #type_ {}
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
|
fn generate_functions(output: &PathBuf, path: &[&str], module: &syn::File) {
|
||||||
let fns_ = module.items.iter().filter_map(|item| match item {
|
let fns_ = module.items.iter().filter_map(|item| match item {
|
||||||
Item::ForeignMod(extern_) => match &*extern_.items {
|
Item::ForeignMod(extern_) => match &*extern_.items {
|
||||||
|
@ -73,7 +136,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
|
||||||
type_: "CUresult",
|
type_: "CUresult",
|
||||||
underlying_type: "cudaError_enum",
|
underlying_type: "cudaError_enum",
|
||||||
new_error_type: "CUerror",
|
new_error_type: "CUerror",
|
||||||
error_prefix: ("CUDA_ERROR", "ERROR"),
|
error_prefix: ("CUDA_ERROR_", "ERROR_"),
|
||||||
success: ("CUDA_SUCCESS", "SUCCESS"),
|
success: ("CUDA_SUCCESS", "SUCCESS"),
|
||||||
constants: Vec::new(),
|
constants: Vec::new(),
|
||||||
};
|
};
|
||||||
|
@ -84,6 +147,7 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
|
||||||
Item::ForeignMod(_) => None,
|
Item::ForeignMod(_) => None,
|
||||||
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
|
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
|
||||||
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
|
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
|
||||||
|
Item::Type(type_) => converter.get_type(type_).map(Item::Type),
|
||||||
Item::Struct(mut struct_) => {
|
Item::Struct(mut struct_) => {
|
||||||
let ident_string = struct_.ident.to_string();
|
let ident_string = struct_.ident.to_string();
|
||||||
match &*ident_string {
|
match &*ident_string {
|
||||||
|
@ -105,6 +169,13 @@ fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
converter.flush(&mut module.items);
|
converter.flush(&mut module.items);
|
||||||
|
module.items.push(parse_quote! {
|
||||||
|
impl From<hip_runtime_sys::hipErrorCode_t> for CUerror {
|
||||||
|
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
|
||||||
|
Self(error.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module);
|
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module);
|
||||||
let mut output = output.clone();
|
let mut output = output.clone();
|
||||||
output.extend(path);
|
output.extend(path);
|
||||||
|
@ -163,9 +234,9 @@ impl ConvertIntoRustResult {
|
||||||
const #success: #type_ = #type_::Ok(());
|
const #success: #type_ = #type_::Ok(());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len() + 1;
|
let old_prefix_len = self.underlying_type.len() + 1 + self.error_prefix.0.len();
|
||||||
let variant_ident =
|
let variant_ident =
|
||||||
format_ident!("{}_{}", self.error_prefix.1, &ident[old_prefix_len..]);
|
format_ident!("{}{}", self.error_prefix.1, &ident[old_prefix_len..]);
|
||||||
let error_ident = format_ident!("{}", &ident[old_prefix_len..]);
|
let error_ident = format_ident!("{}", &ident[old_prefix_len..]);
|
||||||
let expr = &const_.expr;
|
let expr = &const_.expr;
|
||||||
result_variants.push(quote! {
|
result_variants.push(quote! {
|
||||||
|
@ -193,15 +264,17 @@ impl ConvertIntoRustResult {
|
||||||
const _: fn() = || {
|
const _: fn() = || {
|
||||||
let _ = std::mem::transmute::<#type_, u32>;
|
let _ = std::mem::transmute::<#type_, u32>;
|
||||||
};
|
};
|
||||||
|
|
||||||
impl From<hip_runtime_sys::hipErrorCode_t> for #new_error_type {
|
|
||||||
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
|
|
||||||
Self(error.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
items.extend(extra_items);
|
items.extend(extra_items);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_type(&self, type_: syn::ItemType) -> Option<syn::ItemType> {
|
||||||
|
if type_.ident.to_string() == self.type_ {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(type_)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FixAbi;
|
struct FixAbi;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue