@@ -764,6 +764,8 @@ cdef class usm_ndarray:
764
764
ind, (< object > self ).shape, (< object > self ).strides,
765
765
self .get_offset())
766
766
cdef usm_ndarray res
767
+ cdef int i = 0
768
+ cdef bint matching = 1
767
769
768
770
if len (_meta) < 5 :
769
771
raise RuntimeError
@@ -787,7 +789,20 @@ cdef class usm_ndarray:
787
789
788
790
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
789
791
if len (adv_ind) == 1 and adv_ind[0 ].dtype == dpt_bool:
790
- return _extract_impl(res, adv_ind[0 ], axis = adv_ind_start_p)
792
+ key_ = adv_ind[0 ]
793
+ adv_ind_end_p = key_.ndim + adv_ind_start_p
794
+ if adv_ind_end_p > res.ndim:
795
+ raise IndexError (" too many indices for the array" )
796
+ key_shape = key_.shape
797
+ arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p]
798
+ for i in range (key_.ndim):
799
+ if matching:
800
+ if not key_shape[i] == arr_shape[i] and key_shape[i] > 0 :
801
+ matching = 0
802
+ if not matching:
803
+ raise IndexError (" boolean index did not match indexed array in dimensions" )
804
+ res = _extract_impl(res, key_, axis = adv_ind_start_p)
805
+ return res
791
806
792
807
if any (ind.dtype == dpt_bool for ind in adv_ind):
793
808
adv_ind_int = list ()
@@ -1152,6 +1167,8 @@ cdef class usm_ndarray:
1152
1167
if adv_ind_start_p < 0 :
1153
1168
# basic slicing
1154
1169
if isinstance (rhs, usm_ndarray):
1170
+ if Xv.size == 0 :
1171
+ return
1155
1172
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
1156
1173
else :
1157
1174
if hasattr (rhs, " __sycl_usm_array_interface__" ):
0 commit comments