Skip to content

Commit 9c90442

Browse files
Merge pull request #756 from IntelPython/bugfix/gh-729-reshape
Fix for errors reported in #729 when reshaping 0-elems array
2 parents a48f381 + 479ed60 commit 9c90442

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

dpctl/tensor/_reshape.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def reshape(X, newshape, order="C"):
104104
newshape = [v if d == -1 else d for d in newshape]
105105
if X.size != np.prod(newshape):
106106
raise ValueError("Can not reshape into {}".format(newshape))
107-
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
107+
if X.size:
108+
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
109+
else:
110+
newsts = (1,) * len(newshape)
108111
if newsts is None:
109112
# must perform a copy
110113
flat_res = dpt.usm_ndarray(

dpctl/tensor/_usmarray.pyx

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ cdef class usm_ndarray:
428428
cdef int contig_flag = 0
429429
cdef Py_ssize_t *shape_ptr = NULL
430430
cdef Py_ssize_t *strides_ptr = NULL
431+
cdef Py_ssize_t size = -1
431432
import operator
432433

433434
from ._reshape import reshaped_strides
@@ -439,15 +440,19 @@ cdef class usm_ndarray:
439440
raise TypeError(
440441
"Target shape must be a finite iterable of integers"
441442
)
442-
if not np.prod(new_shape) == shape_to_elem_count(self.nd_, self.shape_):
443+
size = shape_to_elem_count(self.nd_, self.shape_)
444+
if not np.prod(new_shape) == size:
443445
raise TypeError(
444446
f"Can not reshape array of size {self.size} into {new_shape}"
445447
)
446-
new_strides = reshaped_strides(
447-
self.shape,
448-
self.strides,
449-
new_shape
450-
)
448+
if size > 0:
449+
new_strides = reshaped_strides(
450+
self.shape,
451+
self.strides,
452+
new_shape
453+
)
454+
else:
455+
new_strides = (1,) * len(new_shape)
451456
if new_strides is None:
452457
raise AttributeError(
453458
"Incompatible shape for in-place modification. "

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,21 @@ def relaxed_strides_equal(st1, st2, sh):
713713
X = dpt.usm_ndarray((4, 4), dtype="d")[::2, ::2]
714714
with pytest.raises(AttributeError):
715715
X.shape = (4,)
716+
X = dpt.usm_ndarray((0,), dtype="i4")
717+
X.shape = (0,)
718+
X.shape = (
719+
2,
720+
0,
721+
)
722+
X.shape = (
723+
0,
724+
2,
725+
)
726+
X.shape = (
727+
1,
728+
0,
729+
1,
730+
)
716731

717732

718733
def test_len():
@@ -816,6 +831,32 @@ def test_reshape():
816831
Y = dpt.reshape(X, X.shape)
817832
assert Y.flags == X.flags
818833

834+
A = dpt.usm_ndarray((0,), "i4")
835+
A1 = dpt.reshape(A, (0,))
836+
assert A1.shape == (0,)
837+
A2 = dpt.reshape(
838+
A,
839+
(
840+
2,
841+
0,
842+
),
843+
)
844+
assert A2.shape == (
845+
2,
846+
0,
847+
)
848+
A3 = dpt.reshape(A, (0, 2))
849+
assert A3.shape == (
850+
0,
851+
2,
852+
)
853+
A4 = dpt.reshape(A, (1, 0, 2))
854+
assert A4.shape == (
855+
1,
856+
0,
857+
2,
858+
)
859+
819860

820861
def test_transpose():
821862
n, m = 2, 3

0 commit comments

Comments
 (0)