Skip to content

Commit 6ca4bbb

Browse files
Merge pull request #1042 from IntelPython/fix-gh-1038-empty-zero-check-device-aspects
Fix gh 1038 empty zero check device aspects
2 parents a43326d + d62ab9e commit 6ca4bbb

File tree

4 files changed

+99
-36
lines changed

4 files changed

+99
-36
lines changed

dpctl/tensor/_ctors.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _asarray_from_usm_ndarray(
153153
if order == "K" and fc_contig:
154154
order = "C" if c_contig else "F"
155155
if order == "K":
156+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
156157
# new USM allocation
157158
res = dpt.usm_ndarray(
158159
usm_ndary.shape,
@@ -176,6 +177,7 @@ def _asarray_from_usm_ndarray(
176177
strides=new_strides,
177178
)
178179
else:
180+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
179181
res = dpt.usm_ndarray(
180182
usm_ndary.shape,
181183
dtype=dtype,
@@ -242,6 +244,7 @@ def _asarray_from_numpy_ndarray(
242244
order = "C" if c_contig else "F"
243245
if order == "K":
244246
# new USM allocation
247+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
245248
res = dpt.usm_ndarray(
246249
ary.shape,
247250
dtype=dtype,
@@ -261,6 +264,7 @@ def _asarray_from_numpy_ndarray(
261264
res.shape, dtype=res.dtype, buffer=res.usm_data, strides=new_strides
262265
)
263266
else:
267+
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
264268
res = dpt.usm_ndarray(
265269
ary.shape,
266270
dtype=dtype,
@@ -283,6 +287,35 @@ def _is_object_with_buffer_protocol(obj):
283287
return False
284288

285289

290+
def _ensure_native_dtype_device_support(dtype, dev) -> None:
291+
"""Check that dtype is natively supported by device.
292+
293+
Arg:
294+
dtype: elemental data-type
295+
dev: :class:`dpctl.SyclDevice`
296+
Return:
297+
None
298+
Raise:
299+
ValueError is device does not natively support this dtype.
300+
"""
301+
if dtype in [dpt.float64, dpt.complex128] and not dev.has_aspect_fp64:
302+
raise ValueError(
303+
f"Device {dev.name} does not provide native support "
304+
"for double-precision floating point type."
305+
)
306+
if (
307+
dtype
308+
in [
309+
dpt.float16,
310+
]
311+
and not dev.has_aspect_fp16
312+
):
313+
raise ValueError(
314+
f"Device {dev.name} does not provide native support "
315+
"for half-precision floating point type."
316+
)
317+
318+
286319
def asarray(
287320
obj,
288321
dtype=None,
@@ -474,6 +507,7 @@ def empty(
474507
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
475508
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
476509
dtype = _get_dtype(dtype, sycl_queue)
510+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
477511
res = dpt.usm_ndarray(
478512
sh,
479513
dtype=dtype,
@@ -651,6 +685,7 @@ def zeros(
651685
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
652686
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
653687
dtype = _get_dtype(dtype, sycl_queue)
688+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
654689
res = dpt.usm_ndarray(
655690
sh,
656691
dtype=dtype,
@@ -839,6 +874,7 @@ def empty_like(
839874
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
840875
sh = x.shape
841876
dtype = dpt.dtype(dtype)
877+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
842878
res = dpt.usm_ndarray(
843879
sh,
844880
dtype=dtype,
@@ -1171,6 +1207,7 @@ def eye(
11711207
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
11721208
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
11731209
dtype = _get_dtype(dtype, sycl_queue)
1210+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
11741211
res = dpt.usm_ndarray(
11751212
(n_rows, n_cols),
11761213
dtype=dtype,

dpctl/tensor/_usmarray.pyx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ cdef class usm_ndarray:
178178
cdef Py_ssize_t _offset = offset
179179
cdef Py_ssize_t ary_min_displacement = 0
180180
cdef Py_ssize_t ary_max_displacement = 0
181+
cdef bint is_fp64 = False
182+
cdef bint is_fp16 = False
181183

182184
self._reset()
183185
if (not isinstance(shape, (list, tuple))
@@ -253,6 +255,16 @@ cdef class usm_ndarray:
253255
self._cleanup()
254256
raise ValueError(("buffer='{}' can not accomodate "
255257
"the requested array.").format(buffer))
258+
is_fp64 = (typenum == UAR_DOUBLE or typenum == UAR_CDOUBLE)
259+
is_fp16 = (typenum == UAR_HALF)
260+
if (is_fp64 or is_fp16):
261+
if ((is_fp64 and not _buffer.sycl_device.has_aspect_fp64) or
262+
(is_fp16 and not _buffer.sycl_device.has_aspect_fp16)
263+
):
264+
raise ValueError(
265+
f"Device {_buffer.sycl_device.name} does"
266+
f" not support {dtype} natively."
267+
)
256268
self.base_ = _buffer
257269
self.data_ = (<char *> (<size_t> _buffer._pointer)) + itemsize * _offset
258270
self.shape_ = shape_ptr

0 commit comments

Comments
 (0)