Skip to content

Commit 7274050

Browse files
rgommerskgryte
andauthored
feat!: change default promotion behavior in summation APIs
This commit modifies type promotion behavior in `sum`, `prod`, `cumulative_sum`, and `linalg.trace` when the input array has a floating-point data type. Previously, the specification required that conforming implementations upcast to the default floating-point data type when the input array data type was of a lower precision. This commit revises that guidance to require conforming libraries return an array having the same data type as the input array. This revision stems from feedback from implementing libraries, where the current status quo matches the changes in this commit, with little desire to change. As such, the specification is amended to match this reality. Closes: data-apis#731 PR-URL: data-apis#744 Co-authored-by: Athan Reines <kgryte@gmail.com> Reviewed-by: Athan Reines <kgryte@gmail.com>
1 parent 2404c99 commit 7274050

File tree

2 files changed

+19
-46
lines changed

2 files changed

+19
-46
lines changed

src/array_api_stubs/_draft/linalg.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -742,19 +742,12 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr
742742
743743
Default: ``0``.
744744
dtype: Optional[dtype]
745-
data type of the returned array. If ``None``,
745+
data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases:
746746
747-
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
748-
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
749-
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
750-
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
751-
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
752-
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
747+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
748+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
753749
754-
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
755-
756-
.. note::
757-
keyword argument is intended to help prevent data type overflows.
750+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``.
758751
759752
Returns
760753
-------

src/array_api_stubs/_draft/statistical_functions.py

+15-35
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,14 @@ def cumulative_sum(
2323
axis along which a cumulative sum must be computed. If ``axis`` is negative, the function must determine the axis along which to compute a cumulative sum by counting from the last dimension.
2424
2525
If ``x`` is a one-dimensional array, providing an ``axis`` is optional; however, if ``x`` has more than one dimension, providing an ``axis`` is required.
26-
dtype: Optional[dtype]
27-
data type of the returned array. If ``None``,
28-
29-
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
30-
31-
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
3226
33-
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
34-
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
35-
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
36-
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
27+
dtype: Optional[dtype]
28+
data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases:
3729
38-
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
30+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
31+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
3932
40-
.. note::
41-
keyword argument is intended to help prevent data type overflows.
33+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``.
4234
4335
include_initial: bool
4436
boolean indicating whether to include the initial value as the first value in the output. By convention, the initial value must be the additive identity (i.e., zero). Default: ``False``.
@@ -200,20 +192,14 @@ def prod(
200192
input array. Should have a numeric data type.
201193
axis: Optional[Union[int, Tuple[int, ...]]]
202194
axis or axes along which products must be computed. By default, the product must be computed over the entire array. If a tuple of integers, products must be computed over multiple axes. Default: ``None``.
203-
dtype: Optional[dtype]
204-
data type of the returned array. If ``None``,
205195
206-
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
207-
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
208-
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
209-
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
210-
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
211-
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
196+
dtype: Optional[dtype]
197+
data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases:
212198
213-
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the product. Default: ``None``.
199+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
200+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
214201
215-
.. note::
216-
This keyword argument is intended to help prevent data type overflows.
202+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``.
217203
218204
keepdims: bool
219205
if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result. Default: ``False``.
@@ -298,20 +284,14 @@ def sum(
298284
input array. Should have a numeric data type.
299285
axis: Optional[Union[int, Tuple[int, ...]]]
300286
axis or axes along which sums must be computed. By default, the sum must be computed over the entire array. If a tuple of integers, sums must be computed over multiple axes. Default: ``None``.
301-
dtype: Optional[dtype]
302-
data type of the returned array. If ``None``,
303287
304-
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
305-
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
306-
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
307-
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
308-
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
309-
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
288+
dtype: Optional[dtype]
289+
data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases:
310290
311-
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
291+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
292+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
312293
313-
.. note::
314-
keyword argument is intended to help prevent data type overflows.
294+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``.
315295
316296
keepdims: bool
317297
if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result. Default: ``False``.

0 commit comments

Comments
 (0)