@@ -422,19 +422,23 @@ def stack(arrays, axis=0):
422
422
return res
423
423
424
424
425
- def can_cast (array_and_dtype_from , dtype_to , casting = "safe" ):
425
+ def can_cast (from_ , to , casting = "safe" ):
426
426
"""
427
427
can_cast(from: usm_ndarray or dtype, to: dtype) -> bool
428
428
429
429
Determines if one data type can be cast to another data type according \
430
430
Type Promotion Rules rules.
431
431
"""
432
- if not isinstance (dtype_to , dpt .dtype ):
432
+ if isinstance (to , dpt .usm_ndarray ):
433
433
raise TypeError ("Expected dtype type." )
434
434
435
- dtype_from = dpt .dtype (array_and_dtype_from )
435
+ dtype_to = dpt .dtype (to )
436
436
437
- _supported_dtype ([dtype_to , dtype_from ])
437
+ dtype_from = (
438
+ from_ .dtype if isinstance (from_ , dpt .usm_ndarray ) else dpt .dtype (from_ )
439
+ )
440
+
441
+ _supported_dtype ([dtype_from , dtype_to ])
438
442
439
443
return np .can_cast (dtype_from , dtype_to , casting )
440
444
@@ -447,7 +451,10 @@ def result_type(*arrays_and_dtypes):
447
451
Returns the dtype that results from applying the Type Promotion Rules to \
448
452
the arguments.
449
453
"""
450
- dtypes = [dpt .dtype (X ) for X in arrays_and_dtypes ]
454
+ dtypes = [
455
+ X .dtype if isinstance (X , dpt .usm_ndarray ) else dpt .dtype (X )
456
+ for X in arrays_and_dtypes
457
+ ]
451
458
452
459
_supported_dtype (dtypes )
453
460
@@ -460,6 +467,8 @@ def iinfo(type):
460
467
461
468
Returns machine limits for integer data types.
462
469
"""
470
+ if isinstance (type , dpt .usm_ndarray ):
471
+ raise TypeError ("Expected dtype type, get {to}." )
463
472
_supported_dtype ([dpt .dtype (type )])
464
473
return np .iinfo (type )
465
474
@@ -470,11 +479,14 @@ def finfo(type):
470
479
471
480
Returns machine limits for float data types.
472
481
"""
482
+ if isinstance (type , dpt .usm_ndarray ):
483
+ raise TypeError ("Expected dtype type, get {to}." )
473
484
_supported_dtype ([dpt .dtype (type )])
474
485
return np .finfo (type )
475
486
476
487
477
488
def _supported_dtype (dtypes ):
478
- if not all (dtype .char in "?bBhHiIlLqQefdFD" for dtype in dtypes ):
479
- raise ValueError ("Unsupported dtype encountered." )
489
+ for dtype in dtypes :
490
+ if dtype .char not in "?bBhHiIlLqQefdFD" :
491
+ raise ValueError (f"Dpctl doesn't support dtype { dtype } ." )
480
492
return True
0 commit comments