Skip to content

Commit 5c1a961

Browse files
Fixed array API test failure by adding validation
1 parent cc08b5d commit 5c1a961

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,8 @@ def _extract_impl(ary, ary_mask, axis=0):
492492
dst = dpt.empty(
493493
dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device
494494
)
495+
if dst.size == 0:
496+
return dst
495497
hev, _ = ti._extract(
496498
src=ary,
497499
cumsum=cumsum,

dpctl/tensor/_usmarray.pyx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,8 @@ cdef class usm_ndarray:
764764
ind, (<object>self).shape, (<object> self).strides,
765765
self.get_offset())
766766
cdef usm_ndarray res
767+
cdef int i = 0
768+
cdef bint matching = 1
767769

768770
if len(_meta) < 5:
769771
raise RuntimeError
@@ -787,7 +789,20 @@ cdef class usm_ndarray:
787789

788790
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
789791
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
790-
return _extract_impl(res, adv_ind[0], axis=adv_ind_start_p)
792+
key_ = adv_ind[0]
793+
adv_ind_end_p = key_.ndim + adv_ind_start_p
794+
if adv_ind_end_p > res.ndim:
795+
raise IndexError("too many indices for the array")
796+
key_shape = key_.shape
797+
arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p]
798+
for i in range(key_.ndim):
799+
if matching:
800+
if not key_shape[i] == arr_shape[i] and key_shape[i] > 0:
801+
matching = 0
802+
if not matching:
803+
raise IndexError("boolean index did not match indexed array in dimensions")
804+
res = _extract_impl(res, key_, axis=adv_ind_start_p)
805+
return res
791806

792807
if any(ind.dtype == dpt_bool for ind in adv_ind):
793808
adv_ind_int = list()
@@ -1152,6 +1167,8 @@ cdef class usm_ndarray:
11521167
if adv_ind_start_p < 0:
11531168
# basic slicing
11541169
if isinstance(rhs, usm_ndarray):
1170+
if Xv.size == 0:
1171+
return
11551172
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
11561173
else:
11571174
if hasattr(rhs, "__sycl_usm_array_interface__"):

0 commit comments

Comments
 (0)