@@ -144,12 +144,12 @@ def _reduction_over_axis(
144
144
def sum (x , axis = None , dtype = None , keepdims = False ):
145
145
"""sum(x, axis=None, dtype=None, keepdims=False)
146
146
147
- Calculates the sum of the input array `x`.
147
+ Calculates the sum of elements in the input array `x`.
148
148
149
149
Args:
150
150
x (usm_ndarray):
151
151
input array.
152
- axis (Optional[int, Tuple[int,...]]):
152
+ axis (Optional[int, Tuple[int, ...]]):
153
153
axis or axes along which sums must be computed. If a tuple
154
154
of unique integers, sums are computed over multiple axes.
155
155
If `None`, the sum is computed over the entire array.
@@ -202,6 +202,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
202
202
)
203
203
204
204
205
+ def prod (x , axis = None , dtype = None , keepdims = False ):
206
+ """prod(x, axis=None, dtype=None, keepdims=False)
207
+
208
+ Calculates the product of elements in the input array `x`.
209
+
210
+ Args:
211
+ x (usm_ndarray):
212
+ input array.
213
+ axis (Optional[int, Tuple[int, ...]]):
214
+ axis or axes along which products must be computed. If a tuple
215
+ of unique integers, products are computed over multiple axes.
216
+ If `None`, the product is computed over the entire array.
217
+ Default: `None`.
218
+ dtype (Optional[dtype]):
219
+ data type of the returned array. If `None`, the default data
220
+ type is inferred from the "kind" of the input array data type.
221
+ * If `x` has a real-valued floating-point data type,
222
+ the returned array will have the default real-valued
223
+ floating-point data type for the device where input
224
+ array `x` is allocated.
225
+ * If x` has signed integral data type, the returned array
226
+ will have the default signed integral type for the device
227
+ where input array `x` is allocated.
228
+ * If `x` has unsigned integral data type, the returned array
229
+ will have the default unsigned integral type for the device
230
+ where input array `x` is allocated.
231
+ * If `x` has a complex-valued floating-point data typee,
232
+ the returned array will have the default complex-valued
233
+ floating-pointer data type for the device where input
234
+ array `x` is allocated.
235
+ * If `x` has a boolean data type, the returned array will
236
+ have the default signed integral type for the device
237
+ where input array `x` is allocated.
238
+ If the data type (either specified or resolved) differs from the
239
+ data type of `x`, the input array elements are cast to the
240
+ specified data type before computing the product. Default: `None`.
241
+ keepdims (Optional[bool]):
242
+ if `True`, the reduced axes (dimensions) are included in the result
243
+ as singleton dimensions, so that the returned array remains
244
+ compatible with the input arrays according to Array Broadcasting
245
+ rules. Otherwise, if `False`, the reduced axes are not included in
246
+ the returned array. Default: `False`.
247
+ Returns:
248
+ usm_ndarray:
249
+ an array containing the products. If the product was computed over
250
+ the entire array, a zero-dimensional array is returned. The returned
251
+ array has the data type as described in the `dtype` parameter
252
+ description above.
253
+ """
254
+ return _reduction_over_axis (
255
+ x ,
256
+ axis ,
257
+ dtype ,
258
+ keepdims ,
259
+ ti ._prod_over_axis ,
260
+ ti ._prod_over_axis_dtype_supported ,
261
+ _default_reduction_dtype ,
262
+ _identity = 1 ,
263
+ )
264
+
265
+
205
266
def _comparison_over_axis (x , axis , keepdims , _reduction_fn ):
206
267
if not isinstance (x , dpt .usm_ndarray ):
207
268
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
@@ -253,7 +314,7 @@ def max(x, axis=None, keepdims=False):
253
314
Args:
254
315
x (usm_ndarray):
255
316
input array.
256
- axis (Optional[int, Tuple[int,...]]):
317
+ axis (Optional[int, Tuple[int, ...]]):
257
318
axis or axes along which maxima must be computed. If a tuple
258
319
of unique integers, the maxima are computed over multiple axes.
259
320
If `None`, the max is computed over the entire array.
@@ -281,7 +342,7 @@ def min(x, axis=None, keepdims=False):
281
342
Args:
282
343
x (usm_ndarray):
283
344
input array.
284
- axis (Optional[int, Tuple[int,...]]):
345
+ axis (Optional[int, Tuple[int, ...]]):
285
346
axis or axes along which minima must be computed. If a tuple
286
347
of unique integers, the minima are computed over multiple axes.
287
348
If `None`, the min is computed over the entire array.
0 commit comments