Skip to content

Commit c3d1f62

Browse files
committed
Sum now uses a generic Python API
1 parent 257bc03 commit c3d1f62

File tree

1 file changed

+89
-60
lines changed

1 file changed

+89
-60
lines changed

dpctl/tensor/_reduction.py

Lines changed: 89 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -52,55 +52,16 @@ def _default_reduction_dtype(inp_dt, q):
5252
return res_dt
5353

5454

55-
def sum(x, axis=None, dtype=None, keepdims=False):
56-
"""sum(x, axis=None, dtype=None, keepdims=False)
57-
58-
Calculates the sum of the input array `x`.
59-
60-
Args:
61-
x (usm_ndarray):
62-
input array.
63-
axis (Optional[int, Tuple[int,...]]):
64-
axis or axes along which sums must be computed. If a tuple
65-
of unique integers, sums are computed over multiple axes.
66-
If `None`, the sum if computed over the entire array.
67-
Default: `None`.
68-
dtype (Optional[dtype]):
69-
data type of the returned array. If `None`, the default data
70-
type is inferred from the "kind" of the input array data type.
71-
* If `x` has a real-valued floating-point data type,
72-
the returned array will have the default real-valued
73-
floating-point data type for the device where input
74-
array `x` is allocated.
75-
* If x` has signed integral data type, the returned array
76-
will have the default signed integral type for the device
77-
where input array `x` is allocated.
78-
* If `x` has unsigned integral data type, the returned array
79-
will have the default unsigned integral type for the device
80-
where input array `x` is allocated.
81-
* If `x` has a complex-valued floating-point data typee,
82-
the returned array will have the default complex-valued
83-
floating-pointer data type for the device where input
84-
array `x` is allocated.
85-
* If `x` has a boolean data type, the returned array will
86-
have the default signed integral type for the device
87-
where input array `x` is allocated.
88-
If the data type (either specified or resolved) differs from the
89-
data type of `x`, the input array elements are cast to the
90-
specified data type before computing the sum. Default: `None`.
91-
keepdims (Optional[bool]):
92-
if `True`, the reduced axes (dimensions) are included in the result
93-
as singleton dimensions, so that the returned array remains
94-
compatible with the input arrays according to Array Broadcasting
95-
rules. Otherwise, if `False`, the reduced axes are not included in
96-
the returned array. Default: `False`.
97-
Returns:
98-
usm_ndarray:
99-
an array containing the sums. If the sum was computed over the
100-
entire array, a zero-dimensional array is returned. The returned
101-
array has the data type as described in the `dtype` parameter
102-
description above.
103-
"""
55+
def _reduction_over_axis(
56+
x,
57+
axis,
58+
dtype,
59+
keepdims,
60+
_reduction_fn,
61+
_dtype_supported,
62+
_default_reduction_type_fn,
63+
_identity=None,
64+
):
10465
if not isinstance(x, dpt.usm_ndarray):
10566
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
10667
nd = x.ndim
@@ -116,29 +77,36 @@ def sum(x, axis=None, dtype=None, keepdims=False):
11677
q = x.sycl_queue
11778
inp_dt = x.dtype
11879
if dtype is None:
119-
res_dt = _default_reduction_dtype(inp_dt, q)
80+
res_dt = _default_reduction_type_fn(inp_dt, q)
12081
else:
12182
res_dt = dpt.dtype(dtype)
12283
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
12384

12485
res_usm_type = x.usm_type
12586
if x.size == 0:
126-
if keepdims:
127-
res_shape = res_shape + (1,) * red_nd
128-
inv_perm = sorted(range(nd), key=lambda d: perm[d])
129-
res_shape = tuple(res_shape[i] for i in inv_perm)
130-
return dpt.zeros(
131-
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
132-
)
87+
if _identity is None:
88+
raise ValueError("reduction does not support zero-size arrays")
89+
else:
90+
if keepdims:
91+
res_shape = res_shape + (1,) * red_nd
92+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
93+
res_shape = tuple(res_shape[i] for i in inv_perm)
94+
return dpt.full(
95+
res_shape,
96+
_identity,
97+
dtype=res_dt,
98+
usm_type=res_usm_type,
99+
sycl_queue=q,
100+
)
133101
if red_nd == 0:
134102
return dpt.astype(x, res_dt, copy=False)
135103

136104
host_tasks_list = []
137-
if ti._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q):
105+
if _dtype_supported(inp_dt, res_dt, res_usm_type, q):
138106
res = dpt.empty(
139107
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
140108
)
141-
ht_e, _ = ti._sum_over_axis(
109+
ht_e, _ = ti._reduction_fn(
142110
src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q
143111
)
144112
host_tasks_list.append(ht_e)
@@ -152,7 +120,7 @@ def sum(x, axis=None, dtype=None, keepdims=False):
152120
tmp = dpt.empty(
153121
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
154122
)
155-
ht_e_tmp, r_e = ti._sum_over_axis(
123+
ht_e_tmp, r_e = ti._reduction_fn(
156124
src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q
157125
)
158126
host_tasks_list.append(ht_e_tmp)
@@ -173,6 +141,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
173141
return res
174142

175143

144+
def sum(x, axis=None, dtype=None, keepdims=False):
145+
"""sum(x, axis=None, dtype=None, keepdims=False)
146+
147+
Calculates the sum of the input array `x`.
148+
149+
Args:
150+
x (usm_ndarray):
151+
input array.
152+
axis (Optional[int, Tuple[int,...]]):
153+
axis or axes along which sums must be computed. If a tuple
154+
of unique integers, sums are computed over multiple axes.
155+
If `None`, the sum is computed over the entire array.
156+
Default: `None`.
157+
dtype (Optional[dtype]):
158+
data type of the returned array. If `None`, the default data
159+
type is inferred from the "kind" of the input array data type.
160+
* If `x` has a real-valued floating-point data type,
161+
the returned array will have the default real-valued
162+
floating-point data type for the device where input
163+
array `x` is allocated.
164+
* If x` has signed integral data type, the returned array
165+
will have the default signed integral type for the device
166+
where input array `x` is allocated.
167+
* If `x` has unsigned integral data type, the returned array
168+
will have the default unsigned integral type for the device
169+
where input array `x` is allocated.
170+
* If `x` has a complex-valued floating-point data typee,
171+
the returned array will have the default complex-valued
172+
floating-pointer data type for the device where input
173+
array `x` is allocated.
174+
* If `x` has a boolean data type, the returned array will
175+
have the default signed integral type for the device
176+
where input array `x` is allocated.
177+
If the data type (either specified or resolved) differs from the
178+
data type of `x`, the input array elements are cast to the
179+
specified data type before computing the sum. Default: `None`.
180+
keepdims (Optional[bool]):
181+
if `True`, the reduced axes (dimensions) are included in the result
182+
as singleton dimensions, so that the returned array remains
183+
compatible with the input arrays according to Array Broadcasting
184+
rules. Otherwise, if `False`, the reduced axes are not included in
185+
the returned array. Default: `False`.
186+
Returns:
187+
usm_ndarray:
188+
an array containing the sums. If the sum was computed over the
189+
entire array, a zero-dimensional array is returned. The returned
190+
array has the data type as described in the `dtype` parameter
191+
description above.
192+
"""
193+
return _reduction_over_axis(
194+
x,
195+
axis,
196+
dtype,
197+
keepdims,
198+
ti._sum_over_axis,
199+
ti._sum_over_axis_dtype_supported,
200+
_default_reduction_dtype,
201+
_identity=0,
202+
)
203+
204+
176205
def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
177206
if not isinstance(x, dpt.usm_ndarray):
178207
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")

0 commit comments

Comments
 (0)