Skip to content

Commit 3048f3e

Browse files
Fixes #729
The generated reshaped_strides routine is meant for non-empty arrays, and so empty ones must be handled differently. Tests added as well.
1 parent a48f381 commit 3048f3e

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

dpctl/tensor/_reshape.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def reshape(X, newshape, order="C"):
104104
newshape = [v if d == -1 else d for d in newshape]
105105
if X.size != np.prod(newshape):
106106
raise ValueError("Can not reshape into {}".format(newshape))
107-
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
107+
if X.size:
108+
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
109+
else:
110+
newsts = (1,) * len(newshape)
108111
if newsts is None:
109112
# must perform a copy
110113
flat_res = dpt.usm_ndarray(

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,32 @@ def test_reshape():
816816
Y = dpt.reshape(X, X.shape)
817817
assert Y.flags == X.flags
818818

819+
A = dpt.usm_ndarray((0,), "i4")
820+
A1 = dpt.reshape(A, (0,))
821+
assert A1.shape == (0,)
822+
A2 = dpt.reshape(
823+
A,
824+
(
825+
2,
826+
0,
827+
),
828+
)
829+
assert A2.shape == (
830+
2,
831+
0,
832+
)
833+
A3 = dpt.reshape(A, (0, 2))
834+
assert A3.shape == (
835+
0,
836+
2,
837+
)
838+
A4 = dpt.reshape(A, (1, 0, 2))
839+
assert A4.shape == (
840+
1,
841+
0,
842+
2,
843+
)
844+
819845

820846
def test_transpose():
821847
n, m = 2, 3

0 commit comments

Comments
 (0)