Skip to content

Commit dedab4e

Browse files
improved coverage of usm_ndarray.__cinit__
1 parent 63693e9 commit dedab4e

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,21 @@ def test_slicing_basic():
251251
Xusm[:, -128]
252252
with pytest.raises(TypeError):
253253
Xusm[{1, 2, 3, 4, 5, 6, 7}]
254+
255+
256+
def test_ctor_invalid_shape():
257+
with pytest.raises(TypeError):
258+
dpt.usm_ndarray(dict())
259+
260+
261+
def test_ctor_buffer_kwarg():
262+
dpt.usm_ndarray(10, buffer=b"device")
263+
with pytest.raises(ValueError):
264+
dpt.usm_ndarray(10, buffer="invalid_param")
265+
Xusm = dpt.usm_ndarray((10, 5), dtype="c16")
266+
X2 = dpt.usm_ndarray(Xusm.shape, buffer=Xusm, dtype=Xusm.dtype)
267+
assert np.array_equal(
268+
Xusm.usm_data.copy_to_host(), X2.usm_data.copy_to_host()
269+
)
270+
with pytest.raises(ValueError):
271+
dpt.usm_ndarray(10, buffer=dict())

0 commit comments

Comments
 (0)