Skip to content

Commit cc02941

Browse files
Return default floating point type of the device in dpctl.tensor.astype when newdtype=None (#1262)
* For None returns default floating point type supported by device. * For all other types returns dpt.dtype(dtype) --------- Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
1 parent 36a7cd7 commit cc02941

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import dpctl.tensor as dpt
2424
import dpctl.tensor._tensor_impl as ti
2525
import dpctl.utils
26+
from dpctl.tensor._ctors import _get_dtype
2627
from dpctl.tensor._device import normalize_queue_device
2728

2829
__doc__ = (
@@ -364,7 +365,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
364365
array (usm_ndarray):
365366
An input array.
366367
new_dtype (dtype):
367-
The data type of the resulting array.
368+
The data type of the resulting array. If `None`, gives default
369+
floating point type supported by device where `array` is allocated.
368370
order ({"C", "F", "A", "K"}, optional):
369371
Controls memory layout of the resulting array if a copy
370372
is returned.
@@ -392,7 +394,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
392394
"Recognized values are 'A', 'C', 'F', or 'K'"
393395
)
394396
ary_dtype = usm_ary.dtype
395-
target_dtype = dpt.dtype(newdtype)
397+
target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue)
396398
if not dpt.can_cast(ary_dtype, target_dtype, casting=casting):
397399
raise TypeError(
398400
f"Can not cast from {ary_dtype} to {newdtype} "

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,11 @@ def test_astype():
11961196
assert np.allclose(dpt.to_numpy(Y), np.full(Y.shape, 7, dtype="f4"))
11971197
Y = dpt.astype(X[::2, ::-1], "i4", order="K", copy=False)
11981198
assert Y.usm_data is X.usm_data
1199+
Y = dpt.astype(X, None, order="K")
1200+
if X.sycl_queue.sycl_device.has_aspect_fp64:
1201+
assert Y.dtype is dpt.float64
1202+
else:
1203+
assert Y.dtype is dpt.float32
11991204

12001205

12011206
def test_astype_invalid_order():

0 commit comments

Comments
 (0)