Skip to content

dpt.take and dpt.put changes #1099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 103 additions & 75 deletions dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,56 @@


def take(x, indices, /, *, axis=None, mode="clip"):
"""take(x, indices, axis=None, mode="clip")

Takes elements from array along a given axis.

Args:
x: usm_ndarray
The array that elements will be taken from.
indices: usm_ndarray
One-dimensional array of indices.
axis:
The axis over which the values will be selected.
If x is one-dimensional, this argument is optional.
mode:
How out-of-bounds indices will be handled.
"Clip" - clamps indices to (-n <= i < n), then wraps
negative indices.
"Wrap" - wraps both negative and positive indices.

Returns:
out: usm_ndarray
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
filled with elements .
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
)

if not isinstance(indices, list) and not isinstance(indices, tuple):
indices = (indices,)

queues_ = [
x.sycl_queue,
]
usm_types_ = [
x.usm_type,
]

for i in indices:
if not isinstance(i, dpt.usm_ndarray):
raise TypeError(
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
type(i)
)
if not isinstance(indices, dpt.usm_ndarray):
raise TypeError(
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
type(indices)
)
if not np.issubdtype(i.dtype, np.integer):
raise IndexError(
"`indices` expected integer data type, got `{}`".format(i.dtype)
)
if not np.issubdtype(indices.dtype, np.integer):
raise IndexError(
"`indices` expected integer data type, got `{}`".format(
indices.dtype
)
queues_.append(i.sycl_queue)
usm_types_.append(i.usm_type)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Can not automatically determine where to allocate the "
"result or performance execution. "
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
if indices.ndim != 1:
raise ValueError(
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
)
exec_q = dpctl.utils.get_execution_queue([x.sycl_queue, indices.sycl_queue])
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError
res_usm_type = dpctl.utils.get_coerced_usm_type(
[x.usm_type, indices.usm_type]
)

modes = {"clip": 0, "wrap": 1}
try:
Expand All @@ -81,27 +94,47 @@ def take(x, indices, /, *, axis=None, mode="clip"):
)
axis = 0

if len(indices) > 1:
indices = dpt.broadcast_arrays(*indices)
if x_ndim > 0:
axis = normalize_axis_index(operator.index(axis), x_ndim)
res_shape = (
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
)
res_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
else:
res_shape = indices[0].shape
if axis != 0:
raise ValueError("`axis` must be 0 for an array of dimension 0.")
res_shape = indices.shape

res = dpt.empty(
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
)

hev, _ = ti._take(x, indices, res, axis, mode, sycl_queue=exec_q)
hev, _ = ti._take(x, (indices,), res, axis, mode, sycl_queue=exec_q)
hev.wait()

return res


def put(x, indices, vals, /, *, axis=None, mode="clip"):
"""put(x, indices, vals, axis=None, mode="clip")

Puts values of an array into another array
along a given axis.

Args:
x: usm_ndarray
The array the values will be put into.
indices: usm_ndarray
One-dimensional array of indices.
vals:
Array of values to be put into `x`.
Must be broadcastable to the shape of `indices`.
axis:
The axis over which the values will be placed.
If x is one-dimensional, this argument is optional.
mode:
How out-of-bounds indices will be handled.
"Clip" - clamps indices to (-axis_size <= i < axis_size),
then wraps negative indices.
"Wrap" - wraps both negative and positive indices.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
Expand All @@ -116,66 +149,61 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
usm_types_ = [
x.usm_type,
]

if not isinstance(indices, list) and not isinstance(indices, tuple):
indices = (indices,)

for i in indices:
if not isinstance(i, dpt.usm_ndarray):
raise TypeError(
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
type(i)
)
if not isinstance(indices, dpt.usm_ndarray):
raise TypeError(
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
type(indices)
)
if not np.issubdtype(i.dtype, np.integer):
raise IndexError(
"`indices` expected integer data type, got `{}`".format(i.dtype)
)
if indices.ndim != 1:
raise ValueError(
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
)
if not np.issubdtype(indices.dtype, np.integer):
raise IndexError(
"`indices` expected integer data type, got `{}`".format(
indices.dtype
)
queues_.append(i.sycl_queue)
usm_types_.append(i.usm_type)
)
queues_.append(indices.sycl_queue)
usm_types_.append(indices.usm_type)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Can not automatically determine where to allocate the "
"result or performance execution. "
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)
val_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)

raise dpctl.utils.ExecutionPlacementError
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
modes = {"clip": 0, "wrap": 1}
try:
mode = modes[mode]
except KeyError:
raise ValueError("`mode` must be `wrap`, or `clip`.")
raise ValueError("`mode` must be `clip` or `wrap`.")

# when axis is none, array is treated as 1D
if axis is None:
try:
x = dpt.reshape(x, (x.size,), copy=False)
axis = 0
except ValueError:
raise ValueError("Cannot create 1D view of input array")
if len(indices) > 1:
indices = dpt.broadcast_arrays(*indices)
x_ndim = x.ndim
if axis is None:
if x_ndim > 1:
raise ValueError(
"`axis` cannot be `None` for array of dimension `{}`".format(
x_ndim
)
)
axis = 0

if x_ndim > 0:
axis = normalize_axis_index(operator.index(axis), x_ndim)

val_shape = (
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
)
val_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
else:
val_shape = indices[0].shape
if axis != 0:
raise ValueError("`axis` must be 0 for an array of dimension 0.")
val_shape = indices.shape

if not isinstance(vals, dpt.usm_ndarray):
vals = dpt.asarray(
vals, dtype=x.dtype, usm_type=val_usm_type, sycl_queue=exec_q
vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
)

vals = dpt.broadcast_to(vals, val_shape)

hev, _ = ti._put(x, indices, vals, axis, mode, sycl_queue=exec_q)
hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
hev.wait()


Expand Down
28 changes: 14 additions & 14 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,11 @@ def test_put_0d_val(data_dt):

x = dpt.arange(5, dtype=data_dt, sycl_queue=q)
ind = dpt.asarray([0], dtype=np.intp, sycl_queue=q)
x[ind] = 2
val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q)
x[ind] = val
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0]))

x = dpt.asarray(5, dtype=data_dt, sycl_queue=q)
val = 2
dpt.put(x, ind, val)
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x))

Expand Down Expand Up @@ -592,13 +592,13 @@ def test_put_0d_data(data_dt):
"ind_dt",
_all_int_dtypes,
)
def test_take_0d_ind(ind_dt):
def test_indexing_0d_ind(ind_dt):
q = get_queue_or_skip()

x = dpt.arange(5, dtype="i4", sycl_queue=q)
ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q)

y = dpt.take(x, ind)
y = x[ind]
assert dpt.asnumpy(x[3]) == dpt.asnumpy(y)


Expand All @@ -613,7 +613,7 @@ def test_put_0d_ind(ind_dt):
ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q)
val = dpt.asarray(5, dtype=x.dtype, sycl_queue=q)

dpt.put(x, ind, val, axis=0)
x[ind] = val
assert dpt.asnumpy(x[3]) == dpt.asnumpy(val)


Expand Down Expand Up @@ -684,10 +684,6 @@ def test_take_strided(data_dt, order):
np.take(xs_np, ind_np, axis=1),
dpt.asnumpy(dpt.take(xs, ind, axis=1)),
)
assert_array_equal(
xs_np[ind_np, ind_np],
dpt.asnumpy(dpt.take(xs, [ind, ind], axis=0)),
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -751,7 +747,7 @@ def test_take_strided_indices(ind_dt, order):
inds_np = ind_np[s, ::sgn]
assert_array_equal(
np.take(x_np, inds_np, axis=0),
dpt.asnumpy(dpt.take(x, inds, axis=0)),
dpt.asnumpy(x[inds]),
)


Expand Down Expand Up @@ -828,7 +824,7 @@ def test_put_strided_destination(data_dt, order):
x_np1[ind_np, ind_np] = val_np

x1 = dpt.copy(xs)
dpt.put(x1, [ind, ind], val, axis=0)
x1[ind, ind] = val
assert_array_equal(x_np1, dpt.asnumpy(x1))


Expand Down Expand Up @@ -887,7 +883,7 @@ def test_put_strided_indices(ind_dt, order):
inds_np = ind_np[s, ::sgn]

x_copy = dpt.copy(x)
dpt.put(x_copy, inds, val, axis=0)
x_copy[inds] = val

x_np_copy = x_np.copy()
x_np_copy[inds_np] = val_np
Expand All @@ -899,7 +895,7 @@ def test_take_arg_validation():
q = get_queue_or_skip()

x = dpt.arange(4, dtype="i4", sycl_queue=q)
ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q)
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)

with pytest.raises(TypeError):
Expand All @@ -919,13 +915,15 @@ def test_take_arg_validation():
dpt.take(x, ind0, mode=0)
with pytest.raises(ValueError):
dpt.take(dpt.reshape(x, (2, 2)), ind0, axis=None)
with pytest.raises(ValueError):
dpt.take(x, dpt.reshape(ind0, (2, 2)))


def test_put_arg_validation():
q = get_queue_or_skip()

x = dpt.arange(4, dtype="i4", sycl_queue=q)
ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q)
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
val = dpt.asarray(2, x.dtype, sycl_queue=q)

Expand All @@ -946,6 +944,8 @@ def test_put_arg_validation():

with pytest.raises(ValueError):
dpt.put(x, ind0, val, mode=0)
with pytest.raises(ValueError):
dpt.put(x, dpt.reshape(ind0, (2, 2)), val)


def test_advanced_indexing_compute_follows_data():
Expand Down