Skip to content

Commit ec8509f

Browse files
Added tests for tril/triu usm_type, queue
1 parent e0b79ff commit ec8509f

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,28 @@ def test_triu(dtype):
15031503
assert np.array_equal(Ynp, dpt.asnumpy(Y))
15041504

15051505

1506+
@pytest.mark.parametrize("tri_fn", [dpt.tril, dpt.triu])
1507+
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
1508+
def test_tri_usm_type(tri_fn, usm_type):
1509+
q = get_queue_or_skip()
1510+
dtype = dpt.uint16
1511+
1512+
shape = (2, 3, 4, 5, 5)
1513+
size = np.prod(shape)
1514+
X = dpt.reshape(
1515+
dpt.arange(size, dtype=dtype, usm_type=usm_type, sycl_queue=q), shape
1516+
)
1517+
Y = tri_fn(X) # main execution branch
1518+
assert Y.usm_type == X.usm_type
1519+
assert Y.sycl_queue == q
1520+
Y = tri_fn(X, k=-6) # special case of Y == X
1521+
assert Y.usm_type == X.usm_type
1522+
assert Y.sycl_queue == q
1523+
Y = tri_fn(X, k=6) # special case of Y == 0
1524+
assert Y.usm_type == X.usm_type
1525+
assert Y.sycl_queue == q
1526+
1527+
15061528
def test_tril_slice():
15071529
q = get_queue_or_skip()
15081530

0 commit comments

Comments
 (0)