Skip to content

Commit c162c2a

Browse files
Added tests for assignment of Pythons scalar.
1 parent d6733b9 commit c162c2a

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,3 +1206,15 @@ def test_nonzero():
12061206
x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3)))
12071207
(i,) = dpt.nonzero(x)
12081208
assert (dpt.asnumpy(i) == np.array([3, 4, 5, 6])).all()
1209+
1210+
1211+
def test_assign_scalar():
1212+
get_queue_or_skip()
1213+
x = dpt.arange(-5, 5, dtype="i8")
1214+
cond = dpt.asarray(
1215+
[True, True, True, True, True, False, False, False, False, False]
1216+
)
1217+
x[cond] = 0 # no error expected
1218+
x[dpt.nonzero(cond)] = -1
1219+
expected = np.array([-1, -1, -1, -1, -1, 0, 1, 2, 3, 4], dtype=x.dtype)
1220+
assert (dpt.asnumpy(x) == expected).all()

0 commit comments

Comments
 (0)