Skip to content

Commit d469c84

Browse files
Merge pull request #619 from IntelPython/improve_copy_utils_coverage
2 parents 42c25df + 2315489 commit d469c84

File tree

4 files changed

+24
-5
lines changed

4 files changed

+24
-5
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,6 @@ def copy_same_shape(dst, src):
194194
sh_i, dst_st, dst_disp, src_st, src_disp = contract_iter2(
195195
dst.shape, dst.strides, src.strides
196196
)
197-
# sh_i, dst_st, dst_disp, src_st, src_disp = (
198-
# dst.shape, dst.strides, 0, src.strides, 0
199-
# )
200197
src_iface = src.__sycl_usm_array_interface__
201198
dst_iface = dst.__sycl_usm_array_interface__
202199
src_iface["shape"] = tuple()
@@ -250,6 +247,7 @@ def copy_from_usm_ndarray_to_usm_ndarray(dst, src):
250247
)
251248
else:
252249
src_same_shape = src
250+
src_same_shape.shape = common_shape
253251

254252
copy_same_shape(dst, src_same_shape)
255253

dpctl/tensor/_reshape.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
6666
]
6767
]
6868
valid = all(
69-
[check_st == old_st for check_st, old_st in zip(check_sts, old_sts)]
69+
[
70+
check_st == old_st or old_dim == 1
71+
for check_st, old_st, old_dim in zip(check_sts, old_sts, old_sh)
72+
]
7073
)
7174
return new_sts if valid else None
7275

dpctl/tensor/_usmarray.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,6 @@ cdef class usm_ndarray:
471471
PyMem_Free(self.shape_)
472472
if (self.strides_):
473473
PyMem_Free(self.strides_)
474-
print(contig_flag)
475474
self.flags_ = contig_flag
476475
self.nd_ = new_nd
477476
self.shape_ = shape_ptr

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,25 @@ def test_setitem_scalar(dtype, usm_type):
666666
)
667667

668668

669+
def test_setitem_errors():
670+
X = dpt.usm_ndarray((4,), dtype="u1")
671+
Y = dpt.usm_ndarray((4, 2), dtype="u1")
672+
with pytest.raises(ValueError):
673+
X[:] = Y
674+
with pytest.raises(ValueError):
675+
X[:] = Y[:, 0:1]
676+
X[:] = Y[None, :, 0]
677+
678+
679+
def test_setitem_different_dtypes():
680+
X = dpt.from_numpy(np.ones(10, "f4"))
681+
Y = dpt.from_numpy(np.zeros(10, "f4"))
682+
Z = dpt.usm_ndarray((20,), "d")
683+
Z[::2] = X
684+
Z[1::2] = Y
685+
assert np.allclose(dpt.asnumpy(Z), np.tile(np.array([1, 0], "d"), 10))
686+
687+
669688
def test_shape_setter():
670689
def cc_strides(sh):
671690
return np.empty(sh, dtype="u1").strides

0 commit comments

Comments
 (0)