Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dpnp.searchsorted #1751

Merged
merged 8 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,11 @@ enum class DPNPFuncName : size_t
DPNP_FN_RNG_ZIPF_EXT, /**< Used in numpy.random.zipf() impl, requires extra
parameters */
DPNP_FN_SEARCHSORTED, /**< Used in numpy.searchsorted() impl */
DPNP_FN_SEARCHSORTED_EXT, /**< Used in numpy.searchsorted() impl, requires
extra parameters */
DPNP_FN_SIGN, /**< Used in numpy.sign() impl */
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
DPNP_FN_SIGN, /**< Used in numpy.sign() impl */
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
*/
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */
Expand Down
20 changes: 0 additions & 20 deletions dpnp/backend/kernels/dpnp_krnl_sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,17 +403,6 @@ void (*dpnp_searchsorted_default_c)(void *,
const size_t) =
dpnp_searchsorted_c<_DataType, _IndexingType>;

template <typename _DataType, typename _IndexingType>
DPCTLSyclEventRef (*dpnp_searchsorted_ext_c)(DPCTLSyclQueueRef,
void *,
const void *,
const void *,
bool,
const size_t,
const size_t,
const DPCTLEventVectorRef) =
dpnp_searchsorted_c<_DataType, _IndexingType>;

template <typename _DataType>
class dpnp_sort_c_kernel;

Expand Down Expand Up @@ -507,15 +496,6 @@ void func_map_init_sorting(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_searchsorted_default_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_searchsorted_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_searchsorted_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_searchsorted_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_searchsorted_ext_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_SORT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_sort_default_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_SORT][eft_LNG][eft_LNG] = {
Expand Down
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_RNG_WALD_EXT
DPNP_FN_RNG_WEIBULL_EXT
DPNP_FN_RNG_ZIPF_EXT
DPNP_FN_SEARCHSORTED_EXT
DPNP_FN_TRACE_EXT
DPNP_FN_TRAPZ_EXT

Expand Down
50 changes: 0 additions & 50 deletions dpnp/dpnp_algo/dpnp_algo_sorting.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ and the rest of the library

__all__ += [
"dpnp_partition",
"dpnp_searchsorted",
]


Expand All @@ -49,14 +48,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_partition_t)(c_dpctl.DPCTLSyclQueu
const shape_elem_type * ,
const size_t,
const c_dpctl.DPCTLEventVectorRef)
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_searchsorted_t)(c_dpctl.DPCTLSyclQueueRef,
void * ,
const void * ,
const void * ,
bool,
const size_t,
const size_t,
const c_dpctl.DPCTLEventVectorRef)


cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None):
Expand Down Expand Up @@ -98,44 +89,3 @@ cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, a
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef utils.dpnp_descriptor dpnp_searchsorted(utils.dpnp_descriptor arr, utils.dpnp_descriptor v, side='left'):
if side is 'left':
side_ = True
else:
side_ = False

cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)

cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SEARCHSORTED_EXT, param1_type, param1_type)

arr_obj = arr.get_array()

cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(v.shape,
dpnp.int64,
None,
device=arr_obj.sycl_device,
usm_type=arr_obj.usm_type,
sycl_queue=arr_obj.sycl_queue)

result_sycl_queue = result.get_array().sycl_queue

cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_dpnp_searchsorted_t func = <fptr_dpnp_searchsorted_t > kernel_data.ptr

cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
arr.get_data(),
v.get_data(),
result.get_data(),
side_,
arr.size,
v.size,
NULL) # dep_events_ref

with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)

return result
12 changes: 11 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,17 @@ def round(self, decimals=0, out=None):

return dpnp.around(self, decimals, out)

# 'searchsorted',
def searchsorted(self, v, side="left", sorter=None):
"""
Find indices where elements of `v` should be inserted in `a`
to maintain order.

Refer to :obj:`dpnp.searchsorted` for full documentation

"""

return dpnp.searchsorted(self, v, side=side, sorter=sorter)

# 'setfield',
# 'setflags',

Expand Down
54 changes: 53 additions & 1 deletion dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,61 @@ def searchsorted(a, v, side="left", sorter=None):

For full documentation refer to :obj:`numpy.searchsorted`.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Input 1-D array. If `sorter` is ``None``, then it must be sorted in
ascending order, otherwise `sorter` must be an array of indices that
sort it.
v : {dpnp.ndarray, usm_ndarray, scalar}
Values to insert into `a`.
side : {'left', 'right'}, optional
If ``'left'``, the index of the first suitable location found is given.
If ``'right'``, return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `a`).
Default is ``'left'``.
sorter : {dpnp.ndarray, usm_ndarray}, optional
Optional 1-D array of integer indices that sort array a into ascending
order. They are typically the result of argsort.
Out of bound index values of `sorter` array are treated using `"wrap"`
mode documented in :py:func:`dpnp.take`.
Default is ``None``.

Returns
-------
indices : dpnp.ndarray
Array of insertion points with the same shape as `v`,
or 0-D array if `v` is a scalar.

See Also
--------
:obj:`dpnp.sort` : Return a sorted copy of an array.
:obj:`dpnp.histogram` : Produce histogram from 1-D data.

Examples
--------
>>> import dpnp as np
>>> a = np.array([11,12,13,14,15])
>>> np.searchsorted(a, 13)
array(2)
>>> np.searchsorted(a, 13, side='right')
array(3)
>>> v = np.array([-10, 20, 12, 13])
>>> np.searchsorted(a, v)
array([0, 5, 1, 2])

"""

return call_origin(numpy.where, a, v, side, sorter)
usm_a = dpnp.get_usm_ndarray(a)
if dpnp.isscalar(v):
usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
else:
usm_v = dpnp.get_usm_ndarray(v)

usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
return dpnp_array._create_from_usm_ndarray(
dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
)


def where(condition, x=None, y=None, /):
Expand Down
38 changes: 1 addition & 37 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@
# pylint: disable=no-name-in-module
from .dpnp_algo import (
dpnp_partition,
dpnp_searchsorted,
)
from .dpnp_array import dpnp_array
from .dpnp_utils import (
call_origin,
)

__all__ = ["argsort", "partition", "searchsorted", "sort"]
__all__ = ["argsort", "partition", "sort"]


def argsort(a, axis=-1, kind=None, order=None):
Expand Down Expand Up @@ -189,41 +188,6 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
return call_origin(numpy.partition, x1, kth, axis, kind, order)


def searchsorted(x1, x2, side="left", sorter=None):
"""
Find indices where elements should be inserted to maintain order.

For full documentation refer to :obj:`numpy.searchsorted`.

Limitations
-----------
Input arrays is supported as :obj:`dpnp.ndarray`.
Input array is supported only sorted.
Input side is supported only values ``left``, ``right``.
Parameter `sorter` is supported only with default values.

"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
# pylint: disable=condition-evals-to-constant
if 0 and x1_desc and x2_desc:
if x1_desc.ndim != 1:
pass
elif x1_desc.dtype != x2_desc.dtype:
pass
elif side not in ["left", "right"]:
pass
elif sorter is not None:
pass
elif x1_desc.size < 2:
pass
else:
return dpnp_searchsorted(x1_desc, x2_desc, side=side).get_pyobj()

return call_origin(numpy.searchsorted, x1, x2, side=side, sorter=sorter)


def sort(a, axis=-1, kind=None, order=None):
"""
Return a sorted copy of an array.
Expand Down
22 changes: 0 additions & 22 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -701,28 +701,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2

tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}]

tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2]

tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4]

tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero

tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2
Expand Down
22 changes: 0 additions & 22 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -763,28 +763,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2

tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}]

tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2]

tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4]

tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero

tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2
Expand Down
Loading
Loading