2
2
from collections import defaultdict
3
3
from collections .abc import Mapping
4
4
from functools import lru_cache
5
- from typing import Any , DefaultDict , NamedTuple , Sequence , Tuple , Union
5
+ from typing import Any , DefaultDict , Dict , List , NamedTuple , Sequence , Tuple , Union
6
6
from warnings import warn
7
7
8
- from . import _array_module as xp
9
8
from . import api_version
10
- from ._array_module import _UndefinedStub
11
- from ._array_module import mod as _xp
9
+ from ._array_module import mod as xp
12
10
from .stubs import name_to_func
13
11
from .typing import DataType , ScalarType
14
12
15
13
__all__ = [
14
+ "uint_names" ,
15
+ "int_names" ,
16
+ "all_int_names" ,
17
+ "float_names" ,
18
+ "real_names" ,
19
+ "complex_names" ,
20
+ "numeric_names" ,
21
+ "dtype_names" ,
16
22
"int_dtypes" ,
17
23
"uint_dtypes" ,
18
24
"all_int_dtypes" ,
@@ -90,27 +96,43 @@ def __repr__(self):
90
96
return f"EqualityMapping({ self } )"
91
97
92
98
93
- def _filter_stubs (* args ):
94
- for a in args :
95
- if not isinstance (a , _UndefinedStub ):
96
- yield a
99
+ uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
100
+ int_names = ("int8" , "int16" , "int32" , "int64" )
101
+ all_int_names = uint_names + int_names
102
+ float_names = ("float32" , "float64" )
103
+ real_names = uint_names + int_names + float_names
104
+ complex_names = ("complex64" , "complex128" )
105
+ numeric_names = real_names + complex_names
106
+ dtype_names = ("bool" ,) + numeric_names
97
107
98
108
99
- _uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
100
- _int_names = ("int8" , "int16" , "int32" , "int64" )
101
- _float_names = ("float32" , "float64" )
102
- _real_names = _uint_names + _int_names + _float_names
103
- _complex_names = ("complex64" , "complex128" )
104
- _numeric_names = _real_names + _complex_names
105
- _dtype_names = ("bool" ,) + _numeric_names
109
+ _name_to_dtype = {}
110
+ for name in dtype_names :
111
+ try :
112
+ dtype = getattr (xp , name )
113
+ except AttributeError :
114
+ continue
115
+ _name_to_dtype [name ] = dtype
116
+ dtype_to_name = EqualityMapping ([(d , n ) for n , d in _name_to_dtype .items ()])
106
117
107
118
108
- uint_dtypes = tuple (getattr (xp , name ) for name in _uint_names )
109
- int_dtypes = tuple (getattr (xp , name ) for name in _int_names )
110
- float_dtypes = tuple (getattr (xp , name ) for name in _float_names )
119
+ def _make_dtype_tuple_from_names (names : List [str ]) -> Tuple [DataType ]:
120
+ dtypes = []
121
+ for name in names :
122
+ try :
123
+ dtype = _name_to_dtype [name ]
124
+ except KeyError :
125
+ continue
126
+ dtypes .append (dtype )
127
+ return tuple (dtypes )
128
+
129
+
130
+ uint_dtypes = _make_dtype_tuple_from_names (uint_names )
131
+ int_dtypes = _make_dtype_tuple_from_names (int_names )
132
+ float_dtypes = _make_dtype_tuple_from_names (float_names )
111
133
all_int_dtypes = uint_dtypes + int_dtypes
112
134
real_dtypes = all_int_dtypes + float_dtypes
113
- complex_dtypes = tuple ( getattr ( xp , name ) for name in _complex_names )
135
+ complex_dtypes = _make_dtype_tuple_from_names ( complex_names )
114
136
numeric_dtypes = real_dtypes
115
137
if api_version > "2021.12" :
116
138
numeric_dtypes += complex_dtypes
@@ -121,16 +143,6 @@ def _filter_stubs(*args):
121
143
bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
122
144
123
145
124
- _dtype_name_pairs = []
125
- for name in _dtype_names :
126
- try :
127
- dtype = getattr (_xp , name )
128
- except AttributeError :
129
- continue
130
- _dtype_name_pairs .append ((dtype , name ))
131
- dtype_to_name = EqualityMapping (_dtype_name_pairs )
132
-
133
-
134
146
dtype_to_scalars = EqualityMapping (
135
147
[
136
148
(xp .bool , [bool ]),
@@ -179,47 +191,59 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
179
191
return bool
180
192
181
193
194
+ def _make_dtype_mapping_from_names (mapping : Dict [str , Any ]) -> EqualityMapping :
195
+ dtype_value_pairs = []
196
+ for name , value in mapping .items ():
197
+ assert isinstance (name , str ) and name in dtype_names # sanity check
198
+ try :
199
+ dtype = getattr (xp , name )
200
+ except AttributeError :
201
+ continue
202
+ dtype_value_pairs .append ((dtype , value ))
203
+ return EqualityMapping (dtype_value_pairs )
204
+
205
+
182
206
class MinMax (NamedTuple ):
183
207
min : Union [int , float ]
184
208
max : Union [int , float ]
185
209
186
210
187
- dtype_ranges = EqualityMapping (
188
- [
189
- ( xp . int8 , MinMax (- 128 , + 127 ) ),
190
- ( xp . int16 , MinMax (- 32_768 , + 32_767 ) ),
191
- ( xp . int32 , MinMax (- 2_147_483_648 , + 2_147_483_647 ) ),
192
- ( xp . int64 , MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ) ),
193
- ( xp . uint8 , MinMax (0 , + 255 ) ),
194
- ( xp . uint16 , MinMax (0 , + 65_535 ) ),
195
- ( xp . uint32 , MinMax (0 , + 4_294_967_295 ) ),
196
- ( xp . uint64 , MinMax (0 , + 18_446_744_073_709_551_615 ) ),
197
- ( xp . float32 , MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ) ),
198
- ( xp . float64 , MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ) ),
199
- ]
211
+ dtype_ranges = _make_dtype_mapping_from_names (
212
+ {
213
+ " int8" : MinMax (- 128 , + 127 ),
214
+ " int16" : MinMax (- 32_768 , + 32_767 ),
215
+ " int32" : MinMax (- 2_147_483_648 , + 2_147_483_647 ),
216
+ " int64" : MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ),
217
+ " uint8" : MinMax (0 , + 255 ),
218
+ " uint16" : MinMax (0 , + 65_535 ),
219
+ " uint32" : MinMax (0 , + 4_294_967_295 ),
220
+ " uint64" : MinMax (0 , + 18_446_744_073_709_551_615 ),
221
+ " float32" : MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ),
222
+ " float64" : MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ),
223
+ }
200
224
)
201
225
202
226
203
- dtype_nbits = EqualityMapping (
204
- [( d , 8 ) for d in _filter_stubs ( xp . int8 , xp . uint8 )]
205
- + [( d , 16 ) for d in _filter_stubs ( xp . int16 , xp . uint16 )]
206
- + [( d , 32 ) for d in _filter_stubs ( xp . int32 , xp . uint32 , xp . float32 )]
207
- + [( d , 64 ) for d in _filter_stubs ( xp . int64 , xp . uint64 , xp . float64 , xp . complex64 )]
208
- + [( d , 128 ) for d in _filter_stubs ( xp . complex128 )]
209
- )
227
+ r_nbits = re . compile ( r"[a-z]+([0-9]+)" )
228
+ _dtype_nbits : Dict [ str , int ] = {}
229
+ for name in numeric_names :
230
+ m = r_nbits . fullmatch ( name )
231
+ assert m is not None # sanity check / for mypy
232
+ _dtype_nbits [ name ] = int ( m . group ( 1 ))
233
+ dtype_nbits = _make_dtype_mapping_from_names ( _dtype_nbits )
210
234
211
235
212
- dtype_signed = EqualityMapping (
213
- [( d , True ) for d in int_dtypes ] + [( d , False ) for d in uint_dtypes ]
236
+ dtype_signed = _make_dtype_mapping_from_names (
237
+ { ** { name : True for name in int_names }, ** { name : False for name in uint_names }}
214
238
)
215
239
216
240
217
- dtype_components = EqualityMapping (
218
- [( xp . complex64 , xp .float32 ), ( xp . complex128 , xp .float64 )]
241
+ dtype_components = _make_dtype_mapping_from_names (
242
+ { " complex64" : xp .float32 , " complex128" : xp .float64 }
219
243
)
220
244
221
245
222
- if isinstance (xp . asarray , _UndefinedStub ):
246
+ if not hasattr (xp , "asarray" ):
223
247
default_int = xp .int32
224
248
default_float = xp .float32
225
249
warn (
@@ -243,60 +267,73 @@ class MinMax(NamedTuple):
243
267
else :
244
268
default_complex = None
245
269
if dtype_nbits [default_int ] == 32 :
246
- default_uint = xp . uint32
270
+ default_uint = getattr ( xp , " uint32" , None )
247
271
else :
248
- default_uint = xp .uint64
249
-
272
+ default_uint = getattr (xp , "uint64" , None )
250
273
251
- _numeric_promotions = [
274
+ _promotion_table : Dict [Tuple [str , str ], str ] = {
275
+ ("bool" , "bool" ): "bool" ,
252
276
# ints
253
- (( xp . int8 , xp . int8 ), xp . int8 ) ,
254
- (( xp . int8 , xp . int16 ), xp . int16 ) ,
255
- (( xp . int8 , xp . int32 ), xp . int32 ) ,
256
- (( xp . int8 , xp . int64 ), xp . int64 ) ,
257
- (( xp . int16 , xp . int16 ), xp . int16 ) ,
258
- (( xp . int16 , xp . int32 ), xp . int32 ) ,
259
- (( xp . int16 , xp . int64 ), xp . int64 ) ,
260
- (( xp . int32 , xp . int32 ), xp . int32 ) ,
261
- (( xp . int32 , xp . int64 ), xp . int64 ) ,
262
- (( xp . int64 , xp . int64 ), xp . int64 ) ,
277
+ (" int8" , " int8" ): " int8" ,
278
+ (" int8" , " int16" ): " int16" ,
279
+ (" int8" , " int32" ): " int32" ,
280
+ (" int8" , " int64" ): " int64" ,
281
+ (" int16" , " int16" ): " int16" ,
282
+ (" int16" , " int32" ): " int32" ,
283
+ (" int16" , " int64" ): " int64" ,
284
+ (" int32" , " int32" ): " int32" ,
285
+ (" int32" , " int64" ): " int64" ,
286
+ (" int64" , " int64" ): " int64" ,
263
287
# uints
264
- (( xp . uint8 , xp . uint8 ), xp . uint8 ) ,
265
- (( xp . uint8 , xp . uint16 ), xp . uint16 ) ,
266
- (( xp . uint8 , xp . uint32 ), xp . uint32 ) ,
267
- (( xp . uint8 , xp . uint64 ), xp . uint64 ) ,
268
- (( xp . uint16 , xp . uint16 ), xp . uint16 ) ,
269
- (( xp . uint16 , xp . uint32 ), xp . uint32 ) ,
270
- (( xp . uint16 , xp . uint64 ), xp . uint64 ) ,
271
- (( xp . uint32 , xp . uint32 ), xp . uint32 ) ,
272
- (( xp . uint32 , xp . uint64 ), xp . uint64 ) ,
273
- (( xp . uint64 , xp . uint64 ), xp . uint64 ) ,
288
+ (" uint8" , " uint8" ): " uint8" ,
289
+ (" uint8" , " uint16" ): " uint16" ,
290
+ (" uint8" , " uint32" ): " uint32" ,
291
+ (" uint8" , " uint64" ): " uint64" ,
292
+ (" uint16" , " uint16" ): " uint16" ,
293
+ (" uint16" , " uint32" ): " uint32" ,
294
+ (" uint16" , " uint64" ): " uint64" ,
295
+ (" uint32" , " uint32" ): " uint32" ,
296
+ (" uint32" , " uint64" ): " uint64" ,
297
+ (" uint64" , " uint64" ): " uint64" ,
274
298
# ints and uints (mixed sign)
275
- (( xp . int8 , xp . uint8 ), xp . int16 ) ,
276
- (( xp . int8 , xp . uint16 ), xp . int32 ) ,
277
- (( xp . int8 , xp . uint32 ), xp . int64 ) ,
278
- (( xp . int16 , xp . uint8 ), xp . int16 ) ,
279
- (( xp . int16 , xp . uint16 ), xp . int32 ) ,
280
- (( xp . int16 , xp . uint32 ), xp . int64 ) ,
281
- (( xp . int32 , xp . uint8 ), xp . int32 ) ,
282
- (( xp . int32 , xp . uint16 ), xp . int32 ) ,
283
- (( xp . int32 , xp . uint32 ), xp . int64 ) ,
284
- (( xp . int64 , xp . uint8 ), xp . int64 ) ,
285
- (( xp . int64 , xp . uint16 ), xp . int64 ) ,
286
- (( xp . int64 , xp . uint32 ), xp . int64 ) ,
299
+ (" int8" , " uint8" ): " int16" ,
300
+ (" int8" , " uint16" ): " int32" ,
301
+ (" int8" , " uint32" ): " int64" ,
302
+ (" int16" , " uint8" ): " int16" ,
303
+ (" int16" , " uint16" ): " int32" ,
304
+ (" int16" , " uint32" ): " int64" ,
305
+ (" int32" , " uint8" ): " int32" ,
306
+ (" int32" , " uint16" ): " int32" ,
307
+ (" int32" , " uint32" ): " int64" ,
308
+ (" int64" , " uint8" ): " int64" ,
309
+ (" int64" , " uint16" ): " int64" ,
310
+ (" int64" , " uint32" ): " int64" ,
287
311
# floats
288
- (( xp . float32 , xp . float32 ), xp . float32 ) ,
289
- (( xp . float32 , xp . float64 ), xp . float64 ) ,
290
- (( xp . float64 , xp . float64 ), xp . float64 ) ,
312
+ (" float32" , " float32" ): " float32" ,
313
+ (" float32" , " float64" ): " float64" ,
314
+ (" float64" , " float64" ): " float64" ,
291
315
# complex
292
- ((xp .complex64 , xp .complex64 ), xp .complex64 ),
293
- ((xp .complex64 , xp .complex128 ), xp .complex128 ),
294
- ((xp .complex128 , xp .complex128 ), xp .complex128 ),
295
- ]
296
- _numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
297
- _promotion_table = list (set (_numeric_promotions ))
298
- _promotion_table .insert (0 , ((xp .bool , xp .bool ), xp .bool ))
299
- promotion_table = EqualityMapping (_promotion_table )
316
+ ("complex64" , "complex64" ): "complex64" ,
317
+ ("complex64" , "complex128" ): "complex128" ,
318
+ ("complex128" , "complex128" ): "complex128" ,
319
+ }
320
+ _promotion_table .update ({(d2 , d1 ): res for (d1 , d2 ), res in _promotion_table .items ()})
321
+ _promotion_table_pairs : List [Tuple [Tuple [DataType , DataType ], DataType ]] = []
322
+ for (in_name1 , in_name2 ), res_name in _promotion_table .items ():
323
+ try :
324
+ in_dtype1 = getattr (xp , in_name1 )
325
+ except AttributeError :
326
+ continue
327
+ try :
328
+ in_dtype2 = getattr (xp , in_name2 )
329
+ except AttributeError :
330
+ continue
331
+ try :
332
+ res_dtype = getattr (xp , res_name )
333
+ except AttributeError :
334
+ continue
335
+ _promotion_table_pairs .append (((in_dtype1 , in_dtype2 ), res_dtype ))
336
+ promotion_table = EqualityMapping (_promotion_table_pairs )
300
337
301
338
302
339
def result_type (* dtypes : DataType ):
@@ -325,6 +362,7 @@ def result_type(*dtypes: DataType):
325
362
}
326
363
func_in_dtypes : DefaultDict [str , Tuple [DataType , ...]] = defaultdict (lambda : all_dtypes )
327
364
for name , func in name_to_func .items ():
365
+ assert func .__doc__ is not None # for mypy
328
366
if m := r_in_dtypes .search (func .__doc__ ):
329
367
dtype_category = m .group (1 )
330
368
if dtype_category == "numeric" and r_int_note .search (func .__doc__ ):
@@ -457,11 +495,10 @@ def result_type(*dtypes: DataType):
457
495
}
458
496
459
497
498
+ # Construct func_in_dtypes and func_returns bool
460
499
for op , elwise_func in op_to_func .items ():
461
500
func_in_dtypes [op ] = func_in_dtypes [elwise_func ]
462
501
func_returns_bool [op ] = func_returns_bool [elwise_func ]
463
-
464
-
465
502
inplace_op_to_symbol = {}
466
503
for op , symbol in binary_op_to_symbol .items ():
467
504
if op == "__matmul__" or func_returns_bool [op ]:
@@ -470,8 +507,6 @@ def result_type(*dtypes: DataType):
470
507
inplace_op_to_symbol [iop ] = f"{ symbol } ="
471
508
func_in_dtypes [iop ] = func_in_dtypes [op ]
472
509
func_returns_bool [iop ] = func_returns_bool [op ]
473
-
474
-
475
510
func_in_dtypes ["__bool__" ] = (xp .bool ,)
476
511
func_in_dtypes ["__int__" ] = all_int_dtypes
477
512
func_in_dtypes ["__index__" ] = all_int_dtypes
0 commit comments