Skip to content

Commit

Permalink
implement sort and argsort (IntelPython#1660)
Browse files Browse the repository at this point in the history
* implement sort and argsort

* add more tests

* update for zero dimensional arrays

* address comments

* fix typo
  • Loading branch information
vtavana authored Jan 22, 2024
1 parent b401ae9 commit 8072622
Show file tree
Hide file tree
Showing 18 changed files with 365 additions and 273 deletions.
6 changes: 1 addition & 5 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ enum class DPNPFuncName : size_t
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
parameters */
DPNP_FN_AROUND, /**< Used in numpy.around() impl */
DPNP_FN_ASTYPE, /**< Used in numpy.astype() impl */
DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() impl */
Expand Down Expand Up @@ -357,9 +355,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
DPNP_FN_SORT_EXT, /**< Used in numpy.sort() impl, requires extra parameters
*/
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
*/
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */
Expand Down
34 changes: 0 additions & 34 deletions dpnp/backend/kernels/dpnp_krnl_sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,6 @@ template <typename _DataType, typename _idx_DataType>
void (*dpnp_argsort_default_c)(void *, void *, size_t) =
dpnp_argsort_c<_DataType, _idx_DataType>;

template <typename _DataType, typename _idx_DataType>
DPCTLSyclEventRef (*dpnp_argsort_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
size_t,
const DPCTLEventVectorRef) =
dpnp_argsort_c<_DataType, _idx_DataType>;

// template void dpnp_argsort_c<double, long>(void* array1_in, void* result1,
// size_t size); template void dpnp_argsort_c<float, long>(void* array1_in,
// void* result1, size_t size); template void dpnp_argsort_c<long, long>(void*
Expand Down Expand Up @@ -471,14 +463,6 @@ void dpnp_sort_c(void *array1_in, void *result1, size_t size)
template <typename _DataType>
void (*dpnp_sort_default_c)(void *, void *, size_t) = dpnp_sort_c<_DataType>;

template <typename _DataType>
DPCTLSyclEventRef (*dpnp_sort_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
size_t,
const DPCTLEventVectorRef) =
dpnp_sort_c<_DataType>;

void func_map_init_sorting(func_map_t &fmap)
{
fmap[DPNPFuncName::DPNP_FN_ARGSORT][eft_INT][eft_INT] = {
Expand All @@ -490,15 +474,6 @@ void func_map_init_sorting(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_ARGSORT][eft_DBL][eft_DBL] = {
eft_LNG, (void *)dpnp_argsort_default_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_INT][eft_INT] = {
eft_LNG, (void *)dpnp_argsort_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_argsort_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_FLT][eft_FLT] = {
eft_LNG, (void *)dpnp_argsort_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGSORT_EXT][eft_DBL][eft_DBL] = {
eft_LNG, (void *)dpnp_argsort_ext_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_partition_default_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_LNG][eft_LNG] = {
Expand Down Expand Up @@ -550,14 +525,5 @@ void func_map_init_sorting(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_SORT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_sort_default_c<double>};

fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_sort_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_sort_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_sort_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_SORT_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_sort_ext_c<double>};

return;
}
10 changes: 0 additions & 10 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_ALLCLOSE
DPNP_FN_ALLCLOSE_EXT
DPNP_FN_ARANGE
DPNP_FN_ARGSORT
DPNP_FN_ARGSORT_EXT
DPNP_FN_CHOOSE
DPNP_FN_CHOOSE_EXT
DPNP_FN_COPY
Expand Down Expand Up @@ -175,8 +173,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_RNG_ZIPF_EXT
DPNP_FN_SEARCHSORTED
DPNP_FN_SEARCHSORTED_EXT
DPNP_FN_SORT
DPNP_FN_SORT_EXT
DPNP_FN_SVD
DPNP_FN_SVD_EXT
DPNP_FN_TRACE
Expand Down Expand Up @@ -309,12 +305,6 @@ cpdef dpnp_descriptor dpnp_fmin(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj,
dpnp_descriptor out=*, object where=*)


"""
Sorting functions
"""
cpdef dpnp_descriptor dpnp_argsort(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)

"""
Trigonometric functions
"""
Expand Down
13 changes: 0 additions & 13 deletions dpnp/dpnp_algo/dpnp_algo_sorting.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ and the rest of the library
# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file

__all__ += [
"dpnp_argsort",
"dpnp_partition",
"dpnp_searchsorted",
"dpnp_sort"
]


Expand All @@ -61,13 +59,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_searchsorted_t)(c_dpctl.DPCTLSyclQ
const c_dpctl.DPCTLEventVectorRef)


cpdef utils.dpnp_descriptor dpnp_argsort(utils.dpnp_descriptor x1):
cdef shape_type_c result_shape = x1.shape
if result_shape == ():
result_shape = (1,)
return call_fptr_1in_1out(DPNP_FN_ARGSORT_EXT, x1, result_shape)


cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None):
cdef shape_type_c shape1 = arr.shape

Expand Down Expand Up @@ -148,7 +139,3 @@ cpdef utils.dpnp_descriptor dpnp_searchsorted(utils.dpnp_descriptor arr, utils.d
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef utils.dpnp_descriptor dpnp_sort(utils.dpnp_descriptor x1):
return call_fptr_1in_1out(DPNP_FN_SORT_EXT, x1, x1.shape)
70 changes: 34 additions & 36 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,39 +510,7 @@ def argsort(self, axis=-1, kind=None, order=None):
"""
Return an ndarray of indices that sort the array along the specified axis.
Parameters
----------
axis : int, optional
Axis along which to sort. If None, the default, the flattened array
is used.
.. versionchanged:: 1.13.0
Previously, the default was documented to be -1, but that was
in error. At some future date, the default will change to -1, as
originally intended.
Until then, the axis should be given explicitly when
``arr.ndim > 1``, to avoid a FutureWarning.
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional
The sorting algorithm used.
order : list, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. Not all fields need be
specified.
Returns
-------
index_array : ndarray, int
Array of indices that sort `a` along the specified axis.
In other words, ``a[index_array]`` yields a sorted `a`.
See Also
--------
MaskedArray.sort : Describes sorting algorithms used.
:obj:`dpnp.lexsort` : Indirect stable sort with multiple keys.
:obj:`numpy.ndarray.sort` : Inplace sort.
Notes
-----
See `sort` for notes on the different sorting algorithms.
Refer to :obj:`dpnp.argsort` for full documentation.
"""
return dpnp.argsort(self, axis, kind, order)
Expand Down Expand Up @@ -1163,14 +1131,44 @@ def size(self):

return self._array_obj.size

# 'sort',
def sort(self, axis=-1, kind=None, order=None):
"""
Sort an array in-place.
Refer to :obj:`dpnp.sort` for full documentation.
Note
----
`axis` in :obj:`dpnp.sort` could be integr or ``None``. If ``None``,
the array is flattened before sorting. However, `axis` in :obj:`dpnp.ndarray.sort`
can only be integer since it sorts an array in-place.
Examples
--------
>>> import dpnp as np
>>> a = np.array([[1,4],[3,1]])
>>> a.sort(axis=1)
>>> a
array([[1, 4],
[1, 3]])
>>> a.sort(axis=0)
>>> a
array([[1, 1],
[3, 4]])
"""

if axis is None:
raise TypeError(
"'NoneType' object cannot be interpreted as an integer"
)
self[...] = dpnp.sort(self, axis=axis, kind=kind, order=order)

def squeeze(self, axis=None):
"""
Remove single-dimensional entries from the shape of an array.
.. seealso::
:obj:`dpnp.squeeze` for full documentation
Refer to :obj:`dpnp.squeeze` for full documentation
"""

Expand Down
1 change: 1 addition & 0 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,7 @@ def take_along_axis(a, indices, axis):
--------
:obj:`dpnp.take` : Take along an axis, using the same indices for every 1d slice.
:obj:`dpnp.put_along_axis` : Put values into the destination array by matching 1d index and data slices.
:obj:`dpnp.argsort` : Return the indices that would sort an array.
Examples
--------
Expand Down
4 changes: 2 additions & 2 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2709,7 +2709,7 @@ def sum(
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}:
a : {dpnp.ndarray, usm_ndarray}
Input array.
axis : int or tuple of ints, optional
Axis or axes along which sums must be computed. If a tuple
Expand Down Expand Up @@ -2762,7 +2762,7 @@ def sum(
Limitations
-----------
Parameters `initial` and `where` are supported with their default values.
Parameters `initial` and `where` are only supported with their default values.
Otherwise ``NotImplementedError`` exception will be raised.
See Also
Expand Down
10 changes: 5 additions & 5 deletions dpnp/dpnp_iface_nanfunctions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# *****************************************************************************
# Copyright (c) 2016-2024, Intel Corporation
# Copyright (c) 2023-2024, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -415,7 +415,7 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}:
a : {dpnp.ndarray, usm_ndarray}
Input array.
axis : int or tuple of ints, optional
Axis or axes along which the arithmetic means must be computed. If
Expand Down Expand Up @@ -696,7 +696,7 @@ def nansum(
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}:
a : {dpnp.ndarray, usm_ndarray}
Input array.
axis : int or tuple of ints, optional
Axis or axes along which sums must be computed. If a tuple
Expand Down Expand Up @@ -806,7 +806,7 @@ def nanstd(
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}:
a : {dpnp.ndarray, usm_ndarray}
Input array.
axis : int or tuple of ints, optional
Axis or axes along which the standard deviations must be computed.
Expand Down Expand Up @@ -908,7 +908,7 @@ def nanvar(
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}:
a : {dpnp_array, usm_ndarray}
Input array.
axis : int or tuple of ints, optional
axis or axes along which the variances must be computed. If a tuple
Expand Down
Loading

0 comments on commit 8072622

Please sign in to comment.