@@ -52,55 +52,16 @@ def _default_reduction_dtype(inp_dt, q):
52
52
return res_dt
53
53
54
54
55
- def sum (x , axis = None , dtype = None , keepdims = False ):
56
- """sum(x, axis=None, dtype=None, keepdims=False)
57
-
58
- Calculates the sum of the input array `x`.
59
-
60
- Args:
61
- x (usm_ndarray):
62
- input array.
63
- axis (Optional[int, Tuple[int,...]]):
64
- axis or axes along which sums must be computed. If a tuple
65
- of unique integers, sums are computed over multiple axes.
66
- If `None`, the sum if computed over the entire array.
67
- Default: `None`.
68
- dtype (Optional[dtype]):
69
- data type of the returned array. If `None`, the default data
70
- type is inferred from the "kind" of the input array data type.
71
- * If `x` has a real-valued floating-point data type,
72
- the returned array will have the default real-valued
73
- floating-point data type for the device where input
74
- array `x` is allocated.
75
- * If x` has signed integral data type, the returned array
76
- will have the default signed integral type for the device
77
- where input array `x` is allocated.
78
- * If `x` has unsigned integral data type, the returned array
79
- will have the default unsigned integral type for the device
80
- where input array `x` is allocated.
81
- * If `x` has a complex-valued floating-point data typee,
82
- the returned array will have the default complex-valued
83
- floating-pointer data type for the device where input
84
- array `x` is allocated.
85
- * If `x` has a boolean data type, the returned array will
86
- have the default signed integral type for the device
87
- where input array `x` is allocated.
88
- If the data type (either specified or resolved) differs from the
89
- data type of `x`, the input array elements are cast to the
90
- specified data type before computing the sum. Default: `None`.
91
- keepdims (Optional[bool]):
92
- if `True`, the reduced axes (dimensions) are included in the result
93
- as singleton dimensions, so that the returned array remains
94
- compatible with the input arrays according to Array Broadcasting
95
- rules. Otherwise, if `False`, the reduced axes are not included in
96
- the returned array. Default: `False`.
97
- Returns:
98
- usm_ndarray:
99
- an array containing the sums. If the sum was computed over the
100
- entire array, a zero-dimensional array is returned. The returned
101
- array has the data type as described in the `dtype` parameter
102
- description above.
103
- """
55
+ def _reduction_over_axis (
56
+ x ,
57
+ axis ,
58
+ dtype ,
59
+ keepdims ,
60
+ _reduction_fn ,
61
+ _dtype_supported ,
62
+ _default_reduction_type_fn ,
63
+ _identity = None ,
64
+ ):
104
65
if not isinstance (x , dpt .usm_ndarray ):
105
66
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
106
67
nd = x .ndim
@@ -116,29 +77,36 @@ def sum(x, axis=None, dtype=None, keepdims=False):
116
77
q = x .sycl_queue
117
78
inp_dt = x .dtype
118
79
if dtype is None :
119
- res_dt = _default_reduction_dtype (inp_dt , q )
80
+ res_dt = _default_reduction_type_fn (inp_dt , q )
120
81
else :
121
82
res_dt = dpt .dtype (dtype )
122
83
res_dt = _to_device_supported_dtype (res_dt , q .sycl_device )
123
84
124
85
res_usm_type = x .usm_type
125
86
if x .size == 0 :
126
- if keepdims :
127
- res_shape = res_shape + (1 ,) * red_nd
128
- inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
129
- res_shape = tuple (res_shape [i ] for i in inv_perm )
130
- return dpt .zeros (
131
- res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
132
- )
87
+ if _identity is None :
88
+ raise ValueError ("reduction does not support zero-size arrays" )
89
+ else :
90
+ if keepdims :
91
+ res_shape = res_shape + (1 ,) * red_nd
92
+ inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
93
+ res_shape = tuple (res_shape [i ] for i in inv_perm )
94
+ return dpt .full (
95
+ res_shape ,
96
+ _identity ,
97
+ dtype = res_dt ,
98
+ usm_type = res_usm_type ,
99
+ sycl_queue = q ,
100
+ )
133
101
if red_nd == 0 :
134
102
return dpt .astype (x , res_dt , copy = False )
135
103
136
104
host_tasks_list = []
137
- if ti . _sum_over_axis_dtype_supported (inp_dt , res_dt , res_usm_type , q ):
105
+ if _dtype_supported (inp_dt , res_dt , res_usm_type , q ):
138
106
res = dpt .empty (
139
107
res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
140
108
)
141
- ht_e , _ = ti ._sum_over_axis (
109
+ ht_e , _ = ti ._reduction_fn (
142
110
src = arr2 , trailing_dims_to_reduce = red_nd , dst = res , sycl_queue = q
143
111
)
144
112
host_tasks_list .append (ht_e )
@@ -152,7 +120,7 @@ def sum(x, axis=None, dtype=None, keepdims=False):
152
120
tmp = dpt .empty (
153
121
res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
154
122
)
155
- ht_e_tmp , r_e = ti ._sum_over_axis (
123
+ ht_e_tmp , r_e = ti ._reduction_fn (
156
124
src = arr2 , trailing_dims_to_reduce = red_nd , dst = tmp , sycl_queue = q
157
125
)
158
126
host_tasks_list .append (ht_e_tmp )
@@ -173,6 +141,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
173
141
return res
174
142
175
143
144
+ def sum (x , axis = None , dtype = None , keepdims = False ):
145
+ """sum(x, axis=None, dtype=None, keepdims=False)
146
+
147
+ Calculates the sum of the input array `x`.
148
+
149
+ Args:
150
+ x (usm_ndarray):
151
+ input array.
152
+ axis (Optional[int, Tuple[int,...]]):
153
+ axis or axes along which sums must be computed. If a tuple
154
+ of unique integers, sums are computed over multiple axes.
155
+ If `None`, the sum is computed over the entire array.
156
+ Default: `None`.
157
+ dtype (Optional[dtype]):
158
+ data type of the returned array. If `None`, the default data
159
+ type is inferred from the "kind" of the input array data type.
160
+ * If `x` has a real-valued floating-point data type,
161
+ the returned array will have the default real-valued
162
+ floating-point data type for the device where input
163
+ array `x` is allocated.
164
+ * If x` has signed integral data type, the returned array
165
+ will have the default signed integral type for the device
166
+ where input array `x` is allocated.
167
+ * If `x` has unsigned integral data type, the returned array
168
+ will have the default unsigned integral type for the device
169
+ where input array `x` is allocated.
170
+ * If `x` has a complex-valued floating-point data typee,
171
+ the returned array will have the default complex-valued
172
+ floating-pointer data type for the device where input
173
+ array `x` is allocated.
174
+ * If `x` has a boolean data type, the returned array will
175
+ have the default signed integral type for the device
176
+ where input array `x` is allocated.
177
+ If the data type (either specified or resolved) differs from the
178
+ data type of `x`, the input array elements are cast to the
179
+ specified data type before computing the sum. Default: `None`.
180
+ keepdims (Optional[bool]):
181
+ if `True`, the reduced axes (dimensions) are included in the result
182
+ as singleton dimensions, so that the returned array remains
183
+ compatible with the input arrays according to Array Broadcasting
184
+ rules. Otherwise, if `False`, the reduced axes are not included in
185
+ the returned array. Default: `False`.
186
+ Returns:
187
+ usm_ndarray:
188
+ an array containing the sums. If the sum was computed over the
189
+ entire array, a zero-dimensional array is returned. The returned
190
+ array has the data type as described in the `dtype` parameter
191
+ description above.
192
+ """
193
+ return _reduction_over_axis (
194
+ x ,
195
+ axis ,
196
+ dtype ,
197
+ keepdims ,
198
+ ti ._sum_over_axis ,
199
+ ti ._sum_over_axis_dtype_supported ,
200
+ _default_reduction_dtype ,
201
+ _identity = 0 ,
202
+ )
203
+
204
+
176
205
def _comparison_over_axis (x , axis , keepdims , _reduction_fn ):
177
206
if not isinstance (x , dpt .usm_ndarray ):
178
207
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
0 commit comments