Skip to content

Commit

Permalink
Get rid of falling back on numpy in dpnp.put (#1838)
Browse files Browse the repository at this point in the history
* Get rid of call_origin in dpnp.put

* Extended tests for dpnp.put
  • Loading branch information
antonwolfy authored May 22, 2024
1 parent 069cad2 commit 614af33
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 188 deletions.
118 changes: 65 additions & 53 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def nonzero(a):
[2, 1]])
A common use for ``nonzero`` is to find the indices of an array, where
a condition is ``True.`` Given an array `a`, the condition `a` > 3 is
a condition is ``True``. Given an array `a`, the condition `a` > 3 is
a boolean array and since ``False`` is interpreted as ``0``,
``np.nonzero(a > 3)`` yields the indices of the `a` where the condition is
true.
Expand Down Expand Up @@ -736,25 +736,33 @@ def place(x, mask, vals, /):
return call_origin(numpy.place, x, mask, vals, dpnp_inplace=True)


# pylint: disable=redefined-outer-name
def put(a, indices, vals, /, *, axis=None, mode="wrap"):
def put(a, ind, v, /, *, axis=None, mode="wrap"):
"""
Puts values of an array into another array along a given axis.
For full documentation refer to :obj:`numpy.put`.
Limitations
-----------
Parameters `a` and `indices` are supported either as :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Parameter `indices` is supported as 1-D array of integer data type.
Parameter `vals` must be broadcastable to the shape of `indices`
and has the same data type as `a` if it is as :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Parameter `mode` is supported with ``wrap``, the default, and ``clip``
values.
Parameter `axis` is supported as integer only.
Otherwise the function will be executed sequentially on CPU.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
The array the values will be put into.
ind : {array_like}
Target indices, interpreted as integers.
v : {scalar, array_like}
Values to be put into `a`. Must be broadcastable to the result shape
``a.shape[:axis] + ind.shape + a.shape[axis+1:]``.
axis {None, int}, optional
The axis along which the values will be placed. If `a` is 1-D array,
this argument is optional.
Default: ``None``.
mode : {'wrap', 'clip'}, optional
Specifies how out-of-bounds indices will behave.
- 'wrap': clamps indices to (``-n <= i < n``), then wraps negative
indices.
- 'clip': clips indices to (``0 <= i < n``).
Default: ``'wrap'``.
See Also
--------
Expand All @@ -774,49 +782,53 @@ def put(a, indices, vals, /, *, axis=None, mode="wrap"):
Examples
--------
>>> import dpnp as np
>>> x = np.arange(5)
>>> indices = np.array([0, 1])
>>> np.put(x, indices, [-44, -55])
>>> x
array([-44, -55, 2, 3, 4])
>>> a = np.arange(5)
>>> np.put(a, [0, 2], [-44, -55])
>>> a
array([-44, 1, -55, 3, 4])
>>> x = np.arange(5)
>>> indices = np.array([22])
>>> np.put(x, indices, -5, mode='clip')
>>> x
>>> a = np.arange(5)
>>> np.put(a, 22, -5, mode='clip')
>>> a
array([ 0, 1, 2, 3, -5])
"""

if dpnp.is_supported_array_type(a) and dpnp.is_supported_array_type(
indices
):
if indices.ndim != 1 or not dpnp.issubdtype(
indices.dtype, dpnp.integer
):
pass
elif mode not in ("clip", "wrap"):
pass
elif axis is not None and not isinstance(axis, int):
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")
# TODO: remove when #1382(dpctl) is solved
elif dpnp.is_supported_array_type(vals) and a.dtype != vals.dtype:
pass
else:
if axis is None and a.ndim > 1:
a = dpnp.reshape(a, -1)
dpt_array = dpnp.get_usm_ndarray(a)
dpt_indices = dpnp.get_usm_ndarray(indices)
dpt_vals = (
dpnp.get_usm_ndarray(vals)
if isinstance(vals, dpnp_array)
else vals
)
return dpt.put(
dpt_array, dpt_indices, dpt_vals, axis=axis, mode=mode
)
dpnp.check_supported_arrays_type(a)

if not dpnp.is_supported_array_type(ind):
ind = dpnp.asarray(
ind, dtype=dpnp.intp, sycl_queue=a.sycl_queue, usm_type=a.usm_type
)
elif not dpnp.issubdtype(ind.dtype, dpnp.integer):
ind = dpnp.astype(ind, dtype=dpnp.intp, casting="safe")
ind = dpnp.ravel(ind)

if not dpnp.is_supported_array_type(v):
v = dpnp.asarray(
v, dtype=a.dtype, sycl_queue=a.sycl_queue, usm_type=a.usm_type
)
if v.size == 0:
return

if not (axis is None or isinstance(axis, int)):
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")

in_a = a
if axis is None and a.ndim > 1:
a = dpnp.ravel(in_a)

if mode not in ("wrap", "clip"):
raise ValueError(
f"clipmode must be one of 'clip' or 'wrap' (got '{mode}')"
)

return call_origin(numpy.put, a, indices, vals, mode, dpnp_inplace=True)
usm_a = dpnp.get_usm_ndarray(a)
usm_ind = dpnp.get_usm_ndarray(ind)
usm_v = dpnp.get_usm_ndarray(v)
dpt.put(usm_a, usm_ind, usm_v, axis=axis, mode=mode)
if in_a is not a:
in_a[:] = a.reshape(in_a.shape, copy=False)


# pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -1194,7 +1206,7 @@ def triu_indices(n, k=0, m=None):
-------
inds : tuple, shape(2) of ndarrays, shape(`n`)
The indices for the triangle. The returned tuple contains two arrays,
each with the indices along one dimension of the array. Can be used
each with the indices along one dimension of the array. Can be used
to slice a ndarray of shape(`n`, `n`).
"""

Expand Down
8 changes: 0 additions & 8 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ def get_integer_dtypes():
return [dpnp.int32, dpnp.int64]


def get_integer_dtypes():
"""
Build a list of integer types supported by DPNP.
"""

return [dpnp.int32, dpnp.int64]


def get_complex_dtypes(device=None):
"""
Build a list of complex types supported by DPNP based on device capabilities.
Expand Down
Loading

0 comments on commit 614af33

Please sign in to comment.