Skip to content

Commit ceb3961

Browse files
committed
Adds device keyword argument to astype
1 parent 8f82fe1 commit ceb3961

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import dpctl.utils
2727
from dpctl.tensor._data_types import _get_dtype
2828
from dpctl.tensor._device import normalize_queue_device
29+
from dpctl.tensor._type_utils import _dtype_supported_by_device_impl
2930

3031
__doc__ = (
3132
"Implementation module for copy- and cast- operations on "
@@ -121,7 +122,7 @@ def from_numpy(np_ary, device=None, usm_type="device", sycl_queue=None):
121122
output array is created. Device can be specified by a
122123
a filter selector string, an instance of
123124
:class:`dpctl.SyclDevice`, an instance of
124-
:class:`dpctl.SyclQueue`, an instance of
125+
:class:`dpctl.SyclQueue`, or an instance of
125126
:class:`dpctl.tensor.Device`. If the value is `None`,
126127
returned array is created on the default-selected device.
127128
Default: `None`.
@@ -564,9 +565,11 @@ def copy(usm_ary, order="K"):
564565
return R
565566

566567

567-
def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
568+
def astype(
569+
usm_ary, newdtype, /, order="K", casting="unsafe", *, copy=True, device=None
570+
):
568571
""" astype(array, new_dtype, order="K", casting="unsafe", \
569-
copy=True)
572+
copy=True, device=None)
570573
571574
Returns a copy of the :class:`dpctl.tensor.usm_ndarray`, cast to a
572575
specified type.
@@ -576,7 +579,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
576579
An input array.
577580
new_dtype (dtype):
578581
The data type of the resulting array. If `None`, gives default
579-
floating point type supported by device where `array` is allocated.
582+
floating point type supported by device where the resulting array
583+
will be located.
580584
order ({"C", "F", "A", "K"}, optional):
581585
Controls memory layout of the resulting array if a copy
582586
is returned.
@@ -587,6 +591,14 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
587591
By default, `astype` always returns a newly allocated array.
588592
If this keyword is set to `False`, a view of the input array
589593
may be returned when possible.
594+
device (object): array API specification of device where the
595+
output array is created. Device can be specified by a
596+
a filter selector string, an instance of
597+
:class:`dpctl.SyclDevice`, an instance of
598+
:class:`dpctl.SyclQueue`, or an instance of
599+
:class:`dpctl.tensor.Device`. If the value is `None`,
600+
returned array is created on the same device as `array`.
601+
Default: `None`.
590602
591603
Returns:
592604
usm_ndarray:
@@ -604,7 +616,25 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
604616
)
605617
order = order[0].upper()
606618
ary_dtype = usm_ary.dtype
607-
target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue)
619+
if device is not None:
620+
if not isinstance(device, dpctl.SyclQueue):
621+
if isinstance(device, dpt.Device):
622+
device = device.sycl_queue
623+
else:
624+
device = dpt.Device.create_device(device).sycl_queue
625+
d = device.sycl_device
626+
target_dtype = _get_dtype(newdtype, device)
627+
if not _dtype_supported_by_device_impl(
628+
target_dtype, d.has_aspect_fp16, d.has_aspect_fp64
629+
):
630+
raise ValueError(
631+
f"Requested dtype `{target_dtype}` is not supported by the "
632+
"target device"
633+
)
634+
usm_ary = usm_ary.to_device(device)
635+
else:
636+
target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue)
637+
608638
if not dpt.can_cast(ary_dtype, target_dtype, casting=casting):
609639
raise TypeError(
610640
f"Can not cast from {ary_dtype} to {newdtype} "

dpctl/tensor/_type_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
from __future__ import annotations
1617

1718
import numpy as np
1819

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,20 @@ def test_astype_invalid_order():
13131313
dpt.astype(X, "i4", order="WRONG")
13141314

13151315

1316+
def test_astype_device():
1317+
get_queue_or_skip()
1318+
q1 = dpctl.SyclQueue()
1319+
q2 = dpctl.SyclQueue()
1320+
1321+
x = dpt.arange(5, dtype="i4", sycl_queue=q1)
1322+
r = dpt.astype(x, "f4")
1323+
assert r.sycl_queue == x.sycl_queue
1324+
assert r.sycl_device == x.sycl_device
1325+
1326+
r = dpt.astype(x, "f4", device=q2)
1327+
assert r.sycl_queue == q2
1328+
1329+
13161330
def test_copy():
13171331
try:
13181332
X = dpt.usm_ndarray((5, 5), "i4")[2:4, 1:4]

0 commit comments

Comments
 (0)