Skip to content

Commit b9c560a

Browse files
committed
Added tests for tril() and triu() functions
1 parent d9371ff commit b9c560a

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,159 @@ def test_eye(dtype, usm_kind):
12741274
assert np.array_equal(Xnp, dpt.asnumpy(X))
12751275

12761276

1277+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
1278+
def test_tril(dtype):
1279+
try:
1280+
q = dpctl.SyclQueue()
1281+
except dpctl.SyclQueueCreationError:
1282+
pytest.skip("Queue could not be created")
1283+
1284+
if dtype in ["f8", "c16"] and q.sycl_device.has_aspect_fp64 is False:
1285+
pytest.skip(
1286+
"Device does not support double precision floating point type"
1287+
)
1288+
shape = (2, 3, 4, 5, 5)
1289+
X = dpt.reshape(
1290+
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
1291+
)
1292+
Y = dpt.tril(X)
1293+
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
1294+
Ynp = np.tril(Xnp)
1295+
assert Y.dtype == Ynp.dtype
1296+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1297+
1298+
1299+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
1300+
def test_triu(dtype):
1301+
try:
1302+
q = dpctl.SyclQueue()
1303+
except dpctl.SyclQueueCreationError:
1304+
pytest.skip("Queue could not be created")
1305+
1306+
if dtype in ["f8", "c16"] and q.sycl_device.has_aspect_fp64 is False:
1307+
pytest.skip(
1308+
"Device does not support double precision floating point type"
1309+
)
1310+
shape = (4, 5)
1311+
X = dpt.reshape(
1312+
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
1313+
)
1314+
Y = dpt.triu(X, 1)
1315+
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
1316+
Ynp = np.triu(Xnp, 1)
1317+
assert Y.dtype == Ynp.dtype
1318+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1319+
1320+
1321+
def test_tril_slice():
1322+
try:
1323+
q = dpctl.SyclQueue()
1324+
except dpctl.SyclQueueCreationError:
1325+
pytest.skip("Queue could not be created")
1326+
shape = (6, 10)
1327+
X = dpt.reshape(
1328+
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
1329+
)[1:, ::-2]
1330+
Y = dpt.tril(X)
1331+
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape)[1:, ::-2]
1332+
Ynp = np.tril(Xnp)
1333+
assert Y.dtype == Ynp.dtype
1334+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1335+
1336+
1337+
def test_triu_permute_dims():
1338+
try:
1339+
q = dpctl.SyclQueue()
1340+
except dpctl.SyclQueueCreationError:
1341+
pytest.skip("Queue could not be created")
1342+
1343+
shape = (2, 3, 4, 5)
1344+
X = dpt.permute_dims(
1345+
dpt.reshape(
1346+
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
1347+
),
1348+
(3, 2, 1, 0),
1349+
)
1350+
Y = dpt.triu(X)
1351+
Xnp = np.transpose(
1352+
np.arange(np.prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0)
1353+
)
1354+
Ynp = np.triu(Xnp)
1355+
assert Y.dtype == Ynp.dtype
1356+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1357+
1358+
1359+
def test_tril_broadcast_to():
1360+
try:
1361+
q = dpctl.SyclQueue()
1362+
except dpctl.SyclQueueCreationError:
1363+
pytest.skip("Queue could not be created")
1364+
shape = (5, 5)
1365+
X = dpt.broadcast_to(dpt.ones((1), dtype="int", sycl_queue=q), shape)
1366+
Y = dpt.tril(X)
1367+
Xnp = np.broadcast_to(np.ones((1), dtype="int"), shape)
1368+
Ynp = np.tril(Xnp)
1369+
assert Y.dtype == Ynp.dtype
1370+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1371+
1372+
1373+
def test_triu_bool():
1374+
try:
1375+
q = dpctl.SyclQueue()
1376+
except dpctl.SyclQueueCreationError:
1377+
pytest.skip("Queue could not be created")
1378+
1379+
shape = (4, 5)
1380+
X = dpt.ones((shape), dtype="bool", sycl_queue=q)
1381+
Y = dpt.triu(X)
1382+
Xnp = np.ones((shape), dtype="bool")
1383+
Ynp = np.triu(Xnp)
1384+
assert Y.dtype == Ynp.dtype
1385+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1386+
1387+
1388+
@pytest.mark.parametrize("order", ["F", "C"])
1389+
@pytest.mark.parametrize("k", [-10, -2, -1, 3, 4, 10])
1390+
def test_triu_order_k(order, k):
1391+
try:
1392+
q = dpctl.SyclQueue()
1393+
except dpctl.SyclQueueCreationError:
1394+
pytest.skip("Queue could not be created")
1395+
shape = (3, 3)
1396+
X = dpt.reshape(
1397+
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
1398+
shape,
1399+
order=order,
1400+
)
1401+
Y = dpt.triu(X, k)
1402+
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
1403+
Ynp = np.triu(Xnp, k)
1404+
assert Y.dtype == Ynp.dtype
1405+
assert X.flags == Y.flags
1406+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1407+
1408+
1409+
@pytest.mark.parametrize("order", ["F", "C"])
1410+
@pytest.mark.parametrize("k", [-10, -4, -3, 1, 2, 10])
1411+
def test_tril_order_k(order, k):
1412+
try:
1413+
q = dpctl.SyclQueue()
1414+
except dpctl.SyclQueueCreationError:
1415+
pytest.skip("Queue could not be created")
1416+
shape = (3, 3)
1417+
X = dpt.reshape(
1418+
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
1419+
shape,
1420+
order=order,
1421+
)
1422+
Y = dpt.tril(X, k)
1423+
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
1424+
Ynp = np.tril(Xnp, k)
1425+
assert Y.dtype == Ynp.dtype
1426+
assert X.flags == Y.flags
1427+
assert np.array_equal(Ynp, dpt.asnumpy(Y))
1428+
1429+
12771430
def test_common_arg_validation():
12781431
order = "I"
12791432
# invalid order must raise ValueError
@@ -1306,3 +1459,7 @@ def test_common_arg_validation():
13061459
dpt.ones_like(X)
13071460
with pytest.raises(TypeError):
13081461
dpt.full_like(X, 1)
1462+
with pytest.raises(TypeError):
1463+
dpt.tril(X)
1464+
with pytest.raises(TypeError):
1465+
dpt.triu(X)

0 commit comments

Comments
 (0)