Skip to content

Commit 346392e

Browse files
Removed stray print in _usmarray.pyx
Improved validation for the routine computing strides of the new reshaped array for a view, or None if view is not possible. This fixes an exception raised for ``` X = dpt.usm_ndarray((1,), "i4") X.shape = (1,) # used to raise, not works as expected ```
1 parent 20fc56b commit 346392e

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

dpctl/tensor/_reshape.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
6666
]
6767
]
6868
valid = all(
69-
[check_st == old_st for check_st, old_st in zip(check_sts, old_sts)]
69+
[
70+
check_st == old_st or old_dim == 1
71+
for check_st, old_st, old_dim in zip(check_sts, old_sts, old_sh)
72+
]
7073
)
7174
return new_sts if valid else None
7275

dpctl/tensor/_usmarray.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,6 @@ cdef class usm_ndarray:
471471
PyMem_Free(self.shape_)
472472
if (self.strides_):
473473
PyMem_Free(self.strides_)
474-
print(contig_flag)
475474
self.flags_ = contig_flag
476475
self.nd_ = new_nd
477476
self.shape_ = shape_ptr

0 commit comments

Comments
 (0)