Skip to content

Commit 479ed60

Browse files
reshaped_strides is also called from shape setter
Special case setting shape for zero-element arrays
1 parent 3048f3e commit 479ed60

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ cdef class usm_ndarray:
428428
cdef int contig_flag = 0
429429
cdef Py_ssize_t *shape_ptr = NULL
430430
cdef Py_ssize_t *strides_ptr = NULL
431+
cdef Py_ssize_t size = -1
431432
import operator
432433

433434
from ._reshape import reshaped_strides
@@ -439,15 +440,19 @@ cdef class usm_ndarray:
439440
raise TypeError(
440441
"Target shape must be a finite iterable of integers"
441442
)
442-
if not np.prod(new_shape) == shape_to_elem_count(self.nd_, self.shape_):
443+
size = shape_to_elem_count(self.nd_, self.shape_)
444+
if not np.prod(new_shape) == size:
443445
raise TypeError(
444446
f"Can not reshape array of size {self.size} into {new_shape}"
445447
)
446-
new_strides = reshaped_strides(
447-
self.shape,
448-
self.strides,
449-
new_shape
450-
)
448+
if size > 0:
449+
new_strides = reshaped_strides(
450+
self.shape,
451+
self.strides,
452+
new_shape
453+
)
454+
else:
455+
new_strides = (1,) * len(new_shape)
451456
if new_strides is None:
452457
raise AttributeError(
453458
"Incompatible shape for in-place modification. "

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,21 @@ def relaxed_strides_equal(st1, st2, sh):
713713
X = dpt.usm_ndarray((4, 4), dtype="d")[::2, ::2]
714714
with pytest.raises(AttributeError):
715715
X.shape = (4,)
716+
X = dpt.usm_ndarray((0,), dtype="i4")
717+
X.shape = (0,)
718+
X.shape = (
719+
2,
720+
0,
721+
)
722+
X.shape = (
723+
0,
724+
2,
725+
)
726+
X.shape = (
727+
1,
728+
0,
729+
1,
730+
)
716731

717732

718733
def test_len():

0 commit comments

Comments
 (0)