Skip to content

Commit e2ba101

Browse files
committed
FEAT: Adding function to copy to existing ndarray
1 parent 252f9b7 commit e2ba101

File tree

3 files changed

+67
-4
lines changed

3 files changed

+67
-4
lines changed

arrayfire/array.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@ def _get_assign_dims(key, idims):
296296
else:
297297
raise IndexError("Invalid type while assigning to arrayfire.array")
298298

299-
300299
def transpose(a, conj=False):
301300
"""
302301
Perform the transpose on an input.
@@ -504,7 +503,9 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
504503
if(offset is None and strides is None):
505504
self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char], is_device)
506505
else:
507-
self.arr = _create_strided_array(buf, numdims, idims, to_dtype[_type_char], is_device, offset, strides)
506+
self.arr = _create_strided_array(buf, numdims, idims,
507+
to_dtype[_type_char],
508+
is_device, offset, strides)
508509

509510
else:
510511

@@ -1159,6 +1160,19 @@ def __setitem__(self, key, val):
11591160
except RuntimeError as e:
11601161
raise IndexError(str(e))
11611162

1163+
def _reorder(self):
1164+
"""
1165+
Returns a reordered array to help interoperate with row major formats.
1166+
"""
1167+
ndims = self.numdims()
1168+
if (ndims == 1):
1169+
return self
1170+
1171+
rdims = tuple(reversed(range(ndims))) + tuple(range(ndims, 4))
1172+
out = Array()
1173+
safe_call(backend.get().af_reorder(c_pointer(out.arr), self.arr, *rdims))
1174+
return out
1175+
11621176
def to_ctype(self, row_major=False, return_shape=False):
11631177
"""
11641178
Return the data as a ctype C array after copying to host memory
@@ -1312,6 +1326,44 @@ def __array__(self):
13121326
safe_call(backend.get().af_get_data_ptr(c_void_ptr_t(res.ctypes.data), self.arr))
13131327
return res
13141328

1329+
def to_ndarray(self, output=None):
1330+
"""
1331+
Parameters
1332+
-----------
1333+
output: optional: numpy. default: None
1334+
1335+
Returns
1336+
----------
1337+
If output is None: Constructs a numpy.array from arrayfire.Array
1338+
If output is not None: copies content of af.array into numpy array.
1339+
1340+
Note
1341+
------
1342+
1343+
- An exception is thrown when output is not None and it is not contiguous.
1344+
- When output is None, The returned array is in fortran contiguous order.
1345+
"""
1346+
if output is None:
1347+
return self.__array__()
1348+
1349+
if (output.dtype != to_typecode[self.type()]):
1350+
raise TypeError("Output is not the same type as the array")
1351+
1352+
if (output.size != self.elements()):
1353+
raise RuntimeError("Output size does not match that of input")
1354+
1355+
flags = output.flags
1356+
tmp = None
1357+
if flags['F_CONTIGUOUS']:
1358+
tmp = self
1359+
elif flags['C_CONTIGUOUS']:
1360+
tmp = self._reorder()
1361+
else:
1362+
raise RuntimeError("When output is not None, it must be contiguous")
1363+
1364+
safe_call(backend.get().af_get_data_ptr(c_void_ptr_t(output.ctypes.data), tmp.arr))
1365+
return output
1366+
13151367
def display(a, precision=4):
13161368
"""
13171369
Displays the contents of an array.

arrayfire/data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def tile(a, d0, d1=1, d2=1, d3=1):
397397
safe_call(backend.get().af_tile(c_pointer(out.arr), a.arr, d0, d1, d2, d3))
398398
return out
399399

400-
401400
def reorder(a, d0=1, d1=0, d2=2, d3=3):
402401
"""
403402
Reorder the dimensions of the input.

arrayfire/tests/simple/interop.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,33 @@ def simple_interop(verbose = False):
1818
a = af.to_array(n)
1919
n2 = np.array(a)
2020
assert((n==n2).all())
21+
n2[:] = 0
22+
a.to_ndarray(n2)
23+
assert((n==n2).all())
2124

2225
n = np.random.random((5,3))
2326
a = af.to_array(n)
2427
n2 = np.array(a)
2528
assert((n==n2).all())
29+
n2[:] = 0
30+
a.to_ndarray(n2)
31+
assert((n==n2).all())
2632

2733
n = np.random.random((5,3,2))
2834
a = af.to_array(n)
2935
n2 = np.array(a)
3036
assert((n==n2).all())
37+
n2[:] = 0
38+
a.to_ndarray(n2)
39+
assert((n==n2).all())
3140

32-
n = np.random.random((5,3,2, 2))
41+
n = np.random.random((5,3,2,2))
3342
a = af.to_array(n)
3443
n2 = np.array(a)
3544
assert((n==n2).all())
45+
n2[:] = 0
46+
a.to_ndarray(n2)
47+
assert((n==n2).all())
3648

3749
if af.AF_PYCUDA_FOUND and af.get_active_backend() == 'cuda':
3850
import pycuda.autoinit

0 commit comments

Comments
 (0)