@@ -162,9 +162,18 @@ def get(self):
162
162
return self .o_
163
163
164
164
165
- class WeakInexactType :
166
- """Python type representing type of Python real- or
167
- complex-valued floating point objects"""
165
+ class WeakFloatingType :
166
+ """Python type representing type of Python floating point objects"""
167
+
168
+ def __init__ (self , o ):
169
+ self .o_ = o
170
+
171
+ def get (self ):
172
+ return self .o_
173
+
174
+
175
+ class WeakComplexType :
176
+ """Python type representing type of Python complex floating point objects"""
168
177
169
178
def __init__ (self , o ):
170
179
self .o_ = o
@@ -189,14 +198,17 @@ def _get_dtype(o, dev):
189
198
return WeakBooleanType (o )
190
199
if isinstance (o , int ):
191
200
return WeakIntegralType (o )
192
- if isinstance (o , (float , complex )):
193
- return WeakInexactType (o )
201
+ if isinstance (o , float ):
202
+ return WeakFloatingType (o )
203
+ if isinstance (o , complex ):
204
+ return WeakComplexType (o )
194
205
return np .object_
195
206
196
207
197
208
def _validate_dtype (dt ) -> bool :
198
209
return isinstance (
199
- dt , (WeakBooleanType , WeakInexactType , WeakIntegralType )
210
+ dt ,
211
+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
200
212
) or (
201
213
isinstance (dt , dpt .dtype )
202
214
and dt
@@ -220,22 +232,24 @@ def _validate_dtype(dt) -> bool:
220
232
221
233
222
234
def _weak_type_num_kind (o ):
223
- _map = {"?" : 0 , "i" : 1 , "f" : 2 }
235
+ _map = {"?" : 0 , "i" : 1 , "f" : 2 , "c" : 3 }
224
236
if isinstance (o , WeakBooleanType ):
225
237
return _map ["?" ]
226
238
if isinstance (o , WeakIntegralType ):
227
239
return _map ["i" ]
228
- if isinstance (o , WeakInexactType ):
240
+ if isinstance (o , WeakFloatingType ):
229
241
return _map ["f" ]
242
+ if isinstance (o , WeakComplexType ):
243
+ return _map ["c" ]
230
244
raise TypeError (
231
245
f"Unexpected type { o } while expecting "
232
- "`WeakBooleanType`, `WeakIntegralType`, or "
233
- "`WeakInexactType `."
246
+ "`WeakBooleanType`, `WeakIntegralType`,"
247
+ "`WeakFloatingType`, or `WeakComplexType `."
234
248
)
235
249
236
250
237
251
def _strong_dtype_num_kind (o ):
238
- _map = {"b" : 0 , "i" : 1 , "u" : 1 , "f" : 2 , "c" : 2 }
252
+ _map = {"b" : 0 , "i" : 1 , "u" : 1 , "f" : 2 , "c" : 3 }
239
253
if not isinstance (o , dpt .dtype ):
240
254
raise TypeError
241
255
k = o .kind
@@ -247,20 +261,29 @@ def _strong_dtype_num_kind(o):
247
261
def _resolve_weak_types (o1_dtype , o2_dtype , dev ):
248
262
"Resolves weak data type per NEP-0050"
249
263
if isinstance (
250
- o1_dtype , (WeakBooleanType , WeakInexactType , WeakIntegralType )
264
+ o1_dtype ,
265
+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
251
266
):
252
267
if isinstance (
253
- o2_dtype , (WeakBooleanType , WeakInexactType , WeakIntegralType )
268
+ o2_dtype ,
269
+ (
270
+ WeakBooleanType ,
271
+ WeakIntegralType ,
272
+ WeakFloatingType ,
273
+ WeakComplexType ,
274
+ ),
254
275
):
255
276
raise ValueError
256
277
o1_kind_num = _weak_type_num_kind (o1_dtype )
257
278
o2_kind_num = _strong_dtype_num_kind (o2_dtype )
258
- if o1_kind_num > o2_kind_num or o1_kind_num == 2 :
279
+ if o1_kind_num > o2_kind_num :
259
280
if isinstance (o1_dtype , WeakBooleanType ):
260
281
return dpt .bool , o2_dtype
261
282
if isinstance (o1_dtype , WeakIntegralType ):
262
283
return dpt .int64 , o2_dtype
263
- if isinstance (o1_dtype .get (), complex ):
284
+ if isinstance (o1_dtype , WeakComplexType ):
285
+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
286
+ return dpt .complex64 , o2_dtype
264
287
return (
265
288
_to_device_supported_dtype (dpt .complex128 , dev ),
266
289
o2_dtype ,
@@ -269,16 +292,19 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
269
292
else :
270
293
return o2_dtype , o2_dtype
271
294
elif isinstance (
272
- o2_dtype , (WeakBooleanType , WeakInexactType , WeakIntegralType )
295
+ o2_dtype ,
296
+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
273
297
):
274
298
o1_kind_num = _strong_dtype_num_kind (o1_dtype )
275
299
o2_kind_num = _weak_type_num_kind (o2_dtype )
276
- if o2_kind_num > o1_kind_num or o2_kind_num == 2 :
300
+ if o2_kind_num > o1_kind_num :
277
301
if isinstance (o2_dtype , WeakBooleanType ):
278
302
return o1_dtype , dpt .bool
279
303
if isinstance (o2_dtype , WeakIntegralType ):
280
304
return o1_dtype , dpt .int64
281
- if isinstance (o2_dtype .get (), complex ):
305
+ if isinstance (o2_dtype , WeakComplexType ):
306
+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
307
+ return o1_dtype , dpt .complex64
282
308
return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
283
309
return (
284
310
o1_dtype ,
0 commit comments