27
27
28
28
29
29
def take (x , indices , / , * , axis = None , mode = "clip" ):
30
+ """take(x, indices, axis=None, mode="clip")
31
+
32
+ Takes elements from array along a given axis.
33
+
34
+ Args:
35
+ x: usm_ndarray
36
+ The array that elements will be taken from.
37
+ indices: usm_ndarray
38
+ One-dimensional array of indices.
39
+ axis:
40
+ The axis over which the values will be selected.
41
+ If x is one-dimensional, this argument is optional.
42
+ mode:
43
+ How out-of-bounds indices will be handled.
44
+ "Clip" - clamps indices to (-n <= i < n), then wraps
45
+ negative indices.
46
+ "Wrap" - wraps both negative and positive indices.
47
+
48
+ Returns:
49
+ out: usm_ndarray
50
+ Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
51
+ filled with elements .
52
+ """
30
53
if not isinstance (x , dpt .usm_ndarray ):
31
54
raise TypeError (
32
55
"Expected instance of `dpt.usm_ndarray`, got `{}`." .format (type (x ))
33
56
)
34
57
35
- if not isinstance (indices , list ) and not isinstance (indices , tuple ):
36
- indices = (indices ,)
37
-
38
- queues_ = [
39
- x .sycl_queue ,
40
- ]
41
- usm_types_ = [
42
- x .usm_type ,
43
- ]
44
-
45
- for i in indices :
46
- if not isinstance (i , dpt .usm_ndarray ):
47
- raise TypeError (
48
- "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
49
- type (i )
50
- )
58
+ if not isinstance (indices , dpt .usm_ndarray ):
59
+ raise TypeError (
60
+ "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
61
+ type (indices )
51
62
)
52
- if not np .issubdtype (i .dtype , np .integer ):
53
- raise IndexError (
54
- "`indices` expected integer data type, got `{}`" .format (i .dtype )
63
+ )
64
+ if not np .issubdtype (indices .dtype , np .integer ):
65
+ raise IndexError (
66
+ "`indices` expected integer data type, got `{}`" .format (
67
+ indices .dtype
55
68
)
56
- queues_ .append (i .sycl_queue )
57
- usm_types_ .append (i .usm_type )
58
- exec_q = dpctl .utils .get_execution_queue (queues_ )
59
- if exec_q is None :
60
- raise dpctl .utils .ExecutionPlacementError (
61
- "Can not automatically determine where to allocate the "
62
- "result or performance execution. "
63
- "Use `usm_ndarray.to_device` method to migrate data to "
64
- "be associated with the same queue."
65
69
)
66
- res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
70
+ if indices .ndim != 1 :
71
+ raise ValueError (
72
+ "`indices` expected a 1D array, got `{}`" .format (indices .ndim )
73
+ )
74
+ exec_q = dpctl .utils .get_execution_queue ([x .sycl_queue , indices .sycl_queue ])
75
+ if exec_q is None :
76
+ raise dpctl .utils .ExecutionPlacementError
77
+ res_usm_type = dpctl .utils .get_coerced_usm_type (
78
+ [x .usm_type , indices .usm_type ]
79
+ )
67
80
68
81
modes = {"clip" : 0 , "wrap" : 1 }
69
82
try :
@@ -81,27 +94,47 @@ def take(x, indices, /, *, axis=None, mode="clip"):
81
94
)
82
95
axis = 0
83
96
84
- if len (indices ) > 1 :
85
- indices = dpt .broadcast_arrays (* indices )
86
97
if x_ndim > 0 :
87
98
axis = normalize_axis_index (operator .index (axis ), x_ndim )
88
- res_shape = (
89
- x .shape [:axis ] + indices [0 ].shape + x .shape [axis + len (indices ) :]
90
- )
99
+ res_shape = x .shape [:axis ] + indices .shape + x .shape [axis + 1 :]
91
100
else :
92
- res_shape = indices [0 ].shape
101
+ if axis != 0 :
102
+ raise ValueError ("`axis` must be 0 for an array of dimension 0." )
103
+ res_shape = indices .shape
93
104
94
105
res = dpt .empty (
95
106
res_shape , dtype = x .dtype , usm_type = res_usm_type , sycl_queue = exec_q
96
107
)
97
108
98
- hev , _ = ti ._take (x , indices , res , axis , mode , sycl_queue = exec_q )
109
+ hev , _ = ti ._take (x , ( indices ,) , res , axis , mode , sycl_queue = exec_q )
99
110
hev .wait ()
100
111
101
112
return res
102
113
103
114
104
115
def put (x , indices , vals , / , * , axis = None , mode = "clip" ):
116
+ """put(x, indices, vals, axis=None, mode="clip")
117
+
118
+ Puts values of an array into another array
119
+ along a given axis.
120
+
121
+ Args:
122
+ x: usm_ndarray
123
+ The array the values will be put into.
124
+ indices: usm_ndarray
125
+ One-dimensional array of indices.
126
+ vals:
127
+ Array of values to be put into `x`.
128
+ Must be broadcastable to the shape of `indices`.
129
+ axis:
130
+ The axis over which the values will be placed.
131
+ If x is one-dimensional, this argument is optional.
132
+ mode:
133
+ How out-of-bounds indices will be handled.
134
+ "Clip" - clamps indices to (-axis_size <= i < axis_size),
135
+ then wraps negative indices.
136
+ "Wrap" - wraps both negative and positive indices.
137
+ """
105
138
if not isinstance (x , dpt .usm_ndarray ):
106
139
raise TypeError (
107
140
"Expected instance of `dpt.usm_ndarray`, got `{}`." .format (type (x ))
@@ -116,66 +149,61 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
116
149
usm_types_ = [
117
150
x .usm_type ,
118
151
]
119
-
120
- if not isinstance (indices , list ) and not isinstance (indices , tuple ):
121
- indices = (indices ,)
122
-
123
- for i in indices :
124
- if not isinstance (i , dpt .usm_ndarray ):
125
- raise TypeError (
126
- "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
127
- type (i )
128
- )
152
+ if not isinstance (indices , dpt .usm_ndarray ):
153
+ raise TypeError (
154
+ "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
155
+ type (indices )
129
156
)
130
- if not np .issubdtype (i .dtype , np .integer ):
131
- raise IndexError (
132
- "`indices` expected integer data type, got `{}`" .format (i .dtype )
157
+ )
158
+ if indices .ndim != 1 :
159
+ raise ValueError (
160
+ "`indices` expected a 1D array, got `{}`" .format (indices .ndim )
161
+ )
162
+ if not np .issubdtype (indices .dtype , np .integer ):
163
+ raise IndexError (
164
+ "`indices` expected integer data type, got `{}`" .format (
165
+ indices .dtype
133
166
)
134
- queues_ .append (i .sycl_queue )
135
- usm_types_ .append (i .usm_type )
167
+ )
168
+ queues_ .append (indices .sycl_queue )
169
+ usm_types_ .append (indices .usm_type )
136
170
exec_q = dpctl .utils .get_execution_queue (queues_ )
137
171
if exec_q is None :
138
- raise dpctl .utils .ExecutionPlacementError (
139
- "Can not automatically determine where to allocate the "
140
- "result or performance execution. "
141
- "Use `usm_ndarray.to_device` method to migrate data to "
142
- "be associated with the same queue."
143
- )
144
- val_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
145
-
172
+ raise dpctl .utils .ExecutionPlacementError
173
+ vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
146
174
modes = {"clip" : 0 , "wrap" : 1 }
147
175
try :
148
176
mode = modes [mode ]
149
177
except KeyError :
150
- raise ValueError ("`mode` must be `wrap`, or `clip `." )
178
+ raise ValueError ("`mode` must be `clip` or `wrap `." )
151
179
152
- # when axis is none, array is treated as 1D
153
- if axis is None :
154
- try :
155
- x = dpt .reshape (x , (x .size ,), copy = False )
156
- axis = 0
157
- except ValueError :
158
- raise ValueError ("Cannot create 1D view of input array" )
159
- if len (indices ) > 1 :
160
- indices = dpt .broadcast_arrays (* indices )
161
180
x_ndim = x .ndim
181
+ if axis is None :
182
+ if x_ndim > 1 :
183
+ raise ValueError (
184
+ "`axis` cannot be `None` for array of dimension `{}`" .format (
185
+ x_ndim
186
+ )
187
+ )
188
+ axis = 0
189
+
162
190
if x_ndim > 0 :
163
191
axis = normalize_axis_index (operator .index (axis ), x_ndim )
164
192
165
- val_shape = (
166
- x .shape [:axis ] + indices [0 ].shape + x .shape [axis + len (indices ) :]
167
- )
193
+ val_shape = x .shape [:axis ] + indices .shape + x .shape [axis + 1 :]
168
194
else :
169
- val_shape = indices [0 ].shape
195
+ if axis != 0 :
196
+ raise ValueError ("`axis` must be 0 for an array of dimension 0." )
197
+ val_shape = indices .shape
170
198
171
199
if not isinstance (vals , dpt .usm_ndarray ):
172
200
vals = dpt .asarray (
173
- vals , dtype = x .dtype , usm_type = val_usm_type , sycl_queue = exec_q
201
+ vals , dtype = x .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
174
202
)
175
203
176
204
vals = dpt .broadcast_to (vals , val_shape )
177
205
178
- hev , _ = ti ._put (x , indices , vals , axis , mode , sycl_queue = exec_q )
206
+ hev , _ = ti ._put (x , ( indices ,) , vals , axis , mode , sycl_queue = exec_q )
179
207
hev .wait ()
180
208
181
209
@@ -192,14 +220,14 @@ def extract(condition, arr):
192
220
193
221
Args:
194
222
conditions: usm_ndarray
195
- An array whose non-zero or True entries indicate the element
196
- of `arr` to extract.
223
+ An array whose non-zero or True entries indicate the element
224
+ of `arr` to extract.
197
225
arr: usm_ndarray
198
- Input array of the same size as `condition`.
226
+ Input array of the same size as `condition`.
199
227
200
228
Returns:
201
- usm_ndarray
202
- Rank 1 array of values from `arr` where `condition` is True.
229
+ extract: usm_ndarray
230
+ Rank 1 array of values from `arr` where `condition` is True.
203
231
"""
204
232
if not isinstance (condition , dpt .usm_ndarray ):
205
233
raise TypeError (
@@ -231,16 +259,16 @@ def place(arr, mask, vals):
231
259
equivalent to ``arr[condition] = vals``.
232
260
233
261
Args:
234
- arr: usm_ndarray
235
- Array to put data into.
262
+ arr: usm_ndarray
263
+ Array to put data into.
236
264
mask: usm_ndarray
237
- Boolean mask array. Must have the same size as `arr`.
265
+ Boolean mask array. Must have the same size as `arr`.
238
266
vals: usm_ndarray
239
- Values to put into `arr`. Only the first N elements are
240
- used, where N is the number of True values in `mask`. If
241
- `vals` is smaller than N, it will be repeated, and if
242
- elements of `arr` are to be masked, this sequence must be
243
- non-empty. Array `vals` must be one dimensional.
267
+ Values to put into `arr`. Only the first N elements are
268
+ used, where N is the number of True values in `mask`. If
269
+ `vals` is smaller than N, it will be repeated, and if
270
+ elements of `arr` are to be masked, this sequence must be
271
+ non-empty. Array `vals` must be one dimensional.
244
272
"""
245
273
if not isinstance (arr , dpt .usm_ndarray ):
246
274
raise TypeError (
@@ -295,11 +323,11 @@ def nonzero(arr):
295
323
row-major, C-style order.
296
324
297
325
Args:
298
- arr: usm_ndarray
299
- Input array, which has non-zero array rank.
326
+ arr: usm_ndarray
327
+ Input array, which has non-zero array rank.
300
328
Returns:
301
- Tuple[usm_ndarray]
302
- Indices of non-zero array elements.
329
+ tuple_of_usm_ndarrays: tuple
330
+ Indices of non-zero array elements.
303
331
"""
304
332
if not isinstance (arr , dpt .usm_ndarray ):
305
333
raise TypeError (
0 commit comments