@@ -11,20 +11,29 @@ def _check(hay_stack, needles, needles_np):
1111 assert hay_stack .dtype == needles .dtype
1212 assert hay_stack .ndim == 1
1313
14+ info_ = dpt .__array_namespace_info__ ()
15+ default_dts_dev = info_ .default_dtypes (hay_stack .device )
16+ index_dt = default_dts_dev ["indexing" ]
17+
1418 p_left = dpt .searchsorted (hay_stack , needles , side = "left" )
19+ assert p_left .dtype == index_dt
1520
1621 hs_np = dpt .asnumpy (hay_stack )
1722 ref_left = np .searchsorted (hs_np , needles_np , side = "left" )
1823 assert dpt .all (p_left == dpt .asarray (ref_left ))
1924
2025 p_right = dpt .searchsorted (hay_stack , needles , side = "right" )
26+ assert p_right .dtype == index_dt
27+
2128 ref_right = np .searchsorted (hs_np , needles_np , side = "right" )
2229 assert dpt .all (p_right == dpt .asarray (ref_right ))
2330
2431 sorter = dpt .arange (hay_stack .size )
2532 ps_left = dpt .searchsorted (hay_stack , needles , side = "left" , sorter = sorter )
33+ assert ps_left .dtype == index_dt
2634 assert dpt .all (ps_left == p_left )
2735 ps_right = dpt .searchsorted (hay_stack , needles , side = "right" , sorter = sorter )
36+ assert ps_right .dtype == index_dt
2837 assert dpt .all (ps_right == p_right )
2938
3039
0 commit comments