Skip to content

Commit 2315489

Browse files
One more test added, to cover different dtypes copy with strides
1 parent 346392e commit 2315489

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,6 @@ def copy_same_shape(dst, src):
194194
sh_i, dst_st, dst_disp, src_st, src_disp = contract_iter2(
195195
dst.shape, dst.strides, src.strides
196196
)
197-
# sh_i, dst_st, dst_disp, src_st, src_disp = (
198-
# dst.shape, dst.strides, 0, src.strides, 0
199-
# )
200197
src_iface = src.__sycl_usm_array_interface__
201198
dst_iface = dst.__sycl_usm_array_interface__
202199
src_iface["shape"] = tuple()

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,15 @@ def test_setitem_errors():
676676
X[:] = Y[None, :, 0]
677677

678678

679+
def test_setitem_different_dtypes():
680+
X = dpt.from_numpy(np.ones(10, "f4"))
681+
Y = dpt.from_numpy(np.zeros(10, "f4"))
682+
Z = dpt.usm_ndarray((20,), "d")
683+
Z[::2] = X
684+
Z[1::2] = Y
685+
assert np.allclose(dpt.asnumpy(Z), np.tile(np.array([1, 0], "d"), 10))
686+
687+
679688
def test_shape_setter():
680689
def cc_strides(sh):
681690
return np.empty(sh, dtype="u1").strides

0 commit comments

Comments
 (0)