Skip to content

Commit 4530026

Browse files
Allow shape to be an integer in the constructor, e.g. usm_ndarray(10, 'd')
``` In [1]: import dpctl.tensor as dpt In [2]: dpt.usm_ndarray(10, 'd').shape Out[2]: (10,) ```
1 parent 0ee8e5f commit 4530026

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,17 @@ cdef class usm_ndarray:
130130
cdef Py_ssize_t _offset = offset
131131
cdef Py_ssize_t ary_min_displacement = 0
132132
cdef Py_ssize_t ary_max_displacement = 0
133+
cdef Py_ssize_t tmp = 0
133134
cdef char * data_ptr = NULL
134135

135136
self._reset()
136137
if (not isinstance(shape, (list, tuple))
137138
and not hasattr(shape, 'tolist')):
138-
raise TypeError("Argument shape must be a list of a tuple.")
139+
try:
140+
tmp = <Py_ssize_t> shape
141+
shape = [shape, ]
142+
except Exception:
143+
raise TypeError("Argument shape must be a list or a tuple.")
139144
nd = len(shape)
140145
typenum = dtype_to_typenum(dtype)
141146
itemsize = type_bytesize(typenum)

0 commit comments

Comments
 (0)