Skip to content

Commit e83b5a8

Browse files
committed
Support to let users create Array from raw device pointers
The example show cases the usage using CUDA API, but it should be similar to OpenCL or CPU. In the case of OpenCL backend, the pointer would be cl_mem
1 parent 1c49295 commit e83b5a8

File tree

17 files changed

+689
-9
lines changed

17 files changed

+689
-9
lines changed

Cargo.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ exclude = [
1515
"arrayfire/*",
1616
]
1717

18+
[workspace]
19+
members = [
20+
"cuda-interop",
21+
]
22+
23+
[lib]
24+
name = "arrayfire"
25+
path = "src/lib.rs"
26+
1827
[package.metadata.docs.rs]
1928
rustdoc-args = [ "--html-in-header", "./scripts/mathjax.script", ]
2029

@@ -53,10 +62,6 @@ serde_derive = "1.0"
5362
serde = "1.0"
5463
rustc_version = "0.2"
5564

56-
[lib]
57-
name = "arrayfire"
58-
path = "src/lib.rs"
59-
6065
[[example]]
6166
name = "helloworld"
6267
path = "examples/helloworld.rs"

cuda-interop/Cargo.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[package]
2+
name = "af-cuda-interop"
3+
version = "0.1.0"
4+
authors = ["Pradeep Garigipati <pradeep@arrayfire.com>"]
5+
edition = "2018"
6+
7+
[dependencies]
8+
libc = "0.2"
9+
arrayfire = { path = "../" }
10+
cuda-runtime-sys = "0.3.0-alpha.1"
11+
12+
[dev-dependencies]
13+
rustacuda = "0.1"
14+
rustacuda_core = "0.1"
15+
16+
[[example]]
17+
name = "custom_kernel"
18+
path = "examples/custom_kernel.rs"

cuda-interop/examples/cuda_af_app.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use arrayfire::{af_print, dim4, info, set_device, Array};
2+
use rustacuda::prelude::*;
3+
use rustacuda::*;
4+
5+
fn main() {
6+
// MAKE SURE to do all rustacuda initilization before arrayfire API's
7+
// first call. It seems like some CUDA context state is getting messed up
8+
// if we mix CUDA context init(device, context, module, stream) with ArrayFire API
9+
match rustacuda::init(CudaFlags::empty()) {
10+
Ok(()) => {}
11+
Err(e) => panic!("rustacuda init failure: {:?}", e),
12+
}
13+
let device = match Device::get_device(0) {
14+
Ok(d) => d,
15+
Err(e) => panic!("Failed to get device: {:?}", e),
16+
};
17+
let _context =
18+
match Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device) {
19+
Ok(c) => c,
20+
Err(e) => panic!("Failed to create context: {:?}", e),
21+
};
22+
let stream = match Stream::new(StreamFlags::NON_BLOCKING, None) {
23+
Ok(s) => s,
24+
Err(e) => panic!("Failed to create stream: {:?}", e),
25+
};
26+
27+
let mut in_x = DeviceBuffer::from_slice(&[1.0f32; 10]).unwrap();
28+
let mut in_y = DeviceBuffer::from_slice(&[2.0f32; 10]).unwrap();
29+
30+
// wait for any prior kernels to finish before passing
31+
// the device pointers to ArrayFire
32+
match stream.synchronize() {
33+
Ok(()) => {}
34+
Err(e) => panic!("Stream sync failure: {:?}", e),
35+
};
36+
37+
set_device(0);
38+
info();
39+
40+
let x = Array::new_from_device_ptr(in_x.as_device_ptr().as_raw_mut(), dim4!(10));
41+
let y = Array::new_from_device_ptr(in_y.as_device_ptr().as_raw_mut(), dim4!(10));
42+
43+
// Lock so that ArrayFire doesn't free pointers from RustaCUDA
44+
// But we have to make sure these pointers stay in valid scope
45+
// as long as the associated ArrayFire Array objects are valid
46+
x.lock();
47+
y.lock();
48+
49+
af_print!("x", x);
50+
af_print!("y", y);
51+
52+
let o = x + y;
53+
af_print!("out", o);
54+
55+
let _o_dptr = unsafe { o.device_ptr() }; // Calls an implicit lock
56+
57+
// User has to call unlock if they want to relenquish control to ArrayFire
58+
59+
// Once the non-arrayfire operations are done, call unlock.
60+
o.unlock(); // After this, there is no guarantee that value of o_dptr is valid
61+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use arrayfire as af;
2+
use rustacuda::prelude::*;
3+
use rustacuda::*;
4+
5+
use std::ffi::CString;
6+
7+
fn main() {
8+
// MAKE SURE to do all rustacuda initilization before arrayfire API's
9+
// first call. It seems like some CUDA context state is getting messed up
10+
// if we mix CUDA context init(device, context, module, stream) with ArrayFire API
11+
match rustacuda::init(CudaFlags::empty()) {
12+
Ok(()) => {}
13+
Err(e) => panic!("rustacuda init failure: {:?}", e),
14+
}
15+
let device = match Device::get_device(0) {
16+
Ok(d) => d,
17+
Err(e) => panic!("Failed to get device: {:?}", e),
18+
};
19+
let _context =
20+
match Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device) {
21+
Ok(c) => c,
22+
Err(e) => panic!("Failed to create context: {:?}", e),
23+
};
24+
let ptx = CString::new(include_str!("./resources/add.ptx")).unwrap();
25+
let module = match Module::load_from_string(&ptx) {
26+
Ok(m) => m,
27+
Err(e) => panic!("Failed to load module from string: {:?}", e),
28+
};
29+
let stream = match Stream::new(StreamFlags::NON_BLOCKING, None) {
30+
Ok(s) => s,
31+
Err(e) => panic!("Failed to create stream: {:?}", e),
32+
};
33+
34+
af::set_device(0);
35+
af::info();
36+
37+
let num: i32 = 10;
38+
let x = af::constant(1f32, af::dim4!(10));
39+
let y = af::constant(2f32, af::dim4!(10));
40+
let out = af::constant(0f32, af::dim4!(10));
41+
42+
af::af_print!("x", x);
43+
af::af_print!("y", y);
44+
af::af_print!("out(init)", out);
45+
46+
//TODO Figure out how to use Stream returned by ArrayFire with Rustacuda
47+
// let af_id = get_device();
48+
// let cuda_id = get_device_native_id(af_id);
49+
// let af_cuda_stream = get_stream(cuda_id);
50+
51+
//TODO Figure out how to use Stream returned by ArrayFire with Rustacuda
52+
// let stream = Stream {inner: mem::transmute(af_cuda_stream)};
53+
54+
// Run a custom CUDA kernel in the ArrayFire CUDA stream
55+
unsafe {
56+
// Obtain device pointers from ArrayFire using Array::device() method
57+
let d_x: *mut f32 = x.device_ptr() as *mut f32;
58+
let d_y: *mut f32 = y.device_ptr() as *mut f32;
59+
let d_o: *mut f32 = out.device_ptr() as *mut f32;
60+
61+
match launch!(module.sum<<<1, 1, 0, stream>>>(
62+
memory::DevicePointer::wrap(d_x),
63+
memory::DevicePointer::wrap(d_y),
64+
memory::DevicePointer::wrap(d_o),
65+
num
66+
)) {
67+
Ok(()) => {}
68+
Err(e) => panic!("Kernel Launch failure: {:?}", e),
69+
}
70+
71+
// wait for the kernel to finish as it is async call
72+
match stream.synchronize() {
73+
Ok(()) => {}
74+
Err(e) => panic!("Stream sync failure: {:?}", e),
75+
};
76+
77+
// Return control of Array memory to ArrayFire using unlock
78+
x.unlock();
79+
y.unlock();
80+
out.unlock();
81+
}
82+
af::af_print!("sum after kernel launch", out);
83+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
extern "C" __constant__ int my_constant = 314;
2+
3+
extern "C" __global__ void sum(const float* x, const float* y, float* out, int count) {
4+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
5+
out[i] = x[i] + y[i];
6+
}
7+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//
2+
// Generated by NVIDIA NVVM Compiler
3+
//
4+
// Compiler Build ID: CL-24817639
5+
// Cuda compilation tools, release 10.0, V10.0.130
6+
// Based on LLVM 3.4svn
7+
//
8+
9+
.version 3.2
10+
.target sm_20
11+
.address_size 64
12+
13+
// .globl sum
14+
.const .align 4 .u32 my_constant = 314;
15+
16+
.visible .entry sum(
17+
.param .u64 sum_param_0,
18+
.param .u64 sum_param_1,
19+
.param .u64 sum_param_2,
20+
.param .u32 sum_param_3
21+
)
22+
{
23+
.reg .pred %p<3>;
24+
.reg .f32 %f<4>;
25+
.reg .b32 %r<11>;
26+
.reg .b64 %rd<11>;
27+
28+
29+
ld.param.u64 %rd4, [sum_param_0];
30+
ld.param.u64 %rd5, [sum_param_1];
31+
ld.param.u64 %rd6, [sum_param_2];
32+
ld.param.u32 %r6, [sum_param_3];
33+
mov.u32 %r1, %ntid.x;
34+
mov.u32 %r7, %ctaid.x;
35+
mov.u32 %r8, %tid.x;
36+
mad.lo.s32 %r10, %r1, %r7, %r8;
37+
setp.ge.s32 %p1, %r10, %r6;
38+
@%p1 bra BB0_3;
39+
40+
cvta.to.global.u64 %rd1, %rd6;
41+
cvta.to.global.u64 %rd2, %rd5;
42+
cvta.to.global.u64 %rd3, %rd4;
43+
mov.u32 %r9, %nctaid.x;
44+
mul.lo.s32 %r3, %r9, %r1;
45+
46+
BB0_2:
47+
mul.wide.s32 %rd7, %r10, 4;
48+
add.s64 %rd8, %rd3, %rd7;
49+
add.s64 %rd9, %rd2, %rd7;
50+
ld.global.f32 %f1, [%rd9];
51+
ld.global.f32 %f2, [%rd8];
52+
add.f32 %f3, %f2, %f1;
53+
add.s64 %rd10, %rd1, %rd7;
54+
st.global.f32 [%rd10], %f3;
55+
add.s32 %r10, %r3, %r10;
56+
setp.lt.s32 %p2, %r10, %r6;
57+
@%p2 bra BB0_2;
58+
59+
BB0_3:
60+
ret;
61+
}

cuda-interop/src/lib.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//! af-cuda-interop package is to used only when the application intends to mix
2+
//! arrayfire code with raw CUDA code.
3+
4+
use arrayfire::{handle_error_general, AfError};
5+
use cuda_runtime_sys::cudaStream_t;
6+
use libc::c_int;
7+
8+
extern "C" {
9+
fn afcu_get_native_id(native_id: *mut c_int, id: c_int) -> c_int;
10+
fn afcu_set_native_id(native_id: c_int) -> c_int;
11+
fn afcu_get_stream(out: *mut cudaStream_t, id: c_int) -> c_int;
12+
}
13+
14+
/// Get active device's id in CUDA context
15+
///
16+
/// # Parameters
17+
///
18+
/// - `id` is the integer identifier of concerned CUDA device as per ArrayFire context
19+
///
20+
/// # Return Values
21+
///
22+
/// Integer identifier of device in CUDA context
23+
pub fn get_device_native_id(id: i32) -> i32 {
24+
unsafe {
25+
let mut temp: i32 = 0;
26+
let err_val = afcu_get_native_id(&mut temp as *mut c_int, id);
27+
handle_error_general(AfError::from(err_val));
28+
temp
29+
}
30+
}
31+
32+
/// Set active device using CUDA context's id
33+
///
34+
/// # Parameters
35+
///
36+
/// - `id` is the identifier of GPU in CUDA context
37+
pub fn set_device_native_id(native_id: i32) {
38+
unsafe {
39+
let err_val = afcu_set_native_id(native_id);
40+
handle_error_general(AfError::from(err_val));
41+
}
42+
}
43+
44+
/// Get CUDA stream of active CUDA device
45+
///
46+
/// # Parameters
47+
///
48+
/// - `id` is the identifier of device in ArrayFire context
49+
///
50+
/// # Return Values
51+
///
52+
/// [cudaStream_t](https://docs.rs/cuda-runtime-sys/0.3.0-alpha.1/cuda_runtime_sys/type.cudaStream_t.html) handle.
53+
pub fn get_stream(native_id: i32) -> cudaStream_t {
54+
unsafe {
55+
let mut ret_val: cudaStream_t = std::ptr::null_mut();
56+
let err_val = afcu_get_stream(&mut ret_val as *mut cudaStream_t, native_id);
57+
handle_error_general(AfError::from(err_val));
58+
ret_val
59+
}
60+
}

scripts/generate_documentation.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
# this script meant to be run from the root of arrayfire-rust
44

55
cargo rustdoc -p arrayfire -- --html-in-header ./scripts/mathjax.script
6+
cargo rustdoc -p af-cuda-interop -- --html-in-header ./scripts/mathjax.script
67

7-
mdbook build tutorials-book && cp -r tutorials-book/book ./target/doc/arrayfire/
8+
mdbook build tutorials-book && cp -r tutorials-book/book ./target/doc/

src/blas/mod.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use super::core::{af_array, AfError, Array, FloatingPoint, HasAfEnum, MatProp, HANDLE_ERROR};
1+
use super::core::{
2+
af_array, AfError, Array, CublasMathMode, FloatingPoint, HasAfEnum, MatProp, HANDLE_ERROR,
3+
};
24

35
use libc::{c_int, c_uint, c_void};
46
use std::vec::Vec;
@@ -32,6 +34,8 @@ extern "C" {
3234

3335
fn af_transpose(out: *mut af_array, arr: af_array, conjugate: bool) -> c_int;
3436
fn af_transpose_inplace(arr: af_array, conjugate: bool) -> c_int;
37+
38+
fn afcu_cublasSetMathMode(mode: c_int) -> c_int;
3539
}
3640

3741
/// BLAS general matrix multiply (GEMM) of two Array objects
@@ -237,3 +241,20 @@ pub fn transpose_inplace<T: HasAfEnum>(arr: &mut Array<T>, conjugate: bool) {
237241
HANDLE_ERROR(AfError::from(err_val));
238242
}
239243
}
244+
245+
/// Sets the cuBLAS math mode for the internal handle.
246+
///
247+
/// See the cuBLAS documentation for additional details
248+
///
249+
/// # Parameters
250+
///
251+
/// - `mode` takes a value of [CublasMathMode](./enum.CublasMathMode.html) enum
252+
pub fn set_cublas_mode(mode: CublasMathMode) {
253+
unsafe {
254+
afcu_cublasSetMathMode(mode as c_int);
255+
//let err_val = afcu_cublasSetMathMode(mode as c_int);
256+
// FIXME(wonder if this something to throw off,
257+
// the program state is not invalid or anything
258+
// HANDLE_ERROR(AfError::from(err_val));
259+
}
260+
}

0 commit comments

Comments
 (0)