Skip to content

Commit 3cf70d8

Browse files
committed
Adding functions to copy from numba to arrayfire
1 parent ff7a689 commit 3cf70d8

File tree

2 files changed

+95
-15
lines changed

2 files changed

+95
-15
lines changed

arrayfire/interop.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
1. numpy - numpy.ndarray
1616
2. pycuda - pycuda.gpuarray
1717
3. pyopencl - pyopencl.array
18+
4. numba - numba.cuda.cudadrv.devicearray.DeviceNDArray
1819
1920
"""
2021

@@ -58,26 +59,27 @@ def _cc_to_af_array(in_ptr, ndim, in_shape, in_dtype, is_device=False, copy = Tr
5859
else:
5960
raise RuntimeError("Unsupported ndim")
6061

62+
63+
_nptype_to_aftype = {'b1' : Dtype.b8,
64+
'u1' : Dtype.u8,
65+
'u2' : Dtype.u16,
66+
'i2' : Dtype.s16,
67+
's4' : Dtype.u32,
68+
'i4' : Dtype.s32,
69+
'f4' : Dtype.f32,
70+
'c8' : Dtype.c32,
71+
's8' : Dtype.u64,
72+
'i8' : Dtype.s64,
73+
'f8' : Dtype.f64,
74+
'c16' : Dtype.c64}
75+
6176
try:
6277
import numpy as np
6378
from numpy import ndarray as NumpyArray
6479
from .data import reorder
6580

6681
AF_NUMPY_FOUND=True
6782

68-
_nptype_to_aftype = {'b1' : Dtype.b8,
69-
'u1' : Dtype.u8,
70-
'u2' : Dtype.u16,
71-
'i2' : Dtype.s16,
72-
's4' : Dtype.u32,
73-
'i4' : Dtype.s32,
74-
'f4' : Dtype.f32,
75-
'c8' : Dtype.c32,
76-
's8' : Dtype.u64,
77-
'i8' : Dtype.s64,
78-
'f8' : Dtype.f64,
79-
'c16' : Dtype.c64}
80-
8183
def np_to_af_array(np_arr, copy=True):
8284
"""
8385
Convert numpy.ndarray to arrayfire.Array.
@@ -222,6 +224,48 @@ def pyopencl_to_af_array(pycl_arr, copy=True):
222224
except:
223225
AF_PYOPENCL_FOUND=False
224226

227+
try:
228+
import numba
229+
from numba import cuda
230+
NumbaCudaArray = cuda.cudadrv.devicearray.DeviceNDArray
231+
AF_NUMBA_FOUND=True
232+
233+
def numba_to_af_array(nb_arr, copy=True):
234+
"""
235+
Convert numba.gpuarray to arrayfire.Array
236+
237+
Parameters
238+
-----------
239+
nb_arr : numba.cuda.cudadrv.devicearray.DeviceNDArray()
240+
241+
copy : Bool specifying if array is to be copied.
242+
Default is true.
243+
Can only be False if array is fortran contiguous.
244+
245+
Returns
246+
----------
247+
af_arr : arrayfire.Array()
248+
249+
Note
250+
----------
251+
The input array is copied to af.Array
252+
"""
253+
254+
in_ptr = nb_arr.device_ctypes_pointer.value
255+
in_shape = nb_arr.shape
256+
in_dtype = _nptype_to_aftype[nb_arr.dtype.str[1:]]
257+
258+
if not copy and not nb_arr.flags.f_contiguous:
259+
raise RuntimeError("Copy can only be False when arr.flags.f_contiguous is True")
260+
261+
if (nb_arr.is_f_contiguous()):
262+
return _fc_to_af_array(in_ptr, in_shape, in_dtype, True, copy)
263+
elif (nb_arr.is_c_contiguous()):
264+
return _cc_to_af_array(in_ptr, nb_arr.ndim, in_shape, in_dtype, True, copy)
265+
else:
266+
return numba_to_af_array(nb_arr.copy())
267+
except:
268+
AF_NUMBA_FOUND=False
225269

226270
def to_array(in_array, copy = True):
227271
"""
@@ -231,8 +275,13 @@ def to_array(in_array, copy = True):
231275
-------------
232276
233277
in_array : array like object
234-
Can be one of numpy.ndarray, pycuda.GPUArray, pyopencl.Array, array.array, list
235-
278+
Can be one of the following:
279+
- numpy.ndarray
280+
- pycuda.GPUArray
281+
- pyopencl.Array
282+
- numba.cuda.cudadrv.devicearray.DeviceNDArray
283+
- array.array
284+
- list
236285
copy : Bool specifying if array is to be copied.
237286
Default is true.
238287
Can only be False if array is fortran contiguous.
@@ -248,4 +297,6 @@ def to_array(in_array, copy = True):
248297
return pycuda_to_af_array(in_array, copy)
249298
if AF_PYOPENCL_FOUND and isinstance(in_array, OpenclArray):
250299
return pyopencl_to_af_array(in_array, copy)
300+
if AF_NUMBA_FOUND and isinstance(in_array, NumbaCudaArray):
301+
return numba_to_af_array(in_array, copy)
251302
return Array(src=in_array)

arrayfire/tests/simple/interop.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,33 @@ def simple_interop(verbose = False):
9595
# assert((n==n2).all())
9696
pass
9797

98+
if af.AF_NUMBA_FOUND and af.get_active_backend() == 'cuda':
99+
100+
import numba
101+
from numba import cuda
102+
103+
n = np.random.random((5,))
104+
c = cuda.to_device(n)
105+
a = af.to_array(c)
106+
n2 = np.array(a)
107+
assert((n==n2).all())
108+
109+
n = np.random.random((5,3))
110+
c = cuda.to_device(n)
111+
a = af.to_array(c)
112+
n2 = np.array(a)
113+
assert((n==n2).all())
114+
115+
n = np.random.random((5,3,2))
116+
c = cuda.to_device(n)
117+
a = af.to_array(c)
118+
n2 = np.array(a)
119+
assert((n==n2).all())
120+
121+
n = np.random.random((5,3,2,2))
122+
c = cuda.to_device(n)
123+
a = af.to_array(c)
124+
n2 = np.array(a)
125+
assert((n==n2).all())
126+
98127
_util.tests['interop'] = simple_interop

0 commit comments

Comments
 (0)