Skip to content

Commit aea6f7f

Browse files
Merge pull request #1338 from IntelPython/fix-dlpack-tests
Fixed tests from_dlpack and _from_dlpack_strided for USM-host allocations
2 parents 6070755 + f02862b commit aea6f7f

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,11 @@ def test_from_dlpack(shape, typestr, usm_type):
127127
Y = dpt.from_dlpack(X)
128128
assert X.shape == Y.shape
129129
assert X.dtype == Y.dtype
130-
assert X.sycl_device == Y.sycl_device
131130
assert X.usm_type == Y.usm_type
132131
assert X._pointer == Y._pointer
132+
# we can only expect device to round-trip for USM-device and
133+
# USM-shared allocations, which are made for specific device
134+
assert (Y.usm_type == "host") or (X.sycl_device == Y.sycl_device)
133135
if Y.ndim:
134136
V = Y[::-1]
135137
W = dpt.from_dlpack(V)
@@ -149,9 +151,11 @@ def test_from_dlpack_strides(mod, typestr, usm_type):
149151
Y = dpt.from_dlpack(X)
150152
assert X.shape == Y.shape
151153
assert X.dtype == Y.dtype
152-
assert X.sycl_device == Y.sycl_device
153154
assert X.usm_type == Y.usm_type
154155
assert X._pointer == Y._pointer
156+
# we can only expect device to round-trip for USM-device and
157+
# USM-shared allocations, which are made for specific device
158+
assert (Y.usm_type == "host") or (X.sycl_device == Y.sycl_device)
155159
if Y.ndim:
156160
V = Y[::-1]
157161
W = dpt.from_dlpack(V)

0 commit comments

Comments
 (0)