Skip to content

Commit 1538b7d

Browse files
Fixes #998
This change fixes formation of TypeError message string. Also improved dtype argument validation.
1 parent 7690952 commit 1538b7d

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

dpctl/tensor/_types.pxi

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,28 +102,29 @@ cdef str _make_typestr(int typenum):
102102
return type_to_str[typenum] + str(type_bytesize(typenum))
103103

104104

105-
cdef int typenum_from_format(str s) except *:
105+
cdef int typenum_from_format(str s):
106106
"""
107107
Internal utility to convert string describing type format
108108
109109
Format is [<|=>][biufc]#
110110
Shortcuts for formats are i, u, d, D
111111
"""
112112
if not s:
113-
raise TypeError("Format string '" + s + "' cannot be empty.")
113+
return -1
114114
try:
115115
dt = np.dtype(s)
116-
except Exception as e:
117-
raise TypeError("Format '" + s + "' is not understood.") from e
116+
except Exception:
117+
return -1
118118
if (dt.byteorder == ">"):
119-
raise TypeError("Format '" + s + "' can only have native byteorder.")
119+
return -2
120120
return dt.num
121121

122+
122123
cdef int descr_to_typenum(object dtype):
123124
"Returns typenum for argumentd dtype that has attribute descr, assumed numpy.dtype"
124125
obj = getattr(dtype, 'descr')
125126
if (not isinstance(obj, list) or len(obj) != 1):
126-
return -1
127+
return -1 # token for ValueError
127128
obj = obj[0]
128129
if (not isinstance(obj, tuple) or len(obj) != 2 or obj[0]):
129130
return -1
@@ -143,9 +144,11 @@ cdef int dtype_to_typenum(dtype) except *:
143144
else:
144145
try:
145146
dt = np.dtype(dtype)
146-
if hasattr(dt, 'descr'):
147-
return descr_to_typenum(dt)
148-
else:
149-
return -1
147+
except TypeError:
148+
return -3
150149
except Exception:
151150
return -1
151+
if hasattr(dt, 'descr'):
152+
return descr_to_typenum(dt)
153+
else:
154+
return -3 # token for TypeError

dpctl/tensor/_usmarray.pyx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,15 @@ cdef class usm_ndarray:
188188
raise TypeError("Argument shape must be a list or a tuple.")
189189
nd = len(shape)
190190
typenum = dtype_to_typenum(dtype)
191+
if (typenum < 0):
192+
if typenum == -2:
193+
raise ValueError("Data type '" + str(dtype) + "' can only have native byteorder.")
194+
elif typenum == -1:
195+
raise ValueError("Data type '" + str(dtype) + "' is not understood.")
196+
raise TypeError(f"Expected string or a dtype object, got {type(dtype)}")
191197
itemsize = type_bytesize(typenum)
192198
if (itemsize < 1):
193-
raise TypeError("dtype=" + dtype + " is not supported.")
199+
raise TypeError("dtype=" + np.dtype(dtype).name + " is not supported.")
194200
# allocate host C-arrays for shape, strides
195201
err = _from_input_shape_strides(
196202
nd, shape, strides, itemsize, <char> ord(order),

0 commit comments

Comments
 (0)