26
26
import dpctl .utils
27
27
from dpctl .tensor ._data_types import _get_dtype
28
28
from dpctl .tensor ._device import normalize_queue_device
29
+ from dpctl .tensor ._type_utils import _dtype_supported_by_device_impl
29
30
30
31
__doc__ = (
31
32
"Implementation module for copy- and cast- operations on "
@@ -121,7 +122,7 @@ def from_numpy(np_ary, device=None, usm_type="device", sycl_queue=None):
121
122
output array is created. Device can be specified by a
122
123
a filter selector string, an instance of
123
124
:class:`dpctl.SyclDevice`, an instance of
124
- :class:`dpctl.SyclQueue`, an instance of
125
+ :class:`dpctl.SyclQueue`, or an instance of
125
126
:class:`dpctl.tensor.Device`. If the value is `None`,
126
127
returned array is created on the default-selected device.
127
128
Default: `None`.
@@ -564,9 +565,11 @@ def copy(usm_ary, order="K"):
564
565
return R
565
566
566
567
567
- def astype (usm_ary , newdtype , order = "K" , casting = "unsafe" , copy = True ):
568
+ def astype (
569
+ usm_ary , newdtype , / , order = "K" , casting = "unsafe" , * , copy = True , device = None
570
+ ):
568
571
""" astype(array, new_dtype, order="K", casting="unsafe", \
569
- copy=True)
572
+ copy=True, device=None )
570
573
571
574
Returns a copy of the :class:`dpctl.tensor.usm_ndarray`, cast to a
572
575
specified type.
@@ -576,7 +579,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
576
579
An input array.
577
580
new_dtype (dtype):
578
581
The data type of the resulting array. If `None`, gives default
579
- floating point type supported by device where `array` is allocated.
582
+ floating point type supported by device where the resulting array
583
+ will be located.
580
584
order ({"C", "F", "A", "K"}, optional):
581
585
Controls memory layout of the resulting array if a copy
582
586
is returned.
@@ -587,6 +591,14 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
587
591
By default, `astype` always returns a newly allocated array.
588
592
If this keyword is set to `False`, a view of the input array
589
593
may be returned when possible.
594
+ device (object): array API specification of device where the
595
+ output array is created. Device can be specified by a
596
+ a filter selector string, an instance of
597
+ :class:`dpctl.SyclDevice`, an instance of
598
+ :class:`dpctl.SyclQueue`, or an instance of
599
+ :class:`dpctl.tensor.Device`. If the value is `None`,
600
+ returned array is created on the same device as `array`.
601
+ Default: `None`.
590
602
591
603
Returns:
592
604
usm_ndarray:
@@ -604,7 +616,25 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
604
616
)
605
617
order = order [0 ].upper ()
606
618
ary_dtype = usm_ary .dtype
607
- target_dtype = _get_dtype (newdtype , usm_ary .sycl_queue )
619
+ if device is not None :
620
+ if not isinstance (device , dpctl .SyclQueue ):
621
+ if isinstance (device , dpt .Device ):
622
+ device = device .sycl_queue
623
+ else :
624
+ device = dpt .Device .create_device (device ).sycl_queue
625
+ d = device .sycl_device
626
+ target_dtype = _get_dtype (newdtype , device )
627
+ if not _dtype_supported_by_device_impl (
628
+ target_dtype , d .has_aspect_fp16 , d .has_aspect_fp64
629
+ ):
630
+ raise ValueError (
631
+ f"Requested dtype `{ target_dtype } ` is not supported by the "
632
+ "target device"
633
+ )
634
+ usm_ary = usm_ary .to_device (device )
635
+ else :
636
+ target_dtype = _get_dtype (newdtype , usm_ary .sycl_queue )
637
+
608
638
if not dpt .can_cast (ary_dtype , target_dtype , casting = casting ):
609
639
raise TypeError (
610
640
f"Can not cast from { ary_dtype } to { newdtype } "
0 commit comments