Skip to content

Commit 9d5a687

Browse files
Add test case for pre-existing buffer, and default dtype
1 parent ea40d71 commit 9d5a687

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,16 @@ def test_ctor_buffer_kwarg():
462462
with pytest.raises(ValueError):
463463
dpt.usm_ndarray(10, buffer="invalid_param")
464464
Xusm = dpt.usm_ndarray((10, 5), dtype="c8")
465+
Xusm[...] = 1
465466
X2 = dpt.usm_ndarray(Xusm.shape, buffer=Xusm, dtype=Xusm.dtype)
466-
assert np.array_equal(
467-
Xusm.usm_data.copy_to_host(), X2.usm_data.copy_to_host()
468-
)
467+
Horig_copy = Xusm.usm_data.copy_to_host()
468+
H2_copy = X2.usm_data.copy_to_host()
469+
assert np.array_equal(Horig_copy, H2_copy)
469470
with pytest.raises(ValueError):
470471
dpt.usm_ndarray(10, dtype="i4", buffer=dict())
472+
# use device-specific default fp data type
473+
X3 = dpt.usm_ndarray(Xusm.shape, buffer=Xusm)
474+
assert np.array_equal(Horig_copy, X3.usm_data.copy_to_host())
471475

472476

473477
def test_usm_ndarray_props():

0 commit comments

Comments
 (0)