Skip to content

Commit 5114b25

Browse files
Added dpt.copy, and dpt.asnumpy (alias of dpt.to_numpy) (#595)
1 parent f2a12e1 commit 5114b25

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,19 @@
2929
3030
"""
3131

32-
from dpctl.tensor._copy_utils import astype
32+
from dpctl.tensor._copy_utils import astype, copy
3333
from dpctl.tensor._copy_utils import copy_from_numpy as from_numpy
34+
from dpctl.tensor._copy_utils import copy_to_numpy as asnumpy
3435
from dpctl.tensor._copy_utils import copy_to_numpy as to_numpy
3536
from dpctl.tensor._reshape import reshape
3637
from dpctl.tensor._usmarray import usm_ndarray
3738

3839
__all__ = [
3940
"usm_ndarray",
4041
"astype",
42+
"copy",
4143
"reshape",
4244
"from_numpy",
4345
"to_numpy",
46+
"asnumpy",
4447
]

dpctl/tensor/_copy_utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,71 @@ def copy_from_usm_ndarray_to_usm_ndarray(dst, src):
253253
copy_same_shape(dst, src_same_shape)
254254

255255

256+
def copy(usm_ary, order="K"):
257+
"""
258+
Creates a copy of given instance of `usm_ndarray`.
259+
260+
Memory layour of the copy is controlled by `order` keyword,
261+
following NumPy's conventions. The `order` keywords can be
262+
one of the following:
263+
264+
"C": C-contiguous memory layout
265+
"F": Fotrant-contiguous memory layout
266+
"A": Fotrant-contiguous if the input array is
267+
F-contiguous, and C-contiguous otherwise
268+
"K": match the layout of `usm_ary` as closely
269+
as possible.
270+
271+
"""
272+
if not isinstance(usm_ary, dpt.usm_ndarray):
273+
return TypeError(
274+
"Expected object of type dpt.usm_ndarray, got {}".format(
275+
type(usm_ary)
276+
)
277+
)
278+
copy_order = "C"
279+
if order == "C":
280+
pass
281+
elif order == "F":
282+
copy_order = order
283+
elif order == "A":
284+
if usm_ary.flags & 2:
285+
copy_order = "F"
286+
elif order == "K":
287+
if usm_ary.flags & 2:
288+
copy_order = "F"
289+
else:
290+
raise ValueError(
291+
"Unrecognized value of the order keyword. "
292+
"Recognized values are 'A', 'C', 'F', or 'K'"
293+
)
294+
c_contig = usm_ary.flags & 1
295+
f_contig = usm_ary.flags & 2
296+
R = dpt.usm_ndarray(
297+
usm_ary.shape,
298+
dtype=usm_ary.dtype,
299+
buffer=usm_ary.usm_type,
300+
order=copy_order,
301+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
302+
)
303+
if order == "K" and (not c_contig and not f_contig):
304+
original_strides = usm_ary.strides
305+
ind = sorted(
306+
range(usm_ary.ndim),
307+
key=lambda i: abs(original_strides[i]),
308+
reverse=True,
309+
)
310+
new_strides = tuple(R.strides[ind[i]] for i in ind)
311+
R = dpt.usm_ndarray(
312+
usm_ary.shape,
313+
dtype=usm_ary.dtype,
314+
buffer=R.usm_data,
315+
strides=new_strides,
316+
)
317+
copy_same_dtype(R, usm_ary)
318+
return R
319+
320+
256321
def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
257322
"""
258323
astype(usm_array, new_dtype, order="K", casting="unsafe", copy=True)
@@ -267,6 +332,11 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
267332
type(usm_ary)
268333
)
269334
)
335+
if not isinstance(order, str) or order not in ["A", "C", "F", "K"]:
336+
raise ValueError(
337+
"Unrecognized value of the order keyword. "
338+
"Recognized values are 'A', 'C', 'F', or 'K'"
339+
)
270340
ary_dtype = usm_ary.dtype
271341
target_dtype = np.dtype(newdtype)
272342
if not np.can_cast(ary_dtype, target_dtype, casting=casting):
@@ -294,6 +364,11 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
294364
elif order == "K":
295365
if usm_ary.flags & 2:
296366
copy_order = "F"
367+
else:
368+
raise ValueError(
369+
"Unrecognized value of the order keyword. "
370+
"Recognized values are 'A', 'C', 'F', or 'K'"
371+
)
297372
R = dpt.usm_ndarray(
298373
usm_ary.shape,
299374
dtype=target_dtype,

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,34 @@ def test_astype():
768768
assert Y.usm_data is X.usm_data
769769

770770

771+
def test_astype_invalid_order():
772+
X = dpt.usm_ndarray(5, "i4")
773+
with pytest.raises(ValueError):
774+
dpt.astype(X, "i4", order="WRONG")
775+
776+
777+
def test_copy():
778+
X = dpt.usm_ndarray((5, 5), "i4")[2:4, 1:4]
779+
X[:] = 42
780+
Yc = dpt.copy(X, order="C")
781+
Yf = dpt.copy(X, order="F")
782+
Ya = dpt.copy(X, order="A")
783+
Yk = dpt.copy(X, order="K")
784+
assert Yc.usm_data is not X.usm_data
785+
assert Yf.usm_data is not X.usm_data
786+
assert Ya.usm_data is not X.usm_data
787+
assert Yk.usm_data is not X.usm_data
788+
assert Yc.strides == (3, 1)
789+
assert Yf.strides == (1, 2)
790+
assert Ya.strides == (3, 1)
791+
assert Yk.strides == (3, 1)
792+
ref = np.full(X.shape, 42, dtype=X.dtype)
793+
assert np.array_equal(dpt.asnumpy(Yc), ref)
794+
assert np.array_equal(dpt.asnumpy(Yf), ref)
795+
assert np.array_equal(dpt.asnumpy(Ya), ref)
796+
assert np.array_equal(dpt.asnumpy(Yk), ref)
797+
798+
771799
def test_ctor_invalid():
772800
m = dpm.MemoryUSMShared(12)
773801
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)