Skip to content

Commit

Permalink
re-factored to include error code
Browse files Browse the repository at this point in the history
  • Loading branch information
arshadm committed Jul 8, 2023
1 parent 92df0a2 commit c3456fc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 36 deletions.
3 changes: 1 addition & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ fn main() {
println!("cargo:rustc-link-search={cuda_lib_dir}");
println!("cargo:rustc-link-search={cuda_lib_dir}/stubs");

// Tell cargo to tell rustc to link the system bzip2
// shared library.
// Tell cargo to tell rustc to link the cuda and nvrtc libraries
println!("cargo:rustc-link-lib=nvrtc");
println!("cargo:rustc-link-lib=cuda");

Expand Down
48 changes: 22 additions & 26 deletions src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,37 @@ pub type CudaMemory = CUdeviceptr;
pub struct Driver;

impl Driver {
pub fn init(device_number: u32) -> Result<(), &'static str> {
pub fn init(device_number: u32) -> Result<(), String> {
let cu_result = unsafe { cuInit(device_number) };
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuInit");
return Err(format!("Failed: cuInit - {cu_result}"));
}

Ok(())
}

pub fn get_device(device_number: u32) -> Result<CudaDevice, &'static str> {
pub fn get_device(device_number: u32) -> Result<CudaDevice, String> {
let mut device = unsafe { zeroed::<CUdevice>() };
let cu_result = unsafe { cuDeviceGet(&mut device as *mut CUdevice, device_number as i32) };
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuDeviceGet");
return Err(format!("Failed: cuDeviceGet - {cu_result}"));
}

Ok(device)
}

pub fn create_context(device: CudaDevice) -> Result<CudaContext, &'static str> {
pub fn create_context(device: CudaDevice) -> Result<CudaContext, String> {
let mut context = unsafe { zeroed::<CUcontext>() };

let cu_result = unsafe { cuCtxCreate_v2(&mut context as *mut CUcontext, 0, device) };
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuCtxCreate_v2");
return Err(format!("Failed: cuCtxCreate_v2 {cu_result}"));
}

Ok(context)
}

pub fn load_module(ptx: CudaPtx) -> Result<CudaModule, &'static str> {
pub fn load_module(ptx: CudaPtx) -> Result<CudaModule, String> {
let mut module = unsafe { zeroed::<CUmodule>() };

let cu_result = unsafe {
Expand All @@ -74,8 +74,7 @@ impl Driver {
)
};
if cu_result != cudaError_enum_CUDA_SUCCESS {
println!("Error: {}", cu_result);
return Err("Failed: cuModuleLoadDataEx");
return Err(format!("Failed: cuModuleLoadDataEx - {cu_result}"));
}

Ok(module)
Expand All @@ -90,29 +89,26 @@ impl Driver {
/// # Safety
///
/// .
pub unsafe fn get_function(
module: CudaModule,
name: &str,
) -> Result<CudaFunction, &'static str> {
pub unsafe fn get_function(module: CudaModule, name: &str) -> Result<CudaFunction, String> {
let mut kernel = zeroed::<CUfunction>();
let name_str = CString::new(name).unwrap();

let cu_result =
cuModuleGetFunction(&mut kernel as *mut CUfunction, module, name_str.as_ptr());

if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuModuleGetFunction");
return Err(format!("Failed: cuModuleGetFunction - {cu_result}"));
}

Ok(kernel)
}

pub fn create_stream() -> Result<CudaStream, &'static str> {
pub fn create_stream() -> Result<CudaStream, String> {
let mut stream = unsafe { zeroed::<CUstream>() };

let cu_result = unsafe { cuStreamCreate(&mut stream as *mut CUstream, 0) };
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuStreamCreate");
return Err(format!("Failed: cuStreamCreate - {cu_result}"));
}

Ok(stream)
Expand All @@ -135,7 +131,7 @@ impl Driver {
stream: CudaStream,
kernel_params: *mut *mut c_void,
extra: *mut *mut c_void,
) -> Result<(), &'static str> {
) -> Result<(), String> {
let cu_result = cuLaunchKernel(
kernel,
num_blocks.0,
Expand All @@ -150,18 +146,18 @@ impl Driver {
extra,
);
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuLaunchKernel");
return Err(format!("Failed: cuLaunchKernel - {cu_result}"));
}

Ok(())
}

pub fn allocate_memory(size: usize) -> Result<CudaMemory, &'static str> {
pub fn allocate_memory(size: usize) -> Result<CudaMemory, String> {
let mut device_ptr = unsafe { zeroed::<CUdeviceptr>() };

let cu_result = unsafe { cuMemAlloc_v2(&mut device_ptr as *mut CUdeviceptr, size) };
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuMemAlloc_v2");
return Err(format!("Failed: cuMemAlloc_v2 - {cu_result}"));
}

Ok(device_ptr)
Expand All @@ -180,10 +176,10 @@ impl Driver {
device_memory: CudaMemory,
host_memory: *const c_void,
size: usize,
) -> Result<(), &'static str> {
) -> Result<(), String> {
let cu_result = cuMemcpyHtoD_v2(device_memory, host_memory, size);
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuMemcpyHtoD_v2");
return Err(format!("Failed: cuMemcpyHtoD_v2 - {cu_result}"));
}

Ok(())
Expand All @@ -202,19 +198,19 @@ impl Driver {
host_memory: *mut c_void,
device_memory: CudaMemory,
size: usize,
) -> Result<(), &'static str> {
) -> Result<(), String> {
let cu_result = cuMemcpyDtoH_v2(host_memory, device_memory, size);
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuMemcpyDtoH_v2");
return Err(format!("Failed: cuMemcpyDtoH_v2 - {cu_result}"));
}

Ok(())
}

pub fn synchronize_context() -> Result<(), &'static str> {
pub fn synchronize_context() -> Result<(), String> {
let cu_result = unsafe { cuCtxSynchronize() };
if cu_result != cudaError_enum_CUDA_SUCCESS {
return Err("Failed: cuCtxSynchronize");
return Err(format!("Failed: cuCtxSynchronize - {cu_result}"));
}

Ok(())
Expand Down
16 changes: 8 additions & 8 deletions src/nvrtc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,18 @@ impl Nvrtc {
/// # Safety
///
/// .
pub unsafe fn get_ptx(program: CudaProgram) -> Result<CudaPtx, &'static str> {
pub unsafe fn get_ptx(program: CudaProgram) -> Result<CudaPtx, String> {
let mut ptx_size: usize = 0;
let nvrtc_result = nvrtcGetPTXSize(program, &mut ptx_size as *mut usize);
if nvrtc_result != nvrtcResult_NVRTC_SUCCESS {
return Err("Failed: nvrtcGetPTXSize");
return Err(format!("Failed: nvrtcGetPTXSize - {nvrtc_result}"));
}
println!("PTX size: {}", ptx_size);

let mut buffer: Vec<c_char> = vec![0; ptx_size];
let nvrtc_result = nvrtcGetPTX(program, buffer.as_mut_ptr());
if nvrtc_result != nvrtcResult_NVRTC_SUCCESS {
return Err("Failed: nvrtcGetPTX");
return Err(format!("Failed: nvrtcGetPTX - {nvrtc_result}"));
}

Ok(buffer)
Expand All @@ -89,10 +89,10 @@ impl Nvrtc {
/// # Safety
///
/// .
pub unsafe fn destroy_program(mut program: CudaProgram) -> Result<(), &'static str> {
pub unsafe fn destroy_program(mut program: CudaProgram) -> Result<(), String> {
let nvrtc_result = nvrtcDestroyProgram(&mut program as *mut nvrtcProgram);
if nvrtc_result != nvrtcResult_NVRTC_SUCCESS {
return Err("Failed: nvrtcDestroyProgram");
return Err(format!("Failed: nvrtcDestroyProgram - {nvrtc_result}"));
}

Ok(())
Expand All @@ -107,17 +107,17 @@ impl Nvrtc {
/// # Safety
///
/// .
pub unsafe fn get_program_log(program: CudaProgram) -> Result<String, &'static str> {
pub unsafe fn get_program_log(program: CudaProgram) -> Result<String, String> {
let mut log_size: usize = 0;
let nvrtc_result = nvrtcGetProgramLogSize(program, &mut log_size as *mut usize);
if nvrtc_result != nvrtcResult_NVRTC_SUCCESS {
return Err("Failed: nvrtcGetProgramLogSize");
return Err(format!("Failed: nvrtcGetProgramLogSize - {nvrtc_result}"));
}

let mut raw_log: Vec<u8> = vec![0; log_size];
let nvrtc_result = nvrtcGetProgramLog(program, raw_log.as_mut_ptr() as *mut c_char);
if nvrtc_result != nvrtcResult_NVRTC_SUCCESS {
return Err("Failed: nvrtcGetProgramLog");
return Err(format!("Failed: nvrtcGetProgramLog - {nvrtc_result}"));
}

Ok(String::from_utf8(raw_log).unwrap())
Expand Down

0 comments on commit c3456fc

Please sign in to comment.