Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dpnp.argmin and dpnp.argmax using dpctl.tensor #1610

Merged
merged 17 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def any(self, axis=None, out=None, keepdims=False, *, where=True):
self, axis=axis, out=out, keepdims=keepdims, where=where
)

def argmax(self, axis=None, out=None):
def argmax(self, axis=None, out=None, *, keepdims=False):
"""
Returns array of indices of the maximum values along the given axis.

Expand All @@ -495,7 +495,7 @@ def argmax(self, axis=None, out=None):
"""
return dpnp.argmax(self, axis, out)
vtavana marked this conversation as resolved.
Show resolved Hide resolved

def argmin(self, axis=None, out=None):
def argmin(self, axis=None, out=None, *, keepdims=False):
"""
Return array of indices to the minimum values along the given axis.

Expand Down
44 changes: 44 additions & 0 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"get_usm_ndarray_or_scalar",
"is_supported_array_or_scalar",
"is_supported_array_type",
"_copyto",
vtavana marked this conversation as resolved.
Show resolved Hide resolved
]

from dpnp import float64, isscalar
Expand Down Expand Up @@ -516,3 +517,46 @@ def is_supported_array_type(a):
"""

return isinstance(a, (dpnp_array, dpt.usm_ndarray))


def _copyto(a, out=None):
"""
If `out` is provided, `a` will be inserted into this array.
Otherwise, `a` is returned.
vtavana marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
a : {dpnp_array}
An input array.

out : {dpnp_array, usm_ndarray}
If provided, the input will be inserted into this array.
It should be of the appropriate shape.

Returns
-------
out : {dpnp_array}
Return `out` if provided, otherwise return `a`.

"""

if out is None:
return a
else:
if out.shape != a.shape:
raise ValueError(
f"Output array of shape {a.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)

dpnp.copyto(out, a, casting="safe")

return out
21 changes: 1 addition & 20 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,26 +2158,7 @@ def prod(
dpt.prod(dpt_array, axis=axis, dtype=dtype, keepdims=keepdims)
)

if out is None:
return result
else:
if out.shape != result.shape:
raise ValueError(
f"Output array of shape {result.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)

dpnp.copyto(out, result, casting="safe")

return out
return dpnp._copyto(result, out)


def proj(
Expand Down
42 changes: 2 additions & 40 deletions dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,7 @@ def argmax(a, axis=None, out=None, *, keepdims=False):
dpt.argmax(dpt_array, axis=axis, keepdims=keepdims)
)

if out is None:
return result
else:
if out.shape != result.shape:
raise ValueError(
f"Output array of shape {result.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)

dpnp.copyto(out, result, casting="safe")

return out
return dpnp._copyto(result, out)


def argmin(a, axis=None, out=None, *, keepdims=False):
Expand Down Expand Up @@ -225,26 +206,7 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
dpt.argmin(dpt_array, axis=axis, keepdims=keepdims)
)

if out is None:
return result
else:
if out.shape != result.shape:
raise ValueError(
f"Output array of shape {result.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)

dpnp.copyto(out, result, casting="safe")

return out
return dpnp._copyto(result, out)


def searchsorted(a, v, side="left", sorter=None):
Expand Down
44 changes: 3 additions & 41 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,26 +414,7 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
dpt.max(dpt_array, axis=axis, keepdims=keepdims)
)

if out is None:
return result
else:
if out.shape != result.shape:
raise ValueError(
f"Output array of shape {result.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)

dpnp.copyto(out, result, casting="safe")

return out
return dpnp._copyto(result, out)


def mean(x, /, *, axis=None, dtype=None, keepdims=False, out=None, where=True):
Expand Down Expand Up @@ -638,34 +619,15 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
)
elif where is not True:
raise NotImplementedError(
"where keyword argument is only supported by its default values."
"where keyword argument is only supported by its default value."
)
else:
dpt_array = dpnp.get_usm_ndarray(a)
result = dpnp_array._create_from_usm_ndarray(
dpt.min(dpt_array, axis=axis, keepdims=keepdims)
)

if out is None:
return result
else:
if out.shape != result.shape:
raise ValueError(
f"Output array of shape {result.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)

dpnp.copyto(out, result, casting="safe")

return out
return dpnp._copyto(result, out)


def ptp(
Expand Down
Loading