Skip to content

Commit 98110cb

Browse files
committed
Remove use of _UndefinedStub in dtype_helpers.py
1 parent 1a73804 commit 98110cb

File tree

3 files changed

+168
-121
lines changed

3 files changed

+168
-121
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 140 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22
from collections import defaultdict
33
from collections.abc import Mapping
44
from functools import lru_cache
5-
from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union
5+
from typing import Any, DefaultDict, Dict, List, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

8-
from . import _array_module as xp
98
from . import api_version
10-
from ._array_module import _UndefinedStub
11-
from ._array_module import mod as _xp
9+
from ._array_module import mod as xp
1210
from .stubs import name_to_func
1311
from .typing import DataType, ScalarType
1412

1513
__all__ = [
14+
"uint_names",
15+
"int_names",
16+
"all_int_names",
17+
"float_names",
18+
"real_names",
19+
"complex_names",
20+
"numeric_names",
21+
"dtype_names",
1622
"int_dtypes",
1723
"uint_dtypes",
1824
"all_int_dtypes",
@@ -90,27 +96,43 @@ def __repr__(self):
9096
return f"EqualityMapping({self})"
9197

9298

93-
def _filter_stubs(*args):
94-
for a in args:
95-
if not isinstance(a, _UndefinedStub):
96-
yield a
99+
uint_names = ("uint8", "uint16", "uint32", "uint64")
100+
int_names = ("int8", "int16", "int32", "int64")
101+
all_int_names = uint_names + int_names
102+
float_names = ("float32", "float64")
103+
real_names = uint_names + int_names + float_names
104+
complex_names = ("complex64", "complex128")
105+
numeric_names = real_names + complex_names
106+
dtype_names = ("bool",) + numeric_names
97107

98108

99-
_uint_names = ("uint8", "uint16", "uint32", "uint64")
100-
_int_names = ("int8", "int16", "int32", "int64")
101-
_float_names = ("float32", "float64")
102-
_real_names = _uint_names + _int_names + _float_names
103-
_complex_names = ("complex64", "complex128")
104-
_numeric_names = _real_names + _complex_names
105-
_dtype_names = ("bool",) + _numeric_names
109+
_name_to_dtype = {}
110+
for name in dtype_names:
111+
try:
112+
dtype = getattr(xp, name)
113+
except AttributeError:
114+
continue
115+
_name_to_dtype[name] = dtype
116+
dtype_to_name = EqualityMapping([(d, n) for n, d in _name_to_dtype.items()])
106117

107118

108-
uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
109-
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
110-
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
119+
def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
120+
dtypes = []
121+
for name in names:
122+
try:
123+
dtype = _name_to_dtype[name]
124+
except KeyError:
125+
continue
126+
dtypes.append(dtype)
127+
return tuple(dtypes)
128+
129+
130+
uint_dtypes = _make_dtype_tuple_from_names(uint_names)
131+
int_dtypes = _make_dtype_tuple_from_names(int_names)
132+
float_dtypes = _make_dtype_tuple_from_names(float_names)
111133
all_int_dtypes = uint_dtypes + int_dtypes
112134
real_dtypes = all_int_dtypes + float_dtypes
113-
complex_dtypes = tuple(getattr(xp, name) for name in _complex_names)
135+
complex_dtypes = _make_dtype_tuple_from_names(complex_names)
114136
numeric_dtypes = real_dtypes
115137
if api_version > "2021.12":
116138
numeric_dtypes += complex_dtypes
@@ -121,16 +143,6 @@ def _filter_stubs(*args):
121143
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
122144

123145

124-
_dtype_name_pairs = []
125-
for name in _dtype_names:
126-
try:
127-
dtype = getattr(_xp, name)
128-
except AttributeError:
129-
continue
130-
_dtype_name_pairs.append((dtype, name))
131-
dtype_to_name = EqualityMapping(_dtype_name_pairs)
132-
133-
134146
dtype_to_scalars = EqualityMapping(
135147
[
136148
(xp.bool, [bool]),
@@ -179,47 +191,59 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
179191
return bool
180192

181193

194+
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
195+
dtype_value_pairs = []
196+
for name, value in mapping.items():
197+
assert isinstance(name, str) and name in dtype_names # sanity check
198+
try:
199+
dtype = getattr(xp, name)
200+
except AttributeError:
201+
continue
202+
dtype_value_pairs.append((dtype, value))
203+
return EqualityMapping(dtype_value_pairs)
204+
205+
182206
class MinMax(NamedTuple):
183207
min: Union[int, float]
184208
max: Union[int, float]
185209

186210

187-
dtype_ranges = EqualityMapping(
188-
[
189-
(xp.int8, MinMax(-128, +127)),
190-
(xp.int16, MinMax(-32_768, +32_767)),
191-
(xp.int32, MinMax(-2_147_483_648, +2_147_483_647)),
192-
(xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)),
193-
(xp.uint8, MinMax(0, +255)),
194-
(xp.uint16, MinMax(0, +65_535)),
195-
(xp.uint32, MinMax(0, +4_294_967_295)),
196-
(xp.uint64, MinMax(0, +18_446_744_073_709_551_615)),
197-
(xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)),
198-
(xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)),
199-
]
211+
dtype_ranges = _make_dtype_mapping_from_names(
212+
{
213+
"int8": MinMax(-128, +127),
214+
"int16": MinMax(-32_768, +32_767),
215+
"int32": MinMax(-2_147_483_648, +2_147_483_647),
216+
"int64": MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
217+
"uint8": MinMax(0, +255),
218+
"uint16": MinMax(0, +65_535),
219+
"uint32": MinMax(0, +4_294_967_295),
220+
"uint64": MinMax(0, +18_446_744_073_709_551_615),
221+
"float32": MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
222+
"float64": MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
223+
}
200224
)
201225

202226

203-
dtype_nbits = EqualityMapping(
204-
[(d, 8) for d in _filter_stubs(xp.int8, xp.uint8)]
205-
+ [(d, 16) for d in _filter_stubs(xp.int16, xp.uint16)]
206-
+ [(d, 32) for d in _filter_stubs(xp.int32, xp.uint32, xp.float32)]
207-
+ [(d, 64) for d in _filter_stubs(xp.int64, xp.uint64, xp.float64, xp.complex64)]
208-
+ [(d, 128) for d in _filter_stubs(xp.complex128)]
209-
)
227+
r_nbits = re.compile(r"[a-z]+([0-9]+)")
228+
_dtype_nbits: Dict[str, int] = {}
229+
for name in numeric_names:
230+
m = r_nbits.fullmatch(name)
231+
assert m is not None # sanity check / for mypy
232+
_dtype_nbits[name] = int(m.group(1))
233+
dtype_nbits = _make_dtype_mapping_from_names(_dtype_nbits)
210234

211235

212-
dtype_signed = EqualityMapping(
213-
[(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes]
236+
dtype_signed = _make_dtype_mapping_from_names(
237+
{**{name: True for name in int_names}, **{name: False for name in uint_names}}
214238
)
215239

216240

217-
dtype_components = EqualityMapping(
218-
[(xp.complex64, xp.float32), (xp.complex128, xp.float64)]
241+
dtype_components = _make_dtype_mapping_from_names(
242+
{"complex64": xp.float32, "complex128": xp.float64}
219243
)
220244

221245

222-
if isinstance(xp.asarray, _UndefinedStub):
246+
if not hasattr(xp, "asarray"):
223247
default_int = xp.int32
224248
default_float = xp.float32
225249
warn(
@@ -243,60 +267,73 @@ class MinMax(NamedTuple):
243267
else:
244268
default_complex = None
245269
if dtype_nbits[default_int] == 32:
246-
default_uint = xp.uint32
270+
default_uint = getattr(xp, "uint32", None)
247271
else:
248-
default_uint = xp.uint64
249-
272+
default_uint = getattr(xp, "uint64", None)
250273

251-
_numeric_promotions = [
274+
_promotion_table: Dict[Tuple[str, str], str] = {
275+
("bool", "bool"): "bool",
252276
# ints
253-
((xp.int8, xp.int8), xp.int8),
254-
((xp.int8, xp.int16), xp.int16),
255-
((xp.int8, xp.int32), xp.int32),
256-
((xp.int8, xp.int64), xp.int64),
257-
((xp.int16, xp.int16), xp.int16),
258-
((xp.int16, xp.int32), xp.int32),
259-
((xp.int16, xp.int64), xp.int64),
260-
((xp.int32, xp.int32), xp.int32),
261-
((xp.int32, xp.int64), xp.int64),
262-
((xp.int64, xp.int64), xp.int64),
277+
("int8", "int8"): "int8",
278+
("int8", "int16"): "int16",
279+
("int8", "int32"): "int32",
280+
("int8", "int64"): "int64",
281+
("int16", "int16"): "int16",
282+
("int16", "int32"): "int32",
283+
("int16", "int64"): "int64",
284+
("int32", "int32"): "int32",
285+
("int32", "int64"): "int64",
286+
("int64", "int64"): "int64",
263287
# uints
264-
((xp.uint8, xp.uint8), xp.uint8),
265-
((xp.uint8, xp.uint16), xp.uint16),
266-
((xp.uint8, xp.uint32), xp.uint32),
267-
((xp.uint8, xp.uint64), xp.uint64),
268-
((xp.uint16, xp.uint16), xp.uint16),
269-
((xp.uint16, xp.uint32), xp.uint32),
270-
((xp.uint16, xp.uint64), xp.uint64),
271-
((xp.uint32, xp.uint32), xp.uint32),
272-
((xp.uint32, xp.uint64), xp.uint64),
273-
((xp.uint64, xp.uint64), xp.uint64),
288+
("uint8", "uint8"): "uint8",
289+
("uint8", "uint16"): "uint16",
290+
("uint8", "uint32"): "uint32",
291+
("uint8", "uint64"): "uint64",
292+
("uint16", "uint16"): "uint16",
293+
("uint16", "uint32"): "uint32",
294+
("uint16", "uint64"): "uint64",
295+
("uint32", "uint32"): "uint32",
296+
("uint32", "uint64"): "uint64",
297+
("uint64", "uint64"): "uint64",
274298
# ints and uints (mixed sign)
275-
((xp.int8, xp.uint8), xp.int16),
276-
((xp.int8, xp.uint16), xp.int32),
277-
((xp.int8, xp.uint32), xp.int64),
278-
((xp.int16, xp.uint8), xp.int16),
279-
((xp.int16, xp.uint16), xp.int32),
280-
((xp.int16, xp.uint32), xp.int64),
281-
((xp.int32, xp.uint8), xp.int32),
282-
((xp.int32, xp.uint16), xp.int32),
283-
((xp.int32, xp.uint32), xp.int64),
284-
((xp.int64, xp.uint8), xp.int64),
285-
((xp.int64, xp.uint16), xp.int64),
286-
((xp.int64, xp.uint32), xp.int64),
299+
("int8", "uint8"): "int16",
300+
("int8", "uint16"): "int32",
301+
("int8", "uint32"): "int64",
302+
("int16", "uint8"): "int16",
303+
("int16", "uint16"): "int32",
304+
("int16", "uint32"): "int64",
305+
("int32", "uint8"): "int32",
306+
("int32", "uint16"): "int32",
307+
("int32", "uint32"): "int64",
308+
("int64", "uint8"): "int64",
309+
("int64", "uint16"): "int64",
310+
("int64", "uint32"): "int64",
287311
# floats
288-
((xp.float32, xp.float32), xp.float32),
289-
((xp.float32, xp.float64), xp.float64),
290-
((xp.float64, xp.float64), xp.float64),
312+
("float32", "float32"): "float32",
313+
("float32", "float64"): "float64",
314+
("float64", "float64"): "float64",
291315
# complex
292-
((xp.complex64, xp.complex64), xp.complex64),
293-
((xp.complex64, xp.complex128), xp.complex128),
294-
((xp.complex128, xp.complex128), xp.complex128),
295-
]
296-
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
297-
_promotion_table = list(set(_numeric_promotions))
298-
_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool))
299-
promotion_table = EqualityMapping(_promotion_table)
316+
("complex64", "complex64"): "complex64",
317+
("complex64", "complex128"): "complex128",
318+
("complex128", "complex128"): "complex128",
319+
}
320+
_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()})
321+
_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = []
322+
for (in_name1, in_name2), res_name in _promotion_table.items():
323+
try:
324+
in_dtype1 = getattr(xp, in_name1)
325+
except AttributeError:
326+
continue
327+
try:
328+
in_dtype2 = getattr(xp, in_name2)
329+
except AttributeError:
330+
continue
331+
try:
332+
res_dtype = getattr(xp, res_name)
333+
except AttributeError:
334+
continue
335+
_promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype))
336+
promotion_table = EqualityMapping(_promotion_table_pairs)
300337

301338

302339
def result_type(*dtypes: DataType):
@@ -325,6 +362,7 @@ def result_type(*dtypes: DataType):
325362
}
326363
func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes)
327364
for name, func in name_to_func.items():
365+
assert func.__doc__ is not None # for mypy
328366
if m := r_in_dtypes.search(func.__doc__):
329367
dtype_category = m.group(1)
330368
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
@@ -457,11 +495,10 @@ def result_type(*dtypes: DataType):
457495
}
458496

459497

498+
# Construct func_in_dtypes and func_returns bool
460499
for op, elwise_func in op_to_func.items():
461500
func_in_dtypes[op] = func_in_dtypes[elwise_func]
462501
func_returns_bool[op] = func_returns_bool[elwise_func]
463-
464-
465502
inplace_op_to_symbol = {}
466503
for op, symbol in binary_op_to_symbol.items():
467504
if op == "__matmul__" or func_returns_bool[op]:
@@ -470,8 +507,6 @@ def result_type(*dtypes: DataType):
470507
inplace_op_to_symbol[iop] = f"{symbol}="
471508
func_in_dtypes[iop] = func_in_dtypes[op]
472509
func_returns_bool[iop] = func_returns_bool[op]
473-
474-
475510
func_in_dtypes["__bool__"] = (xp.bool,)
476511
func_in_dtypes["__int__"] = all_int_dtypes
477512
func_in_dtypes["__index__"] = all_int_dtypes

array_api_tests/test_array_object.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from . import pytest_helpers as ph
1414
from . import shape_helpers as sh
1515
from . import xps
16+
from ._array_module import mod as _xp
1617
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
1718

1819
pytestmark = pytest.mark.ci
@@ -241,21 +242,27 @@ def test_setitem_masking(shape, data):
241242
)
242243

243244

244-
def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:
245+
def make_scalar_casting_param(
246+
method_name: str, dtype_name: DataType, stype: ScalarType
247+
) -> Param:
245248
return pytest.param(
246-
method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})"
249+
method_name, dtype_name, stype, id=f"{method_name}({dtype_name})"
247250
)
248251

249252

250253
@pytest.mark.parametrize(
251-
"method_name, dtype, stype",
252-
[make_param("__bool__", xp.bool, bool)]
253-
+ [make_param("__int__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)]
254-
+ [make_param("__index__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)]
255-
+ [make_param("__float__", d, float) for d in dh.float_dtypes],
254+
"method_name, dtype_name, stype",
255+
[make_scalar_casting_param("__bool__", "bool", bool)]
256+
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names]
257+
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names]
258+
+ [make_scalar_casting_param("__float__", n, float) for n in dh.float_names],
256259
)
257260
@given(data=st.data())
258-
def test_scalar_casting(method_name, dtype, stype, data):
261+
def test_scalar_casting(method_name, dtype_name, stype, data):
262+
try:
263+
dtype = getattr(_xp, dtype_name)
264+
except AttributeError as e:
265+
pytest.skip(str(e))
259266
x = data.draw(xps.arrays(dtype, shape=()), label="x")
260267
method = getattr(x, method_name)
261268
out = method()

0 commit comments

Comments
 (0)