Skip to content

Commit e652595

Browse files
committed
dpt.take and dpt.put changes
- Improved conformity to array API standard - Added docstrings
1 parent cb32c6f commit e652595

File tree

2 files changed

+134
-106
lines changed

2 files changed

+134
-106
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 120 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -27,43 +27,56 @@
2727

2828

2929
def take(x, indices, /, *, axis=None, mode="clip"):
30+
"""take(x, indices, axis=None, mode="clip")
31+
32+
Takes elements from array along a given axis.
33+
34+
Args:
35+
x: usm_ndarray
36+
The array that elements will be taken from.
37+
indices: usm_ndarray
38+
One-dimensional array of indices.
39+
axis:
40+
The axis over which the values will be selected.
41+
If x is one-dimensional, this argument is optional.
42+
mode:
43+
How out-of-bounds indices will be handled.
44+
"Clip" - clamps indices to (-n <= i < n), then wraps
45+
negative indices.
46+
"Wrap" - wraps both negative and positive indices.
47+
48+
Returns:
49+
out: usm_ndarray
50+
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
51+
filled with elements .
52+
"""
3053
if not isinstance(x, dpt.usm_ndarray):
3154
raise TypeError(
3255
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
3356
)
3457

35-
if not isinstance(indices, list) and not isinstance(indices, tuple):
36-
indices = (indices,)
37-
38-
queues_ = [
39-
x.sycl_queue,
40-
]
41-
usm_types_ = [
42-
x.usm_type,
43-
]
44-
45-
for i in indices:
46-
if not isinstance(i, dpt.usm_ndarray):
47-
raise TypeError(
48-
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
49-
type(i)
50-
)
58+
if not isinstance(indices, dpt.usm_ndarray):
59+
raise TypeError(
60+
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
61+
type(indices)
5162
)
52-
if not np.issubdtype(i.dtype, np.integer):
53-
raise IndexError(
54-
"`indices` expected integer data type, got `{}`".format(i.dtype)
63+
)
64+
if not np.issubdtype(indices.dtype, np.integer):
65+
raise IndexError(
66+
"`indices` expected integer data type, got `{}`".format(
67+
indices.dtype
5568
)
56-
queues_.append(i.sycl_queue)
57-
usm_types_.append(i.usm_type)
58-
exec_q = dpctl.utils.get_execution_queue(queues_)
59-
if exec_q is None:
60-
raise dpctl.utils.ExecutionPlacementError(
61-
"Can not automatically determine where to allocate the "
62-
"result or performance execution. "
63-
"Use `usm_ndarray.to_device` method to migrate data to "
64-
"be associated with the same queue."
6569
)
66-
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
70+
if indices.ndim != 1:
71+
raise ValueError(
72+
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
73+
)
74+
exec_q = dpctl.utils.get_execution_queue([x.sycl_queue, indices.sycl_queue])
75+
if exec_q is None:
76+
raise dpctl.utils.ExecutionPlacementError
77+
res_usm_type = dpctl.utils.get_coerced_usm_type(
78+
[x.usm_type, indices.usm_type]
79+
)
6780

6881
modes = {"clip": 0, "wrap": 1}
6982
try:
@@ -81,27 +94,47 @@ def take(x, indices, /, *, axis=None, mode="clip"):
8194
)
8295
axis = 0
8396

84-
if len(indices) > 1:
85-
indices = dpt.broadcast_arrays(*indices)
8697
if x_ndim > 0:
8798
axis = normalize_axis_index(operator.index(axis), x_ndim)
88-
res_shape = (
89-
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
90-
)
99+
res_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
91100
else:
92-
res_shape = indices[0].shape
101+
if axis != 0:
102+
raise ValueError("`axis` must be 0 for an array of dimension 0.")
103+
res_shape = indices.shape
93104

94105
res = dpt.empty(
95106
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
96107
)
97108

98-
hev, _ = ti._take(x, indices, res, axis, mode, sycl_queue=exec_q)
109+
hev, _ = ti._take(x, (indices,), res, axis, mode, sycl_queue=exec_q)
99110
hev.wait()
100111

101112
return res
102113

103114

104115
def put(x, indices, vals, /, *, axis=None, mode="clip"):
116+
"""put(x, indices, vals, axis=None, mode="clip")
117+
118+
Puts values of an array into another array
119+
along a given axis.
120+
121+
Args:
122+
x: usm_ndarray
123+
The array the values will be put into.
124+
indices: usm_ndarray
125+
One-dimensional array of indices.
126+
vals:
127+
Array of values to be put into `x`.
128+
Must be broadcastable to the shape of `indices`.
129+
axis:
130+
The axis over which the values will be placed.
131+
If x is one-dimensional, this argument is optional.
132+
mode:
133+
How out-of-bounds indices will be handled.
134+
"Clip" - clamps indices to (-axis_size <= i < axis_size),
135+
then wraps negative indices.
136+
"Wrap" - wraps both negative and positive indices.
137+
"""
105138
if not isinstance(x, dpt.usm_ndarray):
106139
raise TypeError(
107140
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
@@ -116,66 +149,61 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
116149
usm_types_ = [
117150
x.usm_type,
118151
]
119-
120-
if not isinstance(indices, list) and not isinstance(indices, tuple):
121-
indices = (indices,)
122-
123-
for i in indices:
124-
if not isinstance(i, dpt.usm_ndarray):
125-
raise TypeError(
126-
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
127-
type(i)
128-
)
152+
if not isinstance(indices, dpt.usm_ndarray):
153+
raise TypeError(
154+
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
155+
type(indices)
129156
)
130-
if not np.issubdtype(i.dtype, np.integer):
131-
raise IndexError(
132-
"`indices` expected integer data type, got `{}`".format(i.dtype)
157+
)
158+
if indices.ndim != 1:
159+
raise ValueError(
160+
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
161+
)
162+
if not np.issubdtype(indices.dtype, np.integer):
163+
raise IndexError(
164+
"`indices` expected integer data type, got `{}`".format(
165+
indices.dtype
133166
)
134-
queues_.append(i.sycl_queue)
135-
usm_types_.append(i.usm_type)
167+
)
168+
queues_.append(indices.sycl_queue)
169+
usm_types_.append(indices.usm_type)
136170
exec_q = dpctl.utils.get_execution_queue(queues_)
137171
if exec_q is None:
138-
raise dpctl.utils.ExecutionPlacementError(
139-
"Can not automatically determine where to allocate the "
140-
"result or performance execution. "
141-
"Use `usm_ndarray.to_device` method to migrate data to "
142-
"be associated with the same queue."
143-
)
144-
val_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
145-
172+
raise dpctl.utils.ExecutionPlacementError
173+
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
146174
modes = {"clip": 0, "wrap": 1}
147175
try:
148176
mode = modes[mode]
149177
except KeyError:
150-
raise ValueError("`mode` must be `wrap`, or `clip`.")
178+
raise ValueError("`mode` must be `clip` or `wrap`.")
151179

152-
# when axis is none, array is treated as 1D
153-
if axis is None:
154-
try:
155-
x = dpt.reshape(x, (x.size,), copy=False)
156-
axis = 0
157-
except ValueError:
158-
raise ValueError("Cannot create 1D view of input array")
159-
if len(indices) > 1:
160-
indices = dpt.broadcast_arrays(*indices)
161180
x_ndim = x.ndim
181+
if axis is None:
182+
if x_ndim > 1:
183+
raise ValueError(
184+
"`axis` cannot be `None` for array of dimension `{}`".format(
185+
x_ndim
186+
)
187+
)
188+
axis = 0
189+
162190
if x_ndim > 0:
163191
axis = normalize_axis_index(operator.index(axis), x_ndim)
164192

165-
val_shape = (
166-
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
167-
)
193+
val_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
168194
else:
169-
val_shape = indices[0].shape
195+
if axis != 0:
196+
raise ValueError("`axis` must be 0 for an array of dimension 0.")
197+
val_shape = indices.shape
170198

171199
if not isinstance(vals, dpt.usm_ndarray):
172200
vals = dpt.asarray(
173-
vals, dtype=x.dtype, usm_type=val_usm_type, sycl_queue=exec_q
201+
vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
174202
)
175203

176204
vals = dpt.broadcast_to(vals, val_shape)
177205

178-
hev, _ = ti._put(x, indices, vals, axis, mode, sycl_queue=exec_q)
206+
hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
179207
hev.wait()
180208

181209

@@ -192,14 +220,14 @@ def extract(condition, arr):
192220
193221
Args:
194222
conditions: usm_ndarray
195-
An array whose non-zero or True entries indicate the element
196-
of `arr` to extract.
223+
An array whose non-zero or True entries indicate the element
224+
of `arr` to extract.
197225
arr: usm_ndarray
198-
Input array of the same size as `condition`.
226+
Input array of the same size as `condition`.
199227
200228
Returns:
201-
usm_ndarray
202-
Rank 1 array of values from `arr` where `condition` is True.
229+
extract: usm_ndarray
230+
Rank 1 array of values from `arr` where `condition` is True.
203231
"""
204232
if not isinstance(condition, dpt.usm_ndarray):
205233
raise TypeError(
@@ -231,16 +259,16 @@ def place(arr, mask, vals):
231259
equivalent to ``arr[condition] = vals``.
232260
233261
Args:
234-
arr: usm_ndarray
235-
Array to put data into.
262+
arr: usm_ndarray
263+
Array to put data into.
236264
mask: usm_ndarray
237-
Boolean mask array. Must have the same size as `arr`.
265+
Boolean mask array. Must have the same size as `arr`.
238266
vals: usm_ndarray
239-
Values to put into `arr`. Only the first N elements are
240-
used, where N is the number of True values in `mask`. If
241-
`vals` is smaller than N, it will be repeated, and if
242-
elements of `arr` are to be masked, this sequence must be
243-
non-empty. Array `vals` must be one dimensional.
267+
Values to put into `arr`. Only the first N elements are
268+
used, where N is the number of True values in `mask`. If
269+
`vals` is smaller than N, it will be repeated, and if
270+
elements of `arr` are to be masked, this sequence must be
271+
non-empty. Array `vals` must be one dimensional.
244272
"""
245273
if not isinstance(arr, dpt.usm_ndarray):
246274
raise TypeError(
@@ -295,11 +323,11 @@ def nonzero(arr):
295323
row-major, C-style order.
296324
297325
Args:
298-
arr: usm_ndarray
299-
Input array, which has non-zero array rank.
326+
arr: usm_ndarray
327+
Input array, which has non-zero array rank.
300328
Returns:
301-
Tuple[usm_ndarray]
302-
Indices of non-zero array elements.
329+
tuple_of_usm_ndarrays: tuple
330+
Indices of non-zero array elements.
303331
"""
304332
if not isinstance(arr, dpt.usm_ndarray):
305333
raise TypeError(

0 commit comments

Comments
 (0)