Skip to content

Commit 0a7ea0c

Browse files
committed
dpt.take and dpt.put changes
- Improved conformity to array API standard - Added docstrings
1 parent cb32c6f commit 0a7ea0c

File tree

2 files changed

+117
-89
lines changed

2 files changed

+117
-89
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 103 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -27,43 +27,56 @@
2727

2828

2929
def take(x, indices, /, *, axis=None, mode="clip"):
30+
"""take(x, indices, axis=None, mode="clip")
31+
32+
Takes elements from array along a given axis.
33+
34+
Args:
35+
x: usm_ndarray
36+
The array that elements will be taken from.
37+
indices: usm_ndarray
38+
One-dimensional array of indices.
39+
axis:
40+
The axis over which the values will be selected.
41+
If x is one-dimensional, this argument is optional.
42+
mode:
43+
How out-of-bounds indices will be handled.
44+
"Clip" - clamps indices to (-n <= i < n), then wraps
45+
negative indices.
46+
"Wrap" - wraps both negative and positive indices.
47+
48+
Returns:
49+
out: usm_ndarray
50+
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
51+
filled with elements .
52+
"""
3053
if not isinstance(x, dpt.usm_ndarray):
3154
raise TypeError(
3255
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
3356
)
3457

35-
if not isinstance(indices, list) and not isinstance(indices, tuple):
36-
indices = (indices,)
37-
38-
queues_ = [
39-
x.sycl_queue,
40-
]
41-
usm_types_ = [
42-
x.usm_type,
43-
]
44-
45-
for i in indices:
46-
if not isinstance(i, dpt.usm_ndarray):
47-
raise TypeError(
48-
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
49-
type(i)
50-
)
58+
if not isinstance(indices, dpt.usm_ndarray):
59+
raise TypeError(
60+
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
61+
type(indices)
5162
)
52-
if not np.issubdtype(i.dtype, np.integer):
53-
raise IndexError(
54-
"`indices` expected integer data type, got `{}`".format(i.dtype)
63+
)
64+
if not np.issubdtype(indices.dtype, np.integer):
65+
raise IndexError(
66+
"`indices` expected integer data type, got `{}`".format(
67+
indices.dtype
5568
)
56-
queues_.append(i.sycl_queue)
57-
usm_types_.append(i.usm_type)
58-
exec_q = dpctl.utils.get_execution_queue(queues_)
59-
if exec_q is None:
60-
raise dpctl.utils.ExecutionPlacementError(
61-
"Can not automatically determine where to allocate the "
62-
"result or performance execution. "
63-
"Use `usm_ndarray.to_device` method to migrate data to "
64-
"be associated with the same queue."
6569
)
66-
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
70+
if indices.ndim != 1:
71+
raise ValueError(
72+
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
73+
)
74+
exec_q = dpctl.utils.get_execution_queue([x.sycl_queue, indices.sycl_queue])
75+
if exec_q is None:
76+
raise dpctl.utils.ExecutionPlacementError
77+
res_usm_type = dpctl.utils.get_coerced_usm_type(
78+
[x.usm_type, indices.usm_type]
79+
)
6780

6881
modes = {"clip": 0, "wrap": 1}
6982
try:
@@ -81,27 +94,47 @@ def take(x, indices, /, *, axis=None, mode="clip"):
8194
)
8295
axis = 0
8396

84-
if len(indices) > 1:
85-
indices = dpt.broadcast_arrays(*indices)
8697
if x_ndim > 0:
8798
axis = normalize_axis_index(operator.index(axis), x_ndim)
88-
res_shape = (
89-
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
90-
)
99+
res_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
91100
else:
92-
res_shape = indices[0].shape
101+
if axis != 0:
102+
raise ValueError("`axis` must be 0 for an array of dimension 0.")
103+
res_shape = indices.shape
93104

94105
res = dpt.empty(
95106
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
96107
)
97108

98-
hev, _ = ti._take(x, indices, res, axis, mode, sycl_queue=exec_q)
109+
hev, _ = ti._take(x, (indices,), res, axis, mode, sycl_queue=exec_q)
99110
hev.wait()
100111

101112
return res
102113

103114

104115
def put(x, indices, vals, /, *, axis=None, mode="clip"):
116+
"""put(x, indices, vals, axis=None, mode="clip")
117+
118+
Puts values of an array into another array
119+
along a given axis.
120+
121+
Args:
122+
x: usm_ndarray
123+
The array the values will be put into.
124+
indices: usm_ndarray
125+
One-dimensional array of indices.
126+
vals:
127+
Array of values to be put into `x`.
128+
Must be broadcastable to the shape of `indices`.
129+
axis:
130+
The axis over which the values will be placed.
131+
If x is one-dimensional, this argument is optional.
132+
mode:
133+
How out-of-bounds indices will be handled.
134+
"Clip" - clamps indices to (-axis_size <= i < axis_size),
135+
then wraps negative indices.
136+
"Wrap" - wraps both negative and positive indices.
137+
"""
105138
if not isinstance(x, dpt.usm_ndarray):
106139
raise TypeError(
107140
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
@@ -116,66 +149,61 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
116149
usm_types_ = [
117150
x.usm_type,
118151
]
119-
120-
if not isinstance(indices, list) and not isinstance(indices, tuple):
121-
indices = (indices,)
122-
123-
for i in indices:
124-
if not isinstance(i, dpt.usm_ndarray):
125-
raise TypeError(
126-
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
127-
type(i)
128-
)
152+
if not isinstance(indices, dpt.usm_ndarray):
153+
raise TypeError(
154+
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
155+
type(indices)
129156
)
130-
if not np.issubdtype(i.dtype, np.integer):
131-
raise IndexError(
132-
"`indices` expected integer data type, got `{}`".format(i.dtype)
157+
)
158+
if indices.ndim != 1:
159+
raise ValueError(
160+
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
161+
)
162+
if not np.issubdtype(indices.dtype, np.integer):
163+
raise IndexError(
164+
"`indices` expected integer data type, got `{}`".format(
165+
indices.dtype
133166
)
134-
queues_.append(i.sycl_queue)
135-
usm_types_.append(i.usm_type)
167+
)
168+
queues_.append(indices.sycl_queue)
169+
usm_types_.append(indices.usm_type)
136170
exec_q = dpctl.utils.get_execution_queue(queues_)
137171
if exec_q is None:
138-
raise dpctl.utils.ExecutionPlacementError(
139-
"Can not automatically determine where to allocate the "
140-
"result or performance execution. "
141-
"Use `usm_ndarray.to_device` method to migrate data to "
142-
"be associated with the same queue."
143-
)
144-
val_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
145-
172+
raise dpctl.utils.ExecutionPlacementError
173+
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
146174
modes = {"clip": 0, "wrap": 1}
147175
try:
148176
mode = modes[mode]
149177
except KeyError:
150-
raise ValueError("`mode` must be `wrap`, or `clip`.")
178+
raise ValueError("`mode` must be `clip` or `wrap`.")
151179

152-
# when axis is none, array is treated as 1D
153-
if axis is None:
154-
try:
155-
x = dpt.reshape(x, (x.size,), copy=False)
156-
axis = 0
157-
except ValueError:
158-
raise ValueError("Cannot create 1D view of input array")
159-
if len(indices) > 1:
160-
indices = dpt.broadcast_arrays(*indices)
161180
x_ndim = x.ndim
181+
if axis is None:
182+
if x_ndim > 1:
183+
raise ValueError(
184+
"`axis` cannot be `None` for array of dimension `{}`".format(
185+
x_ndim
186+
)
187+
)
188+
axis = 0
189+
162190
if x_ndim > 0:
163191
axis = normalize_axis_index(operator.index(axis), x_ndim)
164192

165-
val_shape = (
166-
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
167-
)
193+
val_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
168194
else:
169-
val_shape = indices[0].shape
195+
if axis != 0:
196+
raise ValueError("`axis` must be 0 for an array of dimension 0.")
197+
val_shape = indices.shape
170198

171199
if not isinstance(vals, dpt.usm_ndarray):
172200
vals = dpt.asarray(
173-
vals, dtype=x.dtype, usm_type=val_usm_type, sycl_queue=exec_q
201+
vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
174202
)
175203

176204
vals = dpt.broadcast_to(vals, val_shape)
177205

178-
hev, _ = ti._put(x, indices, vals, axis, mode, sycl_queue=exec_q)
206+
hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
179207
hev.wait()
180208

181209

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -542,11 +542,11 @@ def test_put_0d_val(data_dt):
542542

543543
x = dpt.arange(5, dtype=data_dt, sycl_queue=q)
544544
ind = dpt.asarray([0], dtype=np.intp, sycl_queue=q)
545-
x[ind] = 2
545+
val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q)
546+
x[ind] = val
546547
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0]))
547548

548549
x = dpt.asarray(5, dtype=data_dt, sycl_queue=q)
549-
val = 2
550550
dpt.put(x, ind, val)
551551
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x))
552552

@@ -592,13 +592,13 @@ def test_put_0d_data(data_dt):
592592
"ind_dt",
593593
_all_int_dtypes,
594594
)
595-
def test_take_0d_ind(ind_dt):
595+
def test_indexing_0d_ind(ind_dt):
596596
q = get_queue_or_skip()
597597

598598
x = dpt.arange(5, dtype="i4", sycl_queue=q)
599599
ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q)
600600

601-
y = dpt.take(x, ind)
601+
y = x[ind]
602602
assert dpt.asnumpy(x[3]) == dpt.asnumpy(y)
603603

604604

@@ -613,7 +613,7 @@ def test_put_0d_ind(ind_dt):
613613
ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q)
614614
val = dpt.asarray(5, dtype=x.dtype, sycl_queue=q)
615615

616-
dpt.put(x, ind, val, axis=0)
616+
x[ind] = val
617617
assert dpt.asnumpy(x[3]) == dpt.asnumpy(val)
618618

619619

@@ -684,10 +684,6 @@ def test_take_strided(data_dt, order):
684684
np.take(xs_np, ind_np, axis=1),
685685
dpt.asnumpy(dpt.take(xs, ind, axis=1)),
686686
)
687-
assert_array_equal(
688-
xs_np[ind_np, ind_np],
689-
dpt.asnumpy(dpt.take(xs, [ind, ind], axis=0)),
690-
)
691687

692688

693689
@pytest.mark.parametrize(
@@ -751,7 +747,7 @@ def test_take_strided_indices(ind_dt, order):
751747
inds_np = ind_np[s, ::sgn]
752748
assert_array_equal(
753749
np.take(x_np, inds_np, axis=0),
754-
dpt.asnumpy(dpt.take(x, inds, axis=0)),
750+
dpt.asnumpy(x[inds]),
755751
)
756752

757753

@@ -828,7 +824,7 @@ def test_put_strided_destination(data_dt, order):
828824
x_np1[ind_np, ind_np] = val_np
829825

830826
x1 = dpt.copy(xs)
831-
dpt.put(x1, [ind, ind], val, axis=0)
827+
x1[ind, ind] = val
832828
assert_array_equal(x_np1, dpt.asnumpy(x1))
833829

834830

@@ -887,7 +883,7 @@ def test_put_strided_indices(ind_dt, order):
887883
inds_np = ind_np[s, ::sgn]
888884

889885
x_copy = dpt.copy(x)
890-
dpt.put(x_copy, inds, val, axis=0)
886+
x_copy[inds] = val
891887

892888
x_np_copy = x_np.copy()
893889
x_np_copy[inds_np] = val_np
@@ -899,7 +895,7 @@ def test_take_arg_validation():
899895
q = get_queue_or_skip()
900896

901897
x = dpt.arange(4, dtype="i4", sycl_queue=q)
902-
ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q)
898+
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
903899
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
904900

905901
with pytest.raises(TypeError):
@@ -919,13 +915,15 @@ def test_take_arg_validation():
919915
dpt.take(x, ind0, mode=0)
920916
with pytest.raises(ValueError):
921917
dpt.take(dpt.reshape(x, (2, 2)), ind0, axis=None)
918+
with pytest.raises(ValueError):
919+
dpt.take(x, dpt.reshape(ind0, (2, 2)))
922920

923921

924922
def test_put_arg_validation():
925923
q = get_queue_or_skip()
926924

927925
x = dpt.arange(4, dtype="i4", sycl_queue=q)
928-
ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q)
926+
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
929927
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
930928
val = dpt.asarray(2, x.dtype, sycl_queue=q)
931929

@@ -946,6 +944,8 @@ def test_put_arg_validation():
946944

947945
with pytest.raises(ValueError):
948946
dpt.put(x, ind0, val, mode=0)
947+
with pytest.raises(ValueError):
948+
dpt.put(x, dpt.reshape(ind0, (2, 2)), val)
949949

950950

951951
def test_advanced_indexing_compute_follows_data():

0 commit comments

Comments
 (0)