Skip to content

Commit aba58c4

Browse files
Add more dlpack tests (#1670)
* Fixture all_root_devices is to keep only 2 devices from each platform This is done to speed-up test suite execution on multi-GPU system, like Aurora. * Add tests for legacy/versions capsule import-export * Expand tests more to improve coverage Use F-contiguous array to both legacy and versioned capsule. Add tests for read-only array and for use of copy=True keyword for __dlpack__ call
1 parent 9afae01 commit aba58c4

File tree

1 file changed

+89
-4
lines changed

1 file changed

+89
-4
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import collections
1718
import ctypes
1819

1920
import pytest
2021
from helper import skip_if_dtype_not_supported
2122

2223
import dpctl
2324
import dpctl.tensor as dpt
25+
import dpctl.tensor._dlpack as _dlp
2426

2527
device_oneAPI = 14 # DLDeviceType.kDLOneAPI
2628

@@ -57,7 +59,18 @@ def typestr(request):
5759

5860
@pytest.fixture
5961
def all_root_devices():
60-
return dpctl.get_devices()
62+
"""
63+
Caches root devices. For the sake of speed
64+
of test suite execution, keep at most two
65+
devices from each platform
66+
"""
67+
devs = dpctl.get_devices()
68+
devs_per_platform = collections.defaultdict(list)
69+
for dev in devs:
70+
devs_per_platform[dev.sycl_platform].append(dev)
71+
72+
pruned = map(lambda li: li[:2], devs_per_platform.values())
73+
return sum(pruned, start=[])
6174

6275

6376
def test_dlpack_device(usm_type, all_root_devices):
@@ -216,9 +229,8 @@ def test_dlpack_from_subdevice():
216229
except dpctl.SyclSubDeviceCreationError:
217230
sdevs = None
218231
try:
219-
sdevs = (
220-
dev.create_sub_devices(partition=[1, 1]) if sdevs is None else sdevs
221-
)
232+
if sdevs is None:
233+
sdevs = dev.create_sub_devices(partition=[1, 1])
222234
except dpctl.SyclSubDeviceCreationError:
223235
pytest.skip("Default device can not be partitioned")
224236
assert isinstance(sdevs, list) and len(sdevs) > 0
@@ -234,3 +246,76 @@ def test_dlpack_from_subdevice():
234246
ar = dpt.arange(n, dtype=dpt.int32, sycl_queue=q)
235247
ar2 = dpt.from_dlpack(ar)
236248
assert ar2.sycl_device == sdevs[0]
249+
250+
251+
def test_legacy_dlpack_capsule():
252+
try:
253+
x = dpt.arange(100, dtype="i4")
254+
except dpctl.SyclDeviceCreationError:
255+
pytest.skip("No default device available")
256+
257+
legacy_ver = (0, 8)
258+
259+
cap = x.__dlpack__(max_version=legacy_ver)
260+
y = _dlp.from_dlpack_capsule(cap)
261+
del cap
262+
assert x._pointer == y._pointer
263+
264+
x2 = dpt.reshape(x, (10, 10)).mT
265+
cap = x2.__dlpack__(max_version=legacy_ver)
266+
y = _dlp.from_dlpack_capsule(cap)
267+
del cap
268+
assert x2._pointer == y._pointer
269+
del x2
270+
271+
x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F")
272+
cap = x2.__dlpack__(max_version=legacy_ver)
273+
y = _dlp.from_dlpack_capsule(cap)
274+
del cap
275+
assert x2._pointer == y._pointer
276+
277+
x3 = x[::-2]
278+
cap = x3.__dlpack__(max_version=legacy_ver)
279+
y = _dlp.from_dlpack_capsule(cap)
280+
assert x3._pointer == y._pointer
281+
del x3, y, x
282+
del cap
283+
284+
285+
def test_versioned_dlpack_capsule():
286+
try:
287+
x = dpt.arange(100, dtype="i4")
288+
except dpctl.SyclDeviceCreationError:
289+
pytest.skip("No default device available")
290+
291+
max_supported_ver = _dlp.get_build_dlpack_version()
292+
cap = x.__dlpack__(max_version=max_supported_ver)
293+
y = _dlp.from_dlpack_versioned_capsule(cap)
294+
del cap
295+
assert x._pointer == y._pointer
296+
297+
x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F")
298+
cap = x2.__dlpack__(max_version=max_supported_ver)
299+
y = _dlp.from_dlpack_versioned_capsule(cap)
300+
del cap
301+
assert x2._pointer == y._pointer
302+
del x2
303+
304+
x3 = x[::-2]
305+
cap = x3.__dlpack__(max_version=max_supported_ver)
306+
y = _dlp.from_dlpack_versioned_capsule(cap)
307+
assert x3._pointer == y._pointer
308+
del x3, y, x
309+
del cap
310+
311+
# read-only array
312+
x = dpt.arange(100, dtype="i4")
313+
x.flags["W"] = False
314+
cap = x.__dlpack__(max_version=max_supported_ver)
315+
y = _dlp.from_dlpack_versioned_capsule(cap)
316+
assert x._pointer == y._pointer
317+
318+
# read-only array, and copy
319+
cap = x.__dlpack__(max_version=max_supported_ver, copy=True)
320+
y = _dlp.from_dlpack_versioned_capsule(cap)
321+
assert x._pointer != y._pointer

0 commit comments

Comments
 (0)