Skip to content

Commit

Permalink
Merge pull request #1598 from IntelPython/searchsorted-always-return-…
Browse files Browse the repository at this point in the history
…indexing-type

Output of searchsorted must always have default indexing data type
  • Loading branch information
oleksandr-pavlyk authored Mar 20, 2024
2 parents f3caaa1 + f92bfc4 commit ccb64a5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
10 changes: 6 additions & 4 deletions dpctl/tensor/_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from ._copy_utils import _empty_like_orderK
from ._ctors import empty
from ._data_types import int32, int64
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
from ._tensor_impl import _take as ti_take
from ._tensor_impl import (
default_device_index_type as ti_default_device_index_type,
)
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
from ._type_utils import iinfo, isdtype, result_type
from ._type_utils import isdtype, result_type
from ._usmarray import usm_ndarray


Expand Down Expand Up @@ -141,9 +143,9 @@ def searchsorted(
x2 = x2_buf

dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
dst_dt = int32 if x2.size <= iinfo(int32).max else int64
index_dt = ti_default_device_index_type(q)

dst = _empty_like_orderK(x2, dst_dt, usm_type=dst_usm_type)
dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)

if side == "left":
ht_ev, _ = _searchsorted_left(
Expand Down
9 changes: 9 additions & 0 deletions dpctl/tests/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,29 @@ def _check(hay_stack, needles, needles_np):
assert hay_stack.dtype == needles.dtype
assert hay_stack.ndim == 1

info_ = dpt.__array_namespace_info__()
default_dts_dev = info_.default_dtypes(hay_stack.device)
index_dt = default_dts_dev["indexing"]

p_left = dpt.searchsorted(hay_stack, needles, side="left")
assert p_left.dtype == index_dt

hs_np = dpt.asnumpy(hay_stack)
ref_left = np.searchsorted(hs_np, needles_np, side="left")
assert dpt.all(p_left == dpt.asarray(ref_left))

p_right = dpt.searchsorted(hay_stack, needles, side="right")
assert p_right.dtype == index_dt

ref_right = np.searchsorted(hs_np, needles_np, side="right")
assert dpt.all(p_right == dpt.asarray(ref_right))

sorter = dpt.arange(hay_stack.size)
ps_left = dpt.searchsorted(hay_stack, needles, side="left", sorter=sorter)
assert ps_left.dtype == index_dt
assert dpt.all(ps_left == p_left)
ps_right = dpt.searchsorted(hay_stack, needles, side="right", sorter=sorter)
assert ps_right.dtype == index_dt
assert dpt.all(ps_right == p_right)


Expand Down

0 comments on commit ccb64a5

Please sign in to comment.