@@ -153,6 +153,7 @@ def _asarray_from_usm_ndarray(
153
153
if order == "K" and fc_contig :
154
154
order = "C" if c_contig else "F"
155
155
if order == "K" :
156
+ _ensure_native_dtype_device_support (dtype , copy_q .sycl_device )
156
157
# new USM allocation
157
158
res = dpt .usm_ndarray (
158
159
usm_ndary .shape ,
@@ -176,6 +177,7 @@ def _asarray_from_usm_ndarray(
176
177
strides = new_strides ,
177
178
)
178
179
else :
180
+ _ensure_native_dtype_device_support (dtype , copy_q .sycl_device )
179
181
res = dpt .usm_ndarray (
180
182
usm_ndary .shape ,
181
183
dtype = dtype ,
@@ -242,6 +244,7 @@ def _asarray_from_numpy_ndarray(
242
244
order = "C" if c_contig else "F"
243
245
if order == "K" :
244
246
# new USM allocation
247
+ _ensure_native_dtype_device_support (dtype , copy_q .sycl_device )
245
248
res = dpt .usm_ndarray (
246
249
ary .shape ,
247
250
dtype = dtype ,
@@ -261,6 +264,7 @@ def _asarray_from_numpy_ndarray(
261
264
res .shape , dtype = res .dtype , buffer = res .usm_data , strides = new_strides
262
265
)
263
266
else :
267
+ _ensure_native_dtype_device_support (dtype , copy_q .sycl_device )
264
268
res = dpt .usm_ndarray (
265
269
ary .shape ,
266
270
dtype = dtype ,
@@ -283,6 +287,35 @@ def _is_object_with_buffer_protocol(obj):
283
287
return False
284
288
285
289
290
+ def _ensure_native_dtype_device_support (dtype , dev ) -> None :
291
+ """Check that dtype is natively supported by device.
292
+
293
+ Arg:
294
+ dtype: elemental data-type
295
+ dev: :class:`dpctl.SyclDevice`
296
+ Return:
297
+ None
298
+ Raise:
299
+ ValueError is device does not natively support this dtype.
300
+ """
301
+ if dtype in [dpt .float64 , dpt .complex128 ] and not dev .has_aspect_fp64 :
302
+ raise ValueError (
303
+ f"Device { dev .name } does not provide native support "
304
+ "for double-precision floating point type."
305
+ )
306
+ if (
307
+ dtype
308
+ in [
309
+ dpt .float16 ,
310
+ ]
311
+ and not dev .has_aspect_fp16
312
+ ):
313
+ raise ValueError (
314
+ f"Device { dev .name } does not provide native support "
315
+ "for half-precision floating point type."
316
+ )
317
+
318
+
286
319
def asarray (
287
320
obj ,
288
321
dtype = None ,
@@ -474,6 +507,7 @@ def empty(
474
507
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
475
508
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
476
509
dtype = _get_dtype (dtype , sycl_queue )
510
+ _ensure_native_dtype_device_support (dtype , sycl_queue .sycl_device )
477
511
res = dpt .usm_ndarray (
478
512
sh ,
479
513
dtype = dtype ,
@@ -651,6 +685,7 @@ def zeros(
651
685
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
652
686
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
653
687
dtype = _get_dtype (dtype , sycl_queue )
688
+ _ensure_native_dtype_device_support (dtype , sycl_queue .sycl_device )
654
689
res = dpt .usm_ndarray (
655
690
sh ,
656
691
dtype = dtype ,
@@ -839,6 +874,7 @@ def empty_like(
839
874
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
840
875
sh = x .shape
841
876
dtype = dpt .dtype (dtype )
877
+ _ensure_native_dtype_device_support (dtype , sycl_queue .sycl_device )
842
878
res = dpt .usm_ndarray (
843
879
sh ,
844
880
dtype = dtype ,
@@ -1171,6 +1207,7 @@ def eye(
1171
1207
dpctl .utils .validate_usm_type (usm_type , allow_none = False )
1172
1208
sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
1173
1209
dtype = _get_dtype (dtype , sycl_queue )
1210
+ _ensure_native_dtype_device_support (dtype , sycl_queue .sycl_device )
1174
1211
res = dpt .usm_ndarray (
1175
1212
(n_rows , n_cols ),
1176
1213
dtype = dtype ,
0 commit comments