Skip to content

Commit ccb64a5

Browse files
Merge pull request #1598 from IntelPython/searchsorted-always-return-indexing-type
Output of searchsorted must always have default indexing data type
2 parents f3caaa1 + f92bfc4 commit ccb64a5

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

dpctl/tensor/_searchsorted.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
from ._copy_utils import _empty_like_orderK
77
from ._ctors import empty
8-
from ._data_types import int32, int64
98
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
109
from ._tensor_impl import _take as ti_take
10+
from ._tensor_impl import (
11+
default_device_index_type as ti_default_device_index_type,
12+
)
1113
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
12-
from ._type_utils import iinfo, isdtype, result_type
14+
from ._type_utils import isdtype, result_type
1315
from ._usmarray import usm_ndarray
1416

1517

@@ -141,9 +143,9 @@ def searchsorted(
141143
x2 = x2_buf
142144

143145
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
144-
dst_dt = int32 if x2.size <= iinfo(int32).max else int64
146+
index_dt = ti_default_device_index_type(q)
145147

146-
dst = _empty_like_orderK(x2, dst_dt, usm_type=dst_usm_type)
148+
dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
147149

148150
if side == "left":
149151
ht_ev, _ = _searchsorted_left(

dpctl/tests/test_usm_ndarray_searchsorted.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,29 @@ def _check(hay_stack, needles, needles_np):
1111
assert hay_stack.dtype == needles.dtype
1212
assert hay_stack.ndim == 1
1313

14+
info_ = dpt.__array_namespace_info__()
15+
default_dts_dev = info_.default_dtypes(hay_stack.device)
16+
index_dt = default_dts_dev["indexing"]
17+
1418
p_left = dpt.searchsorted(hay_stack, needles, side="left")
19+
assert p_left.dtype == index_dt
1520

1621
hs_np = dpt.asnumpy(hay_stack)
1722
ref_left = np.searchsorted(hs_np, needles_np, side="left")
1823
assert dpt.all(p_left == dpt.asarray(ref_left))
1924

2025
p_right = dpt.searchsorted(hay_stack, needles, side="right")
26+
assert p_right.dtype == index_dt
27+
2128
ref_right = np.searchsorted(hs_np, needles_np, side="right")
2229
assert dpt.all(p_right == dpt.asarray(ref_right))
2330

2431
sorter = dpt.arange(hay_stack.size)
2532
ps_left = dpt.searchsorted(hay_stack, needles, side="left", sorter=sorter)
33+
assert ps_left.dtype == index_dt
2634
assert dpt.all(ps_left == p_left)
2735
ps_right = dpt.searchsorted(hay_stack, needles, side="right", sorter=sorter)
36+
assert ps_right.dtype == index_dt
2837
assert dpt.all(ps_right == p_right)
2938

3039

0 commit comments

Comments
 (0)