Skip to content

Commit 8c48886

Browse files
Added tests for True/False indexing of usm_ndarray
1 parent eccd026 commit 8c48886

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,32 @@ def test_integer_strided_indexing():
455455
assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all()
456456

457457

458+
def test_TrueFalse_indexing():
459+
get_queue_or_skip()
460+
n0, n1 = 2, 3
461+
x = dpt.ones((n0, n1))
462+
for ind in [True, dpt.asarray(True)]:
463+
y1 = x[ind]
464+
assert y1.shape == (1, n0, n1)
465+
assert y1._pointer == x._pointer
466+
y2 = x[:, ind]
467+
assert y2.shape == (n0, 1, n1)
468+
assert y2._pointer == x._pointer
469+
y3 = x[..., ind]
470+
assert y3.shape == (n0, n1, 1)
471+
assert y3._pointer == x._pointer
472+
for ind in [False, dpt.asarray(False)]:
473+
y1 = x[ind]
474+
assert y1.shape == (0, n0, n1)
475+
assert y1._pointer == x._pointer
476+
y2 = x[:, ind]
477+
assert y2.shape == (n0, 0, n1)
478+
assert y2._pointer == x._pointer
479+
y3 = x[..., ind]
480+
assert y3.shape == (n0, n1, 0)
481+
assert y3._pointer == x._pointer
482+
483+
458484
@pytest.mark.parametrize(
459485
"data_dt",
460486
_all_dtypes,

0 commit comments

Comments
 (0)