Skip to content

Commit 71fc06a

Browse files
Fixed dtype validation in _USMBufferData auxiliary class
np.number comprises np.integer and np.inexact, so replace issubdtype(dt.type, np.inexact) or issubdtype(dt.type, np.number) with more efficient issubdtype(dt.type, np.number). Also allowed dtype.type to be np.bool_, to accommodate dtype="|b1" Replaced single quotation marks with double quotation marks per linter's flavor.
1 parent 5637ff2 commit 71fc06a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dpctl/memory/_sycl_usm_array_interface_utils.pxi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
cdef bint _valid_usm_ptr_and_context(DPCTLSyclUSMRef ptr, SyclContext ctx):
44
usm_type = _Memory.get_pointer_type(ptr, ctx)
5-
return usm_type in (b'shared', b'device', b'host')
5+
return usm_type in (b"shared", b"device", b"host")
66

77

88
cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(
@@ -49,7 +49,7 @@ cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
4949
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclContextRef"):
5050
ctx = <SyclContext>SyclContext(syclobj)
5151
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
52-
elif hasattr(syclobj, '_get_capsule'):
52+
elif hasattr(syclobj, "_get_capsule"):
5353
cap = syclobj._get_capsule()
5454
if pycapsule.PyCapsule_IsValid(cap, "SyclQueueRef"):
5555
q = SyclQueue(cap)
@@ -166,8 +166,8 @@ cdef class _USMBufferData:
166166
nd = len(ary_shape)
167167
try:
168168
dt = np.dtype(ary_typestr)
169-
if (dt.hasobject or not (np.issubdtype(dt.type, np.integer) or
170-
np.issubdtype(dt.type, np.inexact))):
169+
if (dt.hasobject or not (np.issubdtype(dt.type, np.number) or
170+
dt.type is np.bool_)):
171171
DPCTLQueue_Delete(QRef)
172172
raise TypeError("Only integer types, floating and complex "
173173
"floating types are supported.")

0 commit comments

Comments
 (0)