Skip to content

Commit

Permalink
Implement static typing for dynamically-loaded CUDA DLLs
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Jan 28, 2022
1 parent 07aa110 commit 89bc406
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 81 deletions.
69 changes: 39 additions & 30 deletions cuda_base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::Brace;
use syn::visit_mut::VisitMut;
use syn::{
bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident,
Item, ItemForeignMod, ItemMacro, LitStr, Macro, MacroDelimiter, PatType, Path, PathArguments,
PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr,
Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature,
Token, Type, TypeArray, TypePath, TypePtr,
};

const CUDA_RS: &'static str = include_str! {"cuda.rs"};
Expand Down Expand Up @@ -109,8 +108,11 @@ impl VisitMut for FixAbi {
// Then macro goes through every function in rust.rs, and for every fn `foo`:
// * if `foo` is contained in `override_fns` then pass it into `override_macro`
// * if `foo` is not contained in `override_fns` pass it to `normal_macro`
// Both `override_macro` and `normal_macro` expect this format:
// macro_foo!("system" fn cuCtxDetach(ctx: CUcontext) -> CUresult)
// Both `override_macro` and `normal_macro` expect semicolon-separated list:
// macro_foo!(
// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult;
// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult
// )
// Additionally, it does a fixup of CUDA types so they get prefixed with `type_path`
#[proc_macro]
pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
Expand All @@ -121,7 +123,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
.iter()
.map(ToString::to_string)
.collect::<FxHashSet<_>>();
cuda_module
let (normal_macro_args, override_macro_args): (Vec<_>, Vec<_>) = cuda_module
.items
.into_iter()
.filter_map(|item| match item {
Expand All @@ -136,12 +138,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
},
..
}) => {
let path = if override_fns.contains(&ident.to_string()) {
&input.override_macro
} else {
&input.normal_macro
}
.clone();
let use_normal_macro = !override_fns.contains(&ident.to_string());
let inputs = inputs
.into_iter()
.map(|fn_arg| match fn_arg {
Expand All @@ -158,30 +155,42 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
ReturnType::Default => unreachable!(),
};
let type_path = input.type_path.clone();
let tokens = quote! {
"system" fn #ident(#inputs) -> #type_path :: #output
};
Some(Item::Macro(ItemMacro {
attrs: Vec::new(),
ident: None,
mac: Macro {
path,
bang_token: Token![!](Span::call_site()),
delimiter: MacroDelimiter::Brace(Brace {
span: Span::call_site(),
}),
tokens,
Some((
quote! {
"system" fn #ident(#inputs) -> #type_path :: #output
},
semi_token: None,
}))
use_normal_macro,
))
}
_ => unreachable!(),
},
_ => None,
})
.map(Item::into_token_stream)
.collect::<proc_macro2::TokenStream>()
.into()
.partition(|(_, use_normal_macro)| *use_normal_macro);
let mut result = proc_macro2::TokenStream::new();
if !normal_macro_args.is_empty() {
let punctuated_normal_macro_args = to_punctuated::<Token![;]>(normal_macro_args);
let macro_ = &input.normal_macro;
result.extend(iter::once(quote! {
#macro_ ! (#punctuated_normal_macro_args);
}));
}
if !override_macro_args.is_empty() {
let punctuated_override_macro_args = to_punctuated::<Token![;]>(override_macro_args);
let macro_ = &input.override_macro;
result.extend(iter::once(quote! {
#macro_ ! (#punctuated_override_macro_args);
}));
}
result.into()
}

fn to_punctuated<P: ToTokens + Default>(
elms: Vec<(proc_macro2::TokenStream, bool)>,
) -> proc_macro2::TokenStream {
let mut collection = Punctuated::<proc_macro2::TokenStream, P>::new();
collection.extend(elms.into_iter().map(|(token_stream, _)| token_stream));
collection.into_token_stream()
}

fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> {
Expand Down
1 change: 0 additions & 1 deletion zluda_dump/src/format.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
extern crate cuda_types;
use std::{
ffi::{c_void, CStr},
fmt::LowerHex,
Expand Down
135 changes: 87 additions & 48 deletions zluda_dump/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use cuda_types::{
CUdevice, CUdevice_attribute, CUfunction, CUjit_option, CUmodule, CUresult, CUuuid,
};
use paste::paste;
use side_by_side::CudaDynamicFns;
use std::io;
use std::{
collections::HashMap, env, error::Error, ffi::c_void, fs, path::PathBuf, ptr::NonNull, rc::Rc,
Expand All @@ -10,47 +11,50 @@ use std::{

#[macro_use]
extern crate lazy_static;
extern crate cuda_types;

macro_rules! extern_redirect {
($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path) => {
#[no_mangle]
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let original_fn = |fn_ptr| {
let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) };
typed_fn($( $arg_id ),*)
};
let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| {
(paste! { format :: [<write_ $fn_name>] }) (
writer
$(,$arg_id)*
)
});
crate::handle_cuda_function_call(stringify!($fn_name), original_fn, get_formatted_args)
}
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => {
$(
#[no_mangle]
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| {
dynamic_fns.$fn_name($( $arg_id ),*)
};
let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| {
(paste! { format :: [<write_ $fn_name>] }) (
writer
$(,$arg_id)*
)
});
crate::handle_cuda_function_call(stringify!($fn_name), original_fn, get_formatted_args)
}
)*
};
}

macro_rules! extern_redirect_with_post {
($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path) => {
#[no_mangle]
pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let original_fn = |fn_ptr| {
let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) };
typed_fn($( $arg_id ),*)
};
let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| {
(paste! { format :: [<write_ $fn_name>] }) (
writer
$(,$arg_id)*
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => {
$(
#[no_mangle]
pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| {
dynamic_fns.$fn_name($( $arg_id ),*)
};
let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| {
(paste! { format :: [<write_ $fn_name>] }) (
writer
$(,$arg_id)*
)
});
crate::handle_cuda_function_call_with_probes(
stringify!($fn_name),
|| (), original_fn,
get_formatted_args,
move |logger, state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , logger, state, cuda_result )
)
});
crate::handle_cuda_function_call_with_probes(
stringify!($fn_name),
|| (), original_fn,
get_formatted_args,
move |logger, state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , logger, state, cuda_result )
)
}
}
)*
};
}

Expand All @@ -77,6 +81,7 @@ mod log;
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
mod os;
mod side_by_side;
mod trace;

lazy_static! {
Expand Down Expand Up @@ -127,7 +132,8 @@ impl<T> LateInit<T> {

struct GlobalDelayedState {
settings: Settings,
libcuda_handle: NonNull<c_void>,
libcuda: CudaDynamicFns,
side_by_side_lib: Option<CudaDynamicFns>,
cuda_state: trace::StateTracker,
}

Expand All @@ -139,21 +145,39 @@ impl GlobalDelayedState {
) -> (LateInit<Self>, log::FunctionLogger<'a>) {
let (mut fn_logger, settings) =
factory.get_first_logger_and_init_settings(func, arguments_writer);
let maybe_libcuda_handle = unsafe { os::load_cuda_library(&settings.libcuda_path) };
let libcuda_handle = match NonNull::new(maybe_libcuda_handle) {
Some(h) => h,
let libcuda = match unsafe { CudaDynamicFns::load_library(&settings.libcuda_path) } {
Some(libcuda) => libcuda,
None => {
fn_logger.log(log::LogEntry::ErrorBox(
format!("Invalid CUDA library at path {}", &settings.libcuda_path).into(),
));
return (LateInit::Error, fn_logger);
}
};
let side_by_side_lib = settings
.side_by_side_path
.as_ref()
.and_then(|side_by_side_path| {
match unsafe { CudaDynamicFns::load_library(&*side_by_side_path) } {
Some(fns) => Some(fns),
None => {
fn_logger.log(log::LogEntry::ErrorBox(
format!(
"Invalid side-by-side CUDA library at path {}",
&side_by_side_path
)
.into(),
));
None
}
}
});
let cuda_state = trace::StateTracker::new(&settings);
let delayed_state = GlobalDelayedState {
settings,
libcuda_handle,
libcuda,
cuda_state,
side_by_side_lib,
};
(LateInit::Success(delayed_state), fn_logger)
}
Expand All @@ -163,6 +187,7 @@ struct Settings {
dump_dir: Option<PathBuf>,
libcuda_path: String,
override_cc_major: Option<u32>,
side_by_side_path: Option<String>,
}

impl Settings {
Expand All @@ -179,7 +204,7 @@ impl Settings {
None
}
};
let libcuda_path = match env::var("ZLUDA_DUMP_LIBCUDA_FILE") {
let libcuda_path = match env::var("ZLUDA_CUDA_LIB") {
Err(env::VarError::NotPresent) => os::LIBCUDA_DEFAULT_PATH.to_owned(),
Err(e) => {
logger.log(log::LogEntry::ErrorBox(Box::new(e) as _));
Expand All @@ -201,10 +226,19 @@ impl Settings {
Ok(cc) => Some(cc),
},
};
let side_by_side_path = match env::var("ZLUDA_SIDE_BY_SIDE_LIB") {
Err(env::VarError::NotPresent) => None,
Err(e) => {
logger.log(log::LogEntry::ErrorBox(Box::new(e) as _));
None
}
Ok(env_string) => Some(env_string),
};
Settings {
dump_dir,
libcuda_path,
override_cc_major,
side_by_side_path,
}
}

Expand Down Expand Up @@ -241,7 +275,7 @@ pub struct ModuleDump {

fn handle_cuda_function_call(
func: &'static str,
original_cuda_fn: impl FnOnce(NonNull<c_void>) -> CUresult,
original_cuda_fn: impl FnOnce(&mut CudaDynamicFns) -> Option<CUresult>,
arguments_writer: Box<dyn FnMut(&mut dyn std::io::Write) -> std::io::Result<()>>,
) -> CUresult {
handle_cuda_function_call_with_probes(
Expand All @@ -256,7 +290,7 @@ fn handle_cuda_function_call(
fn handle_cuda_function_call_with_probes<T, PostFn>(
func: &'static str,
pre_probe: impl FnOnce() -> T,
original_cuda_fn: impl FnOnce(NonNull<c_void>) -> CUresult,
original_cuda_fn: impl FnOnce(&mut CudaDynamicFns) -> Option<CUresult>,
arguments_writer: Box<dyn FnMut(&mut dyn std::io::Write) -> std::io::Result<()>>,
post_probe: PostFn,
) -> CUresult
Expand All @@ -283,13 +317,18 @@ where
(logger, global_state.delayed_state.as_mut().unwrap())
}
};
let name = std::ffi::CString::new(func).unwrap();
let fn_ptr =
unsafe { os::get_proc_address(delayed_state.libcuda_handle.as_ptr(), name.as_c_str()) };
let fn_ptr = NonNull::new(fn_ptr).unwrap();
let pre_result = pre_probe();
let cu_result = original_cuda_fn(fn_ptr);
logger.result = Some(cu_result);
let maybe_cu_result = original_cuda_fn(&mut delayed_state.libcuda);
let cu_result = match maybe_cu_result {
Some(result) => result,
None => {
logger.log(log::LogEntry::ErrorBox(
format!("No function {} in the underlying CUDA library", func).into(),
));
CUresult::CUDA_ERROR_UNKNOWN
}
};
logger.result = maybe_cu_result;
post_probe(
&mut logger,
&mut delayed_state.cuda_state,
Expand Down
2 changes: 1 addition & 1 deletion zluda_dump/src/os_unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::mem;

pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = b"/usr/lib/x86_64-linux-gnu/libcuda.so.1\0";

pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path = CString::new(libcuda_path).unwrap();
libc::dlopen(
libcuda_path.as_ptr() as *const _,
Expand Down
2 changes: 1 addition & 1 deletion zluda_dump/src/os_win.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl PlatformLibrary {
}
}

pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void {
pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void {
let libcuda_path_uf16 = libcuda_path
.encode_utf16()
.chain(std::iter::once(0))
Expand Down
Loading

0 comments on commit 89bc406

Please sign in to comment.