Skip to content

Commit 04ec23b

Browse files
Merge pull request #1426 from IntelPython/implement-product
Implement product over axis
2 parents df1c22f + 60a8ad7 commit 04ec23b

File tree

7 files changed

+640
-109
lines changed

7 files changed

+640
-109
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
tanh,
161161
trunc,
162162
)
163-
from ._reduction import argmax, argmin, max, min, sum
163+
from ._reduction import argmax, argmin, max, min, prod, sum
164164
from ._testing import allclose
165165

166166
__all__ = [
@@ -313,4 +313,5 @@
313313
"min",
314314
"argmax",
315315
"argmin",
316+
"prod",
316317
]

dpctl/tensor/_reduction.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@ def _reduction_over_axis(
144144
def sum(x, axis=None, dtype=None, keepdims=False):
145145
"""sum(x, axis=None, dtype=None, keepdims=False)
146146
147-
Calculates the sum of the input array `x`.
147+
Calculates the sum of elements in the input array `x`.
148148
149149
Args:
150150
x (usm_ndarray):
151151
input array.
152-
axis (Optional[int, Tuple[int,...]]):
152+
axis (Optional[int, Tuple[int, ...]]):
153153
axis or axes along which sums must be computed. If a tuple
154154
of unique integers, sums are computed over multiple axes.
155155
If `None`, the sum is computed over the entire array.
@@ -202,6 +202,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
202202
)
203203

204204

205+
def prod(x, axis=None, dtype=None, keepdims=False):
206+
"""prod(x, axis=None, dtype=None, keepdims=False)
207+
208+
Calculates the product of elements in the input array `x`.
209+
210+
Args:
211+
x (usm_ndarray):
212+
input array.
213+
axis (Optional[int, Tuple[int, ...]]):
214+
axis or axes along which products must be computed. If a tuple
215+
of unique integers, products are computed over multiple axes.
216+
If `None`, the product is computed over the entire array.
217+
Default: `None`.
218+
dtype (Optional[dtype]):
219+
data type of the returned array. If `None`, the default data
220+
type is inferred from the "kind" of the input array data type.
221+
* If `x` has a real-valued floating-point data type,
222+
the returned array will have the default real-valued
223+
floating-point data type for the device where input
224+
array `x` is allocated.
225+
* If x` has signed integral data type, the returned array
226+
will have the default signed integral type for the device
227+
where input array `x` is allocated.
228+
* If `x` has unsigned integral data type, the returned array
229+
will have the default unsigned integral type for the device
230+
where input array `x` is allocated.
231+
* If `x` has a complex-valued floating-point data typee,
232+
the returned array will have the default complex-valued
233+
floating-pointer data type for the device where input
234+
array `x` is allocated.
235+
* If `x` has a boolean data type, the returned array will
236+
have the default signed integral type for the device
237+
where input array `x` is allocated.
238+
If the data type (either specified or resolved) differs from the
239+
data type of `x`, the input array elements are cast to the
240+
specified data type before computing the product. Default: `None`.
241+
keepdims (Optional[bool]):
242+
if `True`, the reduced axes (dimensions) are included in the result
243+
as singleton dimensions, so that the returned array remains
244+
compatible with the input arrays according to Array Broadcasting
245+
rules. Otherwise, if `False`, the reduced axes are not included in
246+
the returned array. Default: `False`.
247+
Returns:
248+
usm_ndarray:
249+
an array containing the products. If the product was computed over
250+
the entire array, a zero-dimensional array is returned. The returned
251+
array has the data type as described in the `dtype` parameter
252+
description above.
253+
"""
254+
return _reduction_over_axis(
255+
x,
256+
axis,
257+
dtype,
258+
keepdims,
259+
ti._prod_over_axis,
260+
ti._prod_over_axis_dtype_supported,
261+
_default_reduction_dtype,
262+
_identity=1,
263+
)
264+
265+
205266
def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
206267
if not isinstance(x, dpt.usm_ndarray):
207268
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
@@ -253,7 +314,7 @@ def max(x, axis=None, keepdims=False):
253314
Args:
254315
x (usm_ndarray):
255316
input array.
256-
axis (Optional[int, Tuple[int,...]]):
317+
axis (Optional[int, Tuple[int, ...]]):
257318
axis or axes along which maxima must be computed. If a tuple
258319
of unique integers, the maxima are computed over multiple axes.
259320
If `None`, the max is computed over the entire array.
@@ -281,7 +342,7 @@ def min(x, axis=None, keepdims=False):
281342
Args:
282343
x (usm_ndarray):
283344
input array.
284-
axis (Optional[int, Tuple[int,...]]):
345+
axis (Optional[int, Tuple[int, ...]]):
285346
axis or axes along which minima must be computed. If a tuple
286347
of unique integers, the minima are computed over multiple axes.
287348
If `None`, the min is computed over the entire array.

0 commit comments

Comments
 (0)