Skip to content

Commit d62ab9e

Browse files
Added more checks to that dtype passed to constructor can be natively supported by device
1 parent d094443 commit d62ab9e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

dpctl/tensor/_ctors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _asarray_from_usm_ndarray(
153153
if order == "K" and fc_contig:
154154
order = "C" if c_contig else "F"
155155
if order == "K":
156+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
156157
# new USM allocation
157158
res = dpt.usm_ndarray(
158159
usm_ndary.shape,
@@ -176,6 +177,7 @@ def _asarray_from_usm_ndarray(
176177
strides=new_strides,
177178
)
178179
else:
180+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
179181
res = dpt.usm_ndarray(
180182
usm_ndary.shape,
181183
dtype=dtype,
@@ -242,6 +244,7 @@ def _asarray_from_numpy_ndarray(
242244
order = "C" if c_contig else "F"
243245
if order == "K":
244246
# new USM allocation
247+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
245248
res = dpt.usm_ndarray(
246249
ary.shape,
247250
dtype=dtype,
@@ -261,6 +264,7 @@ def _asarray_from_numpy_ndarray(
261264
res.shape, dtype=res.dtype, buffer=res.usm_data, strides=new_strides
262265
)
263266
else:
267+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
264268
res = dpt.usm_ndarray(
265269
ary.shape,
266270
dtype=dtype,
@@ -870,6 +874,7 @@ def empty_like(
870874
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
871875
sh = x.shape
872876
dtype = dpt.dtype(dtype)
877+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
873878
res = dpt.usm_ndarray(
874879
sh,
875880
dtype=dtype,
@@ -1202,6 +1207,7 @@ def eye(
12021207
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
12031208
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
12041209
dtype = _get_dtype(dtype, sycl_queue)
1210+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
12051211
res = dpt.usm_ndarray(
12061212
(n_rows, n_cols),
12071213
dtype=dtype,

0 commit comments

Comments
 (0)