Skip to content

Commit f55a730

Browse files
Merge pull request #556 from IntelPython/improve_usm_ndarray_coverage
Improved coverage of _types.pxi
2 parents 832350d + cd27e22 commit f55a730

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,17 @@ cdef class usm_ndarray:
130130
cdef Py_ssize_t _offset = offset
131131
cdef Py_ssize_t ary_min_displacement = 0
132132
cdef Py_ssize_t ary_max_displacement = 0
133+
cdef Py_ssize_t tmp = 0
133134
cdef char * data_ptr = NULL
134135

135136
self._reset()
136137
if (not isinstance(shape, (list, tuple))
137138
and not hasattr(shape, 'tolist')):
138-
raise TypeError("Argument shape must be a list of a tuple.")
139+
try:
140+
tmp = <Py_ssize_t> shape
141+
shape = [shape, ]
142+
except Exception:
143+
raise TypeError("Argument shape must be a list or a tuple.")
139144
nd = len(shape)
140145
typenum = dtype_to_typenum(dtype)
141146
itemsize = type_bytesize(typenum)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
(4, 5),
3939
(2, 5, 2),
4040
(2, 2, 2, 2, 2, 2, 2, 2),
41+
5,
4142
],
4243
)
4344
@pytest.mark.parametrize("usm_type", ["shared", "host", "device"])
@@ -74,13 +75,23 @@ def test_allocate_usm_ndarray(shape, usm_type):
7475
"f8",
7576
"c8",
7677
"c16",
78+
b"float32",
7779
np.dtype("d"),
7880
np.half,
7981
],
8082
)
8183
def test_dtypes(dtype):
8284
Xusm = dpt.usm_ndarray((1,), dtype=dtype)
8385
assert Xusm.itemsize == np.dtype(dtype).itemsize
86+
expected_fmt = (np.dtype(dtype).str)[1:]
87+
actual_fmt = Xusm.__sycl_usm_array_interface__["typestr"][1:]
88+
assert expected_fmt == actual_fmt
89+
90+
91+
@pytest.mark.parametrize("dtype", ["", ">f4", "invalid", 123])
92+
def test_dtypes_invalid(dtype):
93+
with pytest.raises((TypeError, ValueError)):
94+
dpt.usm_ndarray((1,), dtype=dtype)
8495

8596

8697
def test_properties():

0 commit comments

Comments
 (0)