@@ -164,7 +164,7 @@ def test_basic_slice10():
164164
165165
166166def _all_equal (it1 , it2 ):
167- return all (dpt . asnumpy ( x ) == dpt . asnumpy ( y ) for x , y in zip (it1 , it2 ))
167+ return all (bool ( x == y ) for x , y in zip (it1 , it2 ))
168168
169169
170170def test_advanced_slice1 ():
@@ -175,8 +175,6 @@ def test_advanced_slice1():
175175 assert isinstance (y , dpt .usm_ndarray )
176176 assert y .shape == ii .shape
177177 assert y .strides == (1 ,)
178- # FIXME, once usm_ndarray.__equal__ is implemented,
179- # use of asnumpy should be removed
180178 assert _all_equal (
181179 (x [ii [k ]] for k in range (ii .shape [0 ])),
182180 (y [k ] for k in range (ii .shape [0 ])),
@@ -185,8 +183,6 @@ def test_advanced_slice1():
185183 assert isinstance (y , dpt .usm_ndarray )
186184 assert y .shape == ii .shape
187185 assert y .strides == (1 ,)
188- # FIXME, once usm_ndarray.__equal__ is implemented,
189- # use of asnumpy should be removed
190186 assert _all_equal (
191187 (x [ii [k ]] for k in range (ii .shape [0 ])),
192188 (y [k ] for k in range (ii .shape [0 ])),
@@ -201,8 +197,6 @@ def test_advanced_slice1_negative_strides():
201197 assert isinstance (y , dpt .usm_ndarray )
202198 assert y .shape == ii .shape
203199 assert y .strides == (1 ,)
204- # FIXME, once usm_ndarray.__equal__ is implemented,
205- # use of asnumpy should be removed
206200 assert _all_equal (
207201 (x [ii [k ]] for k in range (ii .shape [0 ])),
208202 (y [k ] for k in range (ii .shape [0 ])),
@@ -400,6 +394,16 @@ def test_advanced_slice13():
400394 assert (dpt .asnumpy (y ) == dpt .asnumpy (expected )).all ()
401395
402396
397+ def test_boolean_indexing_validation ():
398+ get_queue_or_skip ()
399+ x = dpt .zeros (10 , dtype = "i4" )
400+ ii = dpt .ones ((2 , 5 ), dtype = "?" )
401+ with pytest .raises (IndexError ):
402+ x [ii ]
403+ with pytest .raises (IndexError ):
404+ x [ii [0 , :]]
405+
406+
403407def test_integer_indexing_1d ():
404408 get_queue_or_skip ()
405409 x = dpt .arange (10 , dtype = "i4" )
@@ -482,6 +486,32 @@ def test_TrueFalse_indexing():
482486 assert y3 ._pointer == x ._pointer
483487
484488
489+ def test_mixed_index_getitem ():
490+ get_queue_or_skip ()
491+ x = dpt .reshape (dpt .arange (1000 , dtype = "i4" ), (10 , 10 , 10 ))
492+ i1b = dpt .ones (10 , dtype = "?" )
493+ info = x .__array_namespace__ ().__array_namespace_info__ ()
494+ ind_dt = info .default_dtypes (x .device )["indexing" ]
495+ i0 = dpt .asarray ([0 , 2 , 3 ], dtype = ind_dt )[:, dpt .newaxis ]
496+ i2 = dpt .asarray ([3 , 4 , 7 ], dtype = ind_dt )[:, dpt .newaxis ]
497+ y = x [i0 , i1b , i2 ]
498+ assert y .shape == (3 , dpt .sum (i1b , dtype = "i8" ))
499+
500+
501+ def test_mixed_index_setitem ():
502+ get_queue_or_skip ()
503+ x = dpt .reshape (dpt .arange (1000 , dtype = "i4" ), (10 , 10 , 10 ))
504+ i1b = dpt .ones (10 , dtype = "?" )
505+ info = x .__array_namespace__ ().__array_namespace_info__ ()
506+ ind_dt = info .default_dtypes (x .device )["indexing" ]
507+ i0 = dpt .asarray ([0 , 2 , 3 ], dtype = ind_dt )[:, dpt .newaxis ]
508+ i2 = dpt .asarray ([3 , 4 , 7 ], dtype = ind_dt )[:, dpt .newaxis ]
509+ v_shape = (3 , int (dpt .sum (i1b , dtype = "i8" )))
510+ canary = 7
511+ x [i0 , i1b , i2 ] = dpt .full (v_shape , canary , dtype = x .dtype )
512+ assert x [0 , 0 , 3 ] == canary
513+
514+
485515@pytest .mark .parametrize (
486516 "data_dt" ,
487517 _all_dtypes ,
0 commit comments