Skip to content

Commit 271b66a

Browse files
committed
Support to let users create Array from raw device pointers
The example show cases the usage using CUDA API, but it should be similar to OpenCL or CPU. In the case of OpenCL backend, the pointer would be cl_mem
1 parent 63ed015 commit 271b66a

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

src/core/array.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ extern "C" {
2828
aftype: c_uint,
2929
) -> c_int;
3030

31+
fn af_device_array(
32+
out: *mut af_array,
33+
data: *mut c_void,
34+
ndims: c_uint,
35+
dims: *const dim_t,
36+
aftype: c_uint,
37+
) -> c_int;
38+
3139
fn af_get_elements(out: *mut dim_t, arr: af_array) -> c_int;
3240

3341
fn af_get_type(out: *mut c_uint, arr: af_array) -> c_int;
@@ -254,6 +262,86 @@ where
254262
}
255263
}
256264

265+
/// Constructs a new Array object from device pointer
266+
///
267+
/// The example show cases the usage using CUDA API, but usage of this function will
268+
/// be similar in CPU and OpenCL backends also. In the case of OpenCL backend, the pointer
269+
/// would be cl_mem.
270+
///
271+
/// # Examples
272+
///
273+
/// An example of creating an Array device pointer using
274+
/// [rustacuda](https://github.com/bheisler/RustaCUDA) crate. The
275+
/// example has to be copied to a `bin` crate with following contents in Cargo.toml
276+
/// to run successfully. Note that, all required setup for rustacuda and arrayfire crate
277+
/// have to completed first.
278+
/// ```text
279+
/// [package]
280+
/// ....
281+
/// [dependencies]
282+
/// rustacuda = "0.1"
283+
/// rustacuda_derive = "0.1"
284+
/// rustacuda_core = "0.1"
285+
/// arrayfire = "3.7.*"
286+
/// ```
287+
///
288+
/// ```rust,ignore
289+
///use arrayfire::*;
290+
///use rustacuda::*;
291+
///use rustacuda::prelude::*;
292+
///
293+
///fn main() {
294+
/// let v: Vec<_> = (0u8 .. 100).map(f32::from).collect();
295+
///
296+
/// rustacuda::init(CudaFlags::empty());
297+
/// let device = Device::get_device(0).unwrap();
298+
/// let context = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO,
299+
/// device).unwrap();
300+
/// // Approach 1
301+
/// {
302+
/// let mut buffer = memory::DeviceBuffer::from_slice(&v).unwrap();
303+
///
304+
/// let array_dptr = Array::new_from_device_ptr(
305+
/// buffer.as_device_ptr().as_raw_mut(), dim4!(10, 10));
306+
///
307+
/// af_print!("array_dptr", &array_dptr);
308+
///
309+
/// array_dptr.lock(); // Needed to avoid free as arrayfire takes ownership
310+
/// }
311+
///
312+
/// // Approach 2
313+
/// {
314+
/// let mut dptr: *mut f32 = std::ptr::null_mut();
315+
/// unsafe {
316+
/// dptr = memory::cuda_malloc::<f32>(10*10).unwrap().as_raw_mut();
317+
/// }
318+
/// let array_dptr = Array::new_from_device_ptr(dptr, dim4!(10, 10));
319+
/// // note that values might be garbage in the memory pointed out by dptr
320+
/// // in this example as it is allocated but not initialized prior to passing
321+
/// // along to arrayfire::Array::new*
322+
///
323+
/// // After ArrayFire takes over ownership of the pointer, you can use other
324+
/// // arrayfire functions as usual.
325+
/// af_print!("array_dptr", &array_dptr);
326+
/// }
327+
///}
328+
/// ```
329+
pub fn new_from_device_ptr(dev_ptr: *mut T, dims: Dim4) -> Self {
330+
let aftype = T::get_af_dtype();
331+
unsafe {
332+
let mut temp: af_array = std::ptr::null_mut();
333+
let err_val = af_device_array(
334+
&mut temp as *mut af_array,
335+
dev_ptr as *mut c_void,
336+
dims.ndims() as c_uint,
337+
dims.get().as_ptr() as *const dim_t,
338+
aftype as c_uint,
339+
);
340+
HANDLE_ERROR(AfError::from(err_val));
341+
temp.into()
342+
}
343+
}
344+
257345
/// Returns the backend of the Array
258346
///
259347
/// # Return Values

src/core/util.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::defines::{
2-
AfError, BinaryOp, ColorMap, ConvDomain, ConvMode, DType, InterpType, MatchType, MatProp,
2+
AfError, BinaryOp, ColorMap, ConvDomain, ConvMode, DType, InterpType, MatProp, MatchType,
33
RandomEngineType, SparseFormat,
44
};
55
use super::error::HANDLE_ERROR;

0 commit comments

Comments
 (0)