@@ -1274,6 +1274,159 @@ def test_eye(dtype, usm_kind):
1274
1274
assert np .array_equal (Xnp , dpt .asnumpy (X ))
1275
1275
1276
1276
1277
+ @pytest .mark .parametrize ("dtype" , _all_dtypes [1 :])
1278
+ def test_tril (dtype ):
1279
+ try :
1280
+ q = dpctl .SyclQueue ()
1281
+ except dpctl .SyclQueueCreationError :
1282
+ pytest .skip ("Queue could not be created" )
1283
+
1284
+ if dtype in ["f8" , "c16" ] and q .sycl_device .has_aspect_fp64 is False :
1285
+ pytest .skip (
1286
+ "Device does not support double precision floating point type"
1287
+ )
1288
+ shape = (2 , 3 , 4 , 5 , 5 )
1289
+ X = dpt .reshape (
1290
+ dpt .arange (np .prod (shape ), dtype = dtype , sycl_queue = q ), shape
1291
+ )
1292
+ Y = dpt .tril (X )
1293
+ Xnp = np .arange (np .prod (shape ), dtype = dtype ).reshape (shape )
1294
+ Ynp = np .tril (Xnp )
1295
+ assert Y .dtype == Ynp .dtype
1296
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1297
+
1298
+
1299
+ @pytest .mark .parametrize ("dtype" , _all_dtypes [1 :])
1300
+ def test_triu (dtype ):
1301
+ try :
1302
+ q = dpctl .SyclQueue ()
1303
+ except dpctl .SyclQueueCreationError :
1304
+ pytest .skip ("Queue could not be created" )
1305
+
1306
+ if dtype in ["f8" , "c16" ] and q .sycl_device .has_aspect_fp64 is False :
1307
+ pytest .skip (
1308
+ "Device does not support double precision floating point type"
1309
+ )
1310
+ shape = (4 , 5 )
1311
+ X = dpt .reshape (
1312
+ dpt .arange (np .prod (shape ), dtype = dtype , sycl_queue = q ), shape
1313
+ )
1314
+ Y = dpt .triu (X , 1 )
1315
+ Xnp = np .arange (np .prod (shape ), dtype = dtype ).reshape (shape )
1316
+ Ynp = np .triu (Xnp , 1 )
1317
+ assert Y .dtype == Ynp .dtype
1318
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1319
+
1320
+
1321
+ def test_tril_slice ():
1322
+ try :
1323
+ q = dpctl .SyclQueue ()
1324
+ except dpctl .SyclQueueCreationError :
1325
+ pytest .skip ("Queue could not be created" )
1326
+ shape = (6 , 10 )
1327
+ X = dpt .reshape (
1328
+ dpt .arange (np .prod (shape ), dtype = "int" , sycl_queue = q ), shape
1329
+ )[1 :, ::- 2 ]
1330
+ Y = dpt .tril (X )
1331
+ Xnp = np .arange (np .prod (shape ), dtype = "int" ).reshape (shape )[1 :, ::- 2 ]
1332
+ Ynp = np .tril (Xnp )
1333
+ assert Y .dtype == Ynp .dtype
1334
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1335
+
1336
+
1337
+ def test_triu_permute_dims ():
1338
+ try :
1339
+ q = dpctl .SyclQueue ()
1340
+ except dpctl .SyclQueueCreationError :
1341
+ pytest .skip ("Queue could not be created" )
1342
+
1343
+ shape = (2 , 3 , 4 , 5 )
1344
+ X = dpt .permute_dims (
1345
+ dpt .reshape (
1346
+ dpt .arange (np .prod (shape ), dtype = "int" , sycl_queue = q ), shape
1347
+ ),
1348
+ (3 , 2 , 1 , 0 ),
1349
+ )
1350
+ Y = dpt .triu (X )
1351
+ Xnp = np .transpose (
1352
+ np .arange (np .prod (shape ), dtype = "int" ).reshape (shape ), (3 , 2 , 1 , 0 )
1353
+ )
1354
+ Ynp = np .triu (Xnp )
1355
+ assert Y .dtype == Ynp .dtype
1356
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1357
+
1358
+
1359
+ def test_tril_broadcast_to ():
1360
+ try :
1361
+ q = dpctl .SyclQueue ()
1362
+ except dpctl .SyclQueueCreationError :
1363
+ pytest .skip ("Queue could not be created" )
1364
+ shape = (5 , 5 )
1365
+ X = dpt .broadcast_to (dpt .ones ((1 ), dtype = "int" , sycl_queue = q ), shape )
1366
+ Y = dpt .tril (X )
1367
+ Xnp = np .broadcast_to (np .ones ((1 ), dtype = "int" ), shape )
1368
+ Ynp = np .tril (Xnp )
1369
+ assert Y .dtype == Ynp .dtype
1370
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1371
+
1372
+
1373
+ def test_triu_bool ():
1374
+ try :
1375
+ q = dpctl .SyclQueue ()
1376
+ except dpctl .SyclQueueCreationError :
1377
+ pytest .skip ("Queue could not be created" )
1378
+
1379
+ shape = (4 , 5 )
1380
+ X = dpt .ones ((shape ), dtype = "bool" , sycl_queue = q )
1381
+ Y = dpt .triu (X )
1382
+ Xnp = np .ones ((shape ), dtype = "bool" )
1383
+ Ynp = np .triu (Xnp )
1384
+ assert Y .dtype == Ynp .dtype
1385
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1386
+
1387
+
1388
+ @pytest .mark .parametrize ("order" , ["F" , "C" ])
1389
+ @pytest .mark .parametrize ("k" , [- 10 , - 2 , - 1 , 3 , 4 , 10 ])
1390
+ def test_triu_order_k (order , k ):
1391
+ try :
1392
+ q = dpctl .SyclQueue ()
1393
+ except dpctl .SyclQueueCreationError :
1394
+ pytest .skip ("Queue could not be created" )
1395
+ shape = (3 , 3 )
1396
+ X = dpt .reshape (
1397
+ dpt .arange (np .prod (shape ), dtype = "int" , sycl_queue = q ),
1398
+ shape ,
1399
+ order = order ,
1400
+ )
1401
+ Y = dpt .triu (X , k )
1402
+ Xnp = np .arange (np .prod (shape ), dtype = "int" ).reshape (shape , order = order )
1403
+ Ynp = np .triu (Xnp , k )
1404
+ assert Y .dtype == Ynp .dtype
1405
+ assert X .flags == Y .flags
1406
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1407
+
1408
+
1409
+ @pytest .mark .parametrize ("order" , ["F" , "C" ])
1410
+ @pytest .mark .parametrize ("k" , [- 10 , - 4 , - 3 , 1 , 2 , 10 ])
1411
+ def test_tril_order_k (order , k ):
1412
+ try :
1413
+ q = dpctl .SyclQueue ()
1414
+ except dpctl .SyclQueueCreationError :
1415
+ pytest .skip ("Queue could not be created" )
1416
+ shape = (3 , 3 )
1417
+ X = dpt .reshape (
1418
+ dpt .arange (np .prod (shape ), dtype = "int" , sycl_queue = q ),
1419
+ shape ,
1420
+ order = order ,
1421
+ )
1422
+ Y = dpt .tril (X , k )
1423
+ Xnp = np .arange (np .prod (shape ), dtype = "int" ).reshape (shape , order = order )
1424
+ Ynp = np .tril (Xnp , k )
1425
+ assert Y .dtype == Ynp .dtype
1426
+ assert X .flags == Y .flags
1427
+ assert np .array_equal (Ynp , dpt .asnumpy (Y ))
1428
+
1429
+
1277
1430
def test_common_arg_validation ():
1278
1431
order = "I"
1279
1432
# invalid order must raise ValueError
@@ -1306,3 +1459,7 @@ def test_common_arg_validation():
1306
1459
dpt .ones_like (X )
1307
1460
with pytest .raises (TypeError ):
1308
1461
dpt .full_like (X , 1 )
1462
+ with pytest .raises (TypeError ):
1463
+ dpt .tril (X )
1464
+ with pytest .raises (TypeError ):
1465
+ dpt .triu (X )
0 commit comments