Skip to content

Commit d1d8d4e

Browse files
dtype keyword of usm_ndarray now supports np.double and other types (#526)
A test added to check for this.
1 parent e5c50ea commit d1d8d4e

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

dpctl/tensor/_types.pxi

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,33 @@ cdef int typenum_from_format(str s) except *:
119119
raise TypeError("Format '" + s + "' can only have native byteorder.")
120120
return dt.num
121121

122+
cdef int descr_to_typenum(object dtype):
123+
"Returns typenum for argumentd dtype that has attribute descr, assumed numpy.dtype"
124+
obj = getattr(dtype, 'descr')
125+
if (not isinstance(obj, list) or len(obj) != 1):
126+
return -1
127+
obj = obj[0]
128+
if (not isinstance(obj, tuple) or len(obj) != 2 or obj[0]):
129+
return -1
130+
obj = obj[1]
131+
if not isinstance(obj, str):
132+
return -1
133+
return typenum_from_format(obj)
134+
122135

123136
cdef int dtype_to_typenum(dtype) except *:
124137
if isinstance(dtype, str):
125138
return typenum_from_format(dtype)
126139
elif isinstance(dtype, bytes):
127140
return typenum_from_format(dtype.decode("UTF-8"))
128141
elif hasattr(dtype, 'descr'):
129-
obj = getattr(dtype, 'descr')
130-
if (not isinstance(obj, list) or len(obj) != 1):
131-
return -1
132-
obj = obj[0]
133-
if (not isinstance(obj, tuple) or len(obj) != 2 or obj[0]):
134-
return -1
135-
obj = obj[1]
136-
if not isinstance(obj, str):
137-
return -1
138-
return typenum_from_format(obj)
142+
return descr_to_typenum(dtype)
139143
else:
140-
return -1
144+
try:
145+
dt = np.dtype(dtype)
146+
if hasattr(dt, 'descr'):
147+
return descr_to_typenum(dt)
148+
else:
149+
return -1
150+
except Exception:
151+
return -1

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def test_allocate_usm_ndarray(shape, usm_type):
7474
"c8",
7575
"c16",
7676
np.dtype("d"),
77+
np.half,
7778
],
7879
)
7980
def test_dtypes(dtype):

0 commit comments

Comments
 (0)