Skip to content

Changes to integer indexing modes #1132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,18 @@
from ._copy_utils import _extract_impl, _nonzero_impl


def take(x, indices, /, *, axis=None, mode="clip"):
"""take(x, indices, axis=None, mode="clip")
def _get_indexing_mode(name):
modes = {"wrap": 0, "clip": 1}
try:
return modes[name]
except KeyError:
raise ValueError(
"`mode` must be `wrap` or `clip`." "Got `{}`.".format(name)
)


def take(x, indices, /, *, axis=None, mode="wrap"):
"""take(x, indices, axis=None, mode="wrap")

Takes elements from array along a given axis.

Expand All @@ -42,15 +52,15 @@ def take(x, indices, /, *, axis=None, mode="clip"):
Default: `None`.
mode:
How out-of-bounds indices will be handled.
"clip" - clamps indices to (-n <= i < n), then wraps
"wrap" - clamps indices to (-n <= i < n), then wraps
negative indices.
"wrap" - wraps both negative and positive indices.
Default: `"clip"`.
"clip" - clips indices to (0 <= i < n)
Default: `"wrap"`.

Returns:
out: usm_ndarray
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
filled with elements .
filled with elements from x.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
Expand Down Expand Up @@ -80,11 +90,7 @@ def take(x, indices, /, *, axis=None, mode="clip"):
[x.usm_type, indices.usm_type]
)

modes = {"clip": 0, "wrap": 1}
try:
mode = modes[mode]
except KeyError:
raise ValueError("`mode` must be `clip` or `wrap`.")
mode = _get_indexing_mode(mode)

x_ndim = x.ndim
if axis is None:
Expand Down Expand Up @@ -114,8 +120,8 @@ def take(x, indices, /, *, axis=None, mode="clip"):
return res


def put(x, indices, vals, /, *, axis=None, mode="clip"):
"""put(x, indices, vals, axis=None, mode="clip")
def put(x, indices, vals, /, *, axis=None, mode="wrap"):
"""put(x, indices, vals, axis=None, mode="wrap")

Puts values of an array into another array
along a given axis.
Expand All @@ -134,10 +140,10 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
Default: `None`.
mode:
How out-of-bounds indices will be handled.
"clip" - clamps indices to (-axis_size <= i < axis_size),
then wraps negative indices.
"wrap" - wraps both negative and positive indices.
Default: `"clip"`.
"wrap" - clamps indices to (-n <= i < n), then wraps
negative indices.
"clip" - clips indices to (0 <= i < n)
Default: `"wrap"`.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
Expand Down Expand Up @@ -175,11 +181,8 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
modes = {"clip": 0, "wrap": 1}
try:
mode = modes[mode]
except KeyError:
raise ValueError("`mode` must be `clip` or `wrap`.")

mode = _get_indexing_mode(mode)

x_ndim = x.ndim
if axis is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ namespace py = pybind11;
template <typename ProjectorT, typename Ty, typename indT> class take_kernel;
template <typename ProjectorT, typename Ty, typename indT> class put_kernel;

class ClipIndex
class WrapIndex
{
public:
ClipIndex() = default;
WrapIndex() = default;

void operator()(py::ssize_t max_item, py::ssize_t &ind) const
{
Expand All @@ -60,16 +60,15 @@ class ClipIndex
}
};

class WrapIndex
class ClipIndex
{
public:
WrapIndex() = default;
ClipIndex() = default;

void operator()(py::ssize_t max_item, py::ssize_t &ind) const
{
max_item = std::max<py::ssize_t>(max_item, 1);
ind = (ind < 0) ? (ind + max_item * ((-ind / max_item) + 1)) % max_item
: ind % max_item;
ind = std::clamp<py::ssize_t>(ind, 0, max_item - 1);
return;
}
};
Expand Down
12 changes: 6 additions & 6 deletions dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
#include "integer_advanced_indexing.hpp"

#define INDEXING_MODES 2
#define CLIP_MODE 0
#define WRAP_MODE 1
#define WRAP_MODE 0
#define CLIP_MODE 1

namespace dpctl
{
Expand Down Expand Up @@ -252,8 +252,8 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
throw py::value_error("Axis cannot be negative.");
}

if (mode != 0 && mode != 1) {
throw py::value_error("Mode must be 0 or 1.");
if (mode != 0 && mode != 1 && mode != 2) {
throw py::value_error("Mode must be 0, 1, or 2.");
}

const dpctl::tensor::usm_ndarray ind_rep = ind[0];
Expand Down Expand Up @@ -575,8 +575,8 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
throw py::value_error("Axis cannot be negative.");
}

if (mode != 0 && mode != 1) {
throw py::value_error("Mode must be 0 or 1.");
if (mode != 0 && mode != 1 && mode != 2) {
throw py::value_error("Mode must be 0, 1, or 2.");
}

if (!dst.is_writable()) {
Expand Down
64 changes: 57 additions & 7 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from helper import get_queue_or_skip, skip_if_dtype_not_supported
from numpy.testing import assert_array_equal

import dpctl
import dpctl.tensor as dpt
from dpctl.utils import ExecutionPlacementError

Expand Down Expand Up @@ -895,20 +896,21 @@ def test_integer_indexing_modes():
q = get_queue_or_skip()

x = dpt.arange(5, sycl_queue=q)
x_np = dpt.asnumpy(x)

# wrapping negative indices
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)

# wrapping
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
res = dpt.take(x, ind, mode="wrap")
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="wrap")
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise")

assert (dpt.asnumpy(res) == expected_arr).all()

# clipping to -n<=i<n,
# where n is the axis length
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
# clipping to 0 (disabling negative indices)
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)

res = dpt.take(x, ind, mode="clip")
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="raise")
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip")

assert (dpt.asnumpy(res) == expected_arr).all()

Expand Down Expand Up @@ -939,6 +941,10 @@ def test_take_arg_validation():
dpt.take(dpt.reshape(x, (2, 2)), ind0, axis=None)
with pytest.raises(ValueError):
dpt.take(x, dpt.reshape(ind0, (2, 2)))
with pytest.raises(ValueError):
dpt.take(x[0], ind0, axis=2)
with pytest.raises(ValueError):
dpt.take(x[:, dpt.newaxis, dpt.newaxis], ind0, axis=None)


def test_put_arg_validation():
Expand Down Expand Up @@ -968,6 +974,10 @@ def test_put_arg_validation():
dpt.put(x, ind0, val, mode=0)
with pytest.raises(ValueError):
dpt.put(x, dpt.reshape(ind0, (2, 2)), val)
with pytest.raises(ValueError):
dpt.put(x[0], ind0, val, axis=2)
with pytest.raises(ValueError):
dpt.put(x[:, dpt.newaxis, dpt.newaxis], ind0, val, axis=None)


def test_advanced_indexing_compute_follows_data():
Expand Down Expand Up @@ -1269,3 +1279,43 @@ def test_nonzero_large():

m = dpt.full((30, 60, 80), True)
assert m[m].size == m.size


def test_extract_arg_validation():
get_queue_or_skip()
with pytest.raises(TypeError):
dpt.extract(None, None)
cond = dpt.ones(10, dtype="?")
with pytest.raises(TypeError):
dpt.extract(cond, None)
q1 = dpctl.SyclQueue()
with pytest.raises(ExecutionPlacementError):
dpt.extract(cond.to_device(q1), dpt.zeros_like(cond, dtype="u1"))
with pytest.raises(ValueError):
dpt.extract(dpt.ones((2, 3), dtype="?"), dpt.ones((3, 2), dtype="i1"))


def test_place_arg_validation():
get_queue_or_skip()
with pytest.raises(TypeError):
dpt.place(None, None, None)
arr = dpt.zeros(8, dtype="i1")
with pytest.raises(TypeError):
dpt.place(arr, None, None)
cond = dpt.ones(8, dtype="?")
with pytest.raises(TypeError):
dpt.place(arr, cond, None)
vals = dpt.ones_like(arr)
q1 = dpctl.SyclQueue()
with pytest.raises(ExecutionPlacementError):
dpt.place(arr.to_device(q1), cond, vals)
with pytest.raises(ValueError):
dpt.place(dpt.reshape(arr, (2, 2, 2)), cond, vals)


def test_nonzero_arg_validation():
get_queue_or_skip()
with pytest.raises(TypeError):
dpt.nonzero(list())
with pytest.raises(ValueError):
dpt.nonzero(dpt.asarray(1))