File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -455,6 +455,32 @@ def test_integer_strided_indexing():
455
455
assert (dpt .asnumpy (y ) == dpt .asnumpy (yc )).all ()
456
456
457
457
458
+ def test_TrueFalse_indexing ():
459
+ get_queue_or_skip ()
460
+ n0 , n1 = 2 , 3
461
+ x = dpt .ones ((n0 , n1 ))
462
+ for ind in [True , dpt .asarray (True )]:
463
+ y1 = x [ind ]
464
+ assert y1 .shape == (1 , n0 , n1 )
465
+ assert y1 ._pointer == x ._pointer
466
+ y2 = x [:, ind ]
467
+ assert y2 .shape == (n0 , 1 , n1 )
468
+ assert y2 ._pointer == x ._pointer
469
+ y3 = x [..., ind ]
470
+ assert y3 .shape == (n0 , n1 , 1 )
471
+ assert y3 ._pointer == x ._pointer
472
+ for ind in [False , dpt .asarray (False )]:
473
+ y1 = x [ind ]
474
+ assert y1 .shape == (0 , n0 , n1 )
475
+ assert y1 ._pointer == x ._pointer
476
+ y2 = x [:, ind ]
477
+ assert y2 .shape == (n0 , 0 , n1 )
478
+ assert y2 ._pointer == x ._pointer
479
+ y3 = x [..., ind ]
480
+ assert y3 .shape == (n0 , n1 , 0 )
481
+ assert y3 ._pointer == x ._pointer
482
+
483
+
458
484
@pytest .mark .parametrize (
459
485
"data_dt" ,
460
486
_all_dtypes ,
You can’t perform that action at this time.
0 commit comments