Skip to content

Commit cb138ce

Browse files
committed
Adds test for correct order="K" behavior in where
1 parent 91e646d commit cb138ce

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,39 @@ def test_where_compute_follows_data():
370370
dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q3), x1, x2)
371371
with pytest.raises(ExecutionPlacementError):
372372
dpt.where(x1, x1, x2)
373+
374+
375+
def test_where_order():
376+
get_queue_or_skip()
377+
378+
test_sh = (
379+
20,
380+
20,
381+
)
382+
test_sh2 = tuple(2 * dim for dim in test_sh)
383+
n = test_sh[-1]
384+
385+
for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]):
386+
ar1 = dpt.zeros(test_sh, dtype=dt1, order="C")
387+
ar2 = dpt.ones(test_sh, dtype=dt2, order="C")
388+
condition = dpt.zeros(test_sh, dtype="?", order="C")
389+
res = dpt.where(condition, ar1, ar2)
390+
assert res.flags.c_contiguous
391+
392+
ar1 = dpt.ones(test_sh, dtype=dt1, order="F")
393+
ar2 = dpt.ones(test_sh, dtype=dt2, order="F")
394+
condition = dpt.zeros(test_sh, dtype="?", order="F")
395+
res = dpt.where(condition, ar1, ar2)
396+
assert res.flags.f_contiguous
397+
398+
ar1 = dpt.ones(test_sh2, dtype=dt1, order="C")[:20, ::-2]
399+
ar2 = dpt.ones(test_sh2, dtype=dt2, order="C")[:20, ::-2]
400+
condition = dpt.zeros(test_sh2, dtype="?", order="C")[:20, ::-2]
401+
res = dpt.where(condition, ar1, ar2)
402+
assert res.strides == (n, -1)
403+
404+
ar1 = dpt.ones(test_sh2, dtype=dt1, order="C")[:20, ::-2].mT
405+
ar2 = dpt.ones(test_sh2, dtype=dt2, order="C")[:20, ::-2].mT
406+
condition = dpt.zeros(test_sh2, dtype="?", order="C")[:20, ::-2].mT
407+
res = dpt.where(condition, ar1, ar2)
408+
assert res.strides == (-1, n)

0 commit comments

Comments
 (0)