Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Oct 27, 2023
1 parent b10aaed commit 24fe4c8
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 35 deletions.
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ set(dpnp_algo_pyx_deps
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_sorting.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_arraycreation.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_mathematical.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_searching.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_indexing.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_logic.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_special.pxi
Expand Down
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/dpnp_algo.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ include "dpnp_algo_linearalgebra.pxi"
include "dpnp_algo_logic.pxi"
include "dpnp_algo_manipulation.pxi"
include "dpnp_algo_mathematical.pxi"
include "dpnp_algo_searching.pxi"
include "dpnp_algo_sorting.pxi"
include "dpnp_algo_special.pxi"
include "dpnp_algo_statistics.pxi"
Expand Down
11 changes: 0 additions & 11 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,17 +492,6 @@ def argmax(self, axis=None, out=None):
Refer to :obj:`dpnp.argmax` for full documentation.
Examples
--------
>>> import dpnp as np
>>> a = np.arange(6).reshape(2,3)
>>> a.argmax()
array(5)
>>> a.argmax(0)
array([1, 1, 1])
>>> a.argmax(1)
array([2, 2])
"""
return dpnp.argmax(self, axis, out)

Expand Down
12 changes: 8 additions & 4 deletions dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
__all__ = ["argmax", "argmin", "searchsorted", "where"]


def argmax(a, axis=None, out=None, keepdims=False):
def argmax(a, axis=None, out=None, *, keepdims=False):
"""
Returns the indices of the maximum values along an axis.
Expand All @@ -62,11 +62,13 @@ def argmax(a, axis=None, out=None, keepdims=False):
Limitations
-----------
Input array is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
Input and output arrays are only supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Input array data types are limited by supported DPNP :ref:`Data types`.
See Also
--------
:obj:`dpnp.ndarray.argmax` : Equivalent function.
:obj:`dpnp.argmin` : Returns the indices of the minimum values along an axis.
:obj:`dpnp.amax` : The maximum value along a given axis.
:obj:`dpnp.unravel_index` : Convert a flat index into an index tuple.
Expand Down Expand Up @@ -147,7 +149,7 @@ def argmax(a, axis=None, out=None, keepdims=False):
return out


def argmin(a, axis=None, out=None, keepdims=False):
def argmin(a, axis=None, out=None, *, keepdims=False):
"""
Returns the indices of the minimum values along an axis.
Expand All @@ -160,11 +162,13 @@ def argmin(a, axis=None, out=None, keepdims=False):
Limitations
-----------
Input array is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
Input and output arrays are only supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Input array data types are limited by supported DPNP :ref:`Data types`.
See Also
--------
:obj:`dpnp.ndarray.argmin` : Equivalent function.
:obj:`dpnp.argmax` : Returns the indices of the maximum values along an axis.
:obj:`dpnp.amin` : The minimum value along a given axis.
:obj:`dpnp.unravel_index` : Convert a flat index into an index tuple.
Expand Down
36 changes: 20 additions & 16 deletions tests/third_party/cupy/core_tests/test_ndarray_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,97 +229,101 @@ def test_ptp_nan_imag(self, xp, dtype):
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_all(self, xp, dtype):
a = testing.shaped_random((2, 3), xp, dtype)
a = testing.shaped_random((2, 3), xp, dtype, order=self.order)
return a.argmax()

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_axis_large(self, xp, dtype):
a = testing.shaped_random((3, 1000), xp, dtype)
a = testing.shaped_random((3, 1000), xp, dtype, order=self.order)
return a.argmax(axis=0)

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_axis0(self, xp, dtype):
a = testing.shaped_random((2, 3, 4), xp, dtype)
a = testing.shaped_random((2, 3, 4), xp, dtype, order=self.order)
return a.argmax(axis=0)

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_axis1(self, xp, dtype):
a = testing.shaped_random((2, 3, 4), xp, dtype)
a = testing.shaped_random((2, 3, 4), xp, dtype, order=self.order)
return a.argmax(axis=1)

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_axis2(self, xp, dtype):
a = testing.shaped_random((2, 3, 4), xp, dtype)
a = testing.shaped_random((2, 3, 4), xp, dtype, order=self.order)
return a.argmax(axis=2)

@testing.for_float_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_nan(self, xp, dtype):
a = xp.array([float("nan"), 1, -1], dtype)
a = xp.array([float("nan"), 1, -1], dtype, order=self.order)
return a.argmax()

@testing.for_complex_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_nan_real(self, xp, dtype):
a = xp.array([float("nan"), 1, -1], dtype)
a = xp.array([float("nan"), 1, -1], dtype, order=self.order)
return a.argmax()

@testing.for_complex_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmax_nan_imag(self, xp, dtype):
a = xp.array([float("nan") * 1.0j, 1.0j, -1.0j], dtype)
a = xp.array(
[float("nan") * 1.0j, 1.0j, -1.0j], dtype, order=self.order
)
return a.argmax()

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_all(self, xp, dtype):
a = testing.shaped_random((2, 3), xp, dtype)
a = testing.shaped_random((2, 3), xp, dtype, order=self.order)
return a.argmin()

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_axis_large(self, xp, dtype):
a = testing.shaped_random((3, 1000), xp, dtype)
a = testing.shaped_random((3, 1000), xp, dtype, order=self.order)
return a.argmin(axis=0)

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_axis0(self, xp, dtype):
a = testing.shaped_random((2, 3, 4), xp, dtype)
a = testing.shaped_random((2, 3, 4), xp, dtype, order=self.order)
return a.argmin(axis=0)

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_axis1(self, xp, dtype):
a = testing.shaped_random((2, 3, 4), xp, dtype)
a = testing.shaped_random((2, 3, 4), xp, dtype, order=self.order)
return a.argmin(axis=1)

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_axis2(self, xp, dtype):
a = testing.shaped_random((2, 3, 4), xp, dtype)
a = testing.shaped_random((2, 3, 4), xp, dtype, order=self.order)
return a.argmin(axis=2)

@testing.for_float_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_nan(self, xp, dtype):
a = xp.array([float("nan"), 1, -1], dtype)
a = xp.array([float("nan"), 1, -1], dtype, order=self.order)
return a.argmin()

@testing.for_complex_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_nan_real(self, xp, dtype):
a = xp.array([float("nan"), 1, -1], dtype)
a = xp.array([float("nan"), 1, -1], dtype, order=self.order)
return a.argmin()

@testing.for_complex_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False)
def test_argmin_nan_imag(self, xp, dtype):
a = xp.array([float("nan") * 1.0j, 1.0j, -1.0j], dtype)
a = xp.array(
[float("nan") * 1.0j, 1.0j, -1.0j], dtype, order=self.order
)
return a.argmin()


Expand Down
4 changes: 2 additions & 2 deletions tests/third_party/cupy/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def test_argmin_all(self, xp, dtype):
a = testing.shaped_random((2, 3), xp, dtype)
return a.argmin()

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose(accept_error=ValueError)
@testing.for_float_dtypes()
@testing.numpy_cupy_allclose()
def test_argmin_nan(self, xp, dtype):
a = xp.array([float("nan"), -1, 1], dtype)
return a.argmin()
Expand Down

0 comments on commit 24fe4c8

Please sign in to comment.