Skip to content

Add more dlpack tests #1670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import ctypes

import pytest
from helper import skip_if_dtype_not_supported

import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._dlpack as _dlp

device_oneAPI = 14 # DLDeviceType.kDLOneAPI

Expand Down Expand Up @@ -57,7 +59,18 @@ def typestr(request):

@pytest.fixture
def all_root_devices():
return dpctl.get_devices()
"""
Caches root devices. For the sake of speed
of test suite execution, keep at most two
devices from each platform
"""
devs = dpctl.get_devices()
devs_per_platform = collections.defaultdict(list)
for dev in devs:
devs_per_platform[dev.sycl_platform].append(dev)

pruned = map(lambda li: li[:2], devs_per_platform.values())
return sum(pruned, start=[])


def test_dlpack_device(usm_type, all_root_devices):
Expand Down Expand Up @@ -216,9 +229,8 @@ def test_dlpack_from_subdevice():
except dpctl.SyclSubDeviceCreationError:
sdevs = None
try:
sdevs = (
dev.create_sub_devices(partition=[1, 1]) if sdevs is None else sdevs
)
if sdevs is None:
sdevs = dev.create_sub_devices(partition=[1, 1])
except dpctl.SyclSubDeviceCreationError:
pytest.skip("Default device can not be partitioned")
assert isinstance(sdevs, list) and len(sdevs) > 0
Expand All @@ -234,3 +246,76 @@ def test_dlpack_from_subdevice():
ar = dpt.arange(n, dtype=dpt.int32, sycl_queue=q)
ar2 = dpt.from_dlpack(ar)
assert ar2.sycl_device == sdevs[0]


def test_legacy_dlpack_capsule():
try:
x = dpt.arange(100, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

legacy_ver = (0, 8)

cap = x.__dlpack__(max_version=legacy_ver)
y = _dlp.from_dlpack_capsule(cap)
del cap
assert x._pointer == y._pointer

x2 = dpt.reshape(x, (10, 10)).mT
cap = x2.__dlpack__(max_version=legacy_ver)
y = _dlp.from_dlpack_capsule(cap)
del cap
assert x2._pointer == y._pointer
del x2

x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F")
cap = x2.__dlpack__(max_version=legacy_ver)
y = _dlp.from_dlpack_capsule(cap)
del cap
assert x2._pointer == y._pointer

x3 = x[::-2]
cap = x3.__dlpack__(max_version=legacy_ver)
y = _dlp.from_dlpack_capsule(cap)
assert x3._pointer == y._pointer
del x3, y, x
del cap


def test_versioned_dlpack_capsule():
try:
x = dpt.arange(100, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

max_supported_ver = _dlp.get_build_dlpack_version()
cap = x.__dlpack__(max_version=max_supported_ver)
y = _dlp.from_dlpack_versioned_capsule(cap)
del cap
assert x._pointer == y._pointer

x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F")
cap = x2.__dlpack__(max_version=max_supported_ver)
y = _dlp.from_dlpack_versioned_capsule(cap)
del cap
assert x2._pointer == y._pointer
del x2

x3 = x[::-2]
cap = x3.__dlpack__(max_version=max_supported_ver)
y = _dlp.from_dlpack_versioned_capsule(cap)
assert x3._pointer == y._pointer
del x3, y, x
del cap

# read-only array
x = dpt.arange(100, dtype="i4")
x.flags["W"] = False
cap = x.__dlpack__(max_version=max_supported_ver)
y = _dlp.from_dlpack_versioned_capsule(cap)
assert x._pointer == y._pointer

# read-only array, and copy
cap = x.__dlpack__(max_version=max_supported_ver, copy=True)
y = _dlp.from_dlpack_versioned_capsule(cap)
assert x._pointer != y._pointer