Skip to content

Commit 7e1bf48

Browse files
Added test_reshape, fixed error uncovered by it (including ones in copying)
1 parent 4f20687 commit 7e1bf48

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def copy_same_dtype(dst, src):
129129
return
130130

131131
if (dst.flags & 1) and (src.flags & 1):
132-
dst.usm_data.copy_from_device(src.usm_data)
132+
dst_mem = dpm.as_usm_memory(dst)
133+
src_mem = dpm.as_usm_memory(src)
134+
dst_mem.copy_from_device(src_mem)
133135
return
134136

135137
# simplify strides

dpctl/tensor/_reshape.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import collections.abc
1+
import operator
22

33
import numpy as np
44

@@ -59,14 +59,27 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
5959
def reshape(X, newshape, order="C"):
6060
if type(X) is not dpt.usm_ndarray:
6161
raise TypeError
62-
if X.size != np.prod(newshape):
63-
raise ValueError("Can not reshape into {}".format(newshape))
64-
if not isinstance(newshape, collections.abc.Sized):
62+
if not isinstance(newshape, (list, tuple)):
6563
newshape = (newshape,)
6664
if order not in ["C", "F"]:
67-
return ValueError(
65+
raise ValueError(
6866
f"Keyword 'order' not recognized. Expecting 'C' or 'F', got {order}"
6967
)
68+
newshape = [operator.index(d) for d in newshape]
69+
negative_ones_count = 0
70+
for i in range(len(newshape)):
71+
if newshape[i] == -1:
72+
negative_ones_count = negative_ones_count + 1
73+
if (newshape[i] < -1) or negative_ones_count > 1:
74+
raise ValueError(
75+
"Target shape should have at most 1 negative "
76+
"value which can only be -1"
77+
)
78+
if negative_ones_count:
79+
v = X.size // (-np.prod(newshape))
80+
newshape = [v if d == -1 else d for d in newshape]
81+
if X.size != np.prod(newshape):
82+
raise ValueError("Can not reshape into {}".format(newshape))
7083
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
7184
if newsts is None:
7285
# must perform a copy

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,3 +775,22 @@ def test_ctor_invalid():
775775
m = dpm.MemoryUSMShared(64)
776776
with pytest.raises(ValueError):
777777
dpt.usm_ndarray((4,), dtype="u1", buffer=m, strides={"not": "valid"})
778+
779+
780+
def test_reshape():
781+
X = dpt.usm_ndarray((5, 5), "i4")
782+
# can be done as views
783+
Y = dpt.reshape(X, (25,))
784+
assert Y.shape == (25,)
785+
Z = X[::2, ::2]
786+
# requires a copy
787+
W = dpt.reshape(Z, (Z.size,), order="F")
788+
assert W.shape == (Z.size,)
789+
with pytest.raises(TypeError):
790+
dpt.reshape("invalid")
791+
with pytest.raises(ValueError):
792+
dpt.reshape(Z, (2, 2, 2, 2, 2))
793+
with pytest.raises(ValueError):
794+
dpt.reshape(Z, Z.shape, order="invalid")
795+
W = dpt.reshape(Z, (-1,), order="C")
796+
assert W.shape == (Z.size,)

0 commit comments

Comments
 (0)