Skip to content

Commit c8d469c

Browse files
Adding test_usm_ndarray_dlpack
Test should anticipate that dlpack roundtripping changes bool dtype to uint8 Adding test for `from_dlpack` input validation
1 parent 288ac71 commit c8d469c

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2021 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import pytest
20+
21+
import dpctl
22+
import dpctl.tensor as dpt
23+
24+
device_oneAPI = 14 # DLDeviceType.kDLOneAPI
25+
26+
_usm_types_list = ["shared", "device", "host"]
27+
28+
29+
@pytest.fixture(params=_usm_types_list)
30+
def usm_type(request):
31+
return request.param
32+
33+
34+
_typestrs_list = [
35+
"b1",
36+
"u1",
37+
"i1",
38+
"u2",
39+
"i2",
40+
"u4",
41+
"i4",
42+
"u8",
43+
"i8",
44+
"f2",
45+
"f4",
46+
"f8",
47+
"c8",
48+
"c16",
49+
]
50+
51+
52+
@pytest.fixture(params=_typestrs_list)
53+
def typestr(request):
54+
return request.param
55+
56+
57+
def test_dlpack_device(usm_type):
58+
all_root_devices = dpctl.get_devices()
59+
for sycl_dev in all_root_devices:
60+
X = dpt.empty((64,), dtype="u1", usm_type=usm_type, device=sycl_dev)
61+
dev = X.__dlpack_device__()
62+
assert type(dev) is tuple
63+
assert len(dev) == 2
64+
assert dev[0] == device_oneAPI
65+
assert sycl_dev == all_root_devices[dev[1]]
66+
67+
68+
def test_dlpack_exporter(typestr, usm_type):
69+
caps_fn = ctypes.pythonapi.PyCapsule_IsValid
70+
caps_fn.restype = bool
71+
caps_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
72+
all_root_devices = dpctl.get_devices()
73+
for sycl_dev in all_root_devices:
74+
X = dpt.empty((64,), dtype=typestr, usm_type=usm_type, device=sycl_dev)
75+
caps = X.__dlpack__()
76+
assert caps_fn(caps, b"dltensor")
77+
Y = X[::2]
78+
caps2 = Y.__dlpack__()
79+
assert caps_fn(caps2, b"dltensor")
80+
81+
82+
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
83+
def test_from_dlpack(shape, typestr, usm_type):
84+
all_root_devices = dpctl.get_devices()
85+
for sycl_dev in all_root_devices:
86+
X = dpt.empty(shape, dtype=typestr, usm_type=usm_type, device=sycl_dev)
87+
Y = dpt.from_dlpack(X)
88+
assert X.shape == Y.shape
89+
assert X.dtype == Y.dtype or (
90+
str(X.dtype) == "bool" and str(Y.dtype) == "uint8"
91+
)
92+
assert X.sycl_device == Y.sycl_device
93+
assert X.usm_type == Y.usm_type
94+
assert X._pointer == Y._pointer
95+
if Y.ndim:
96+
V = Y[::-1]
97+
W = dpt.from_dlpack(V)
98+
assert V.strides == W.strides
99+
100+
101+
def test_from_dlpack_input_validation():
102+
vstr = dpt._dlpack.get_build_dlpack_version()
103+
assert type(vstr) is str
104+
with pytest.raises(TypeError):
105+
dpt.from_dlpack(None)
106+
107+
class DummyWithProperty:
108+
@property
109+
def __dlpack__(self):
110+
return None
111+
112+
with pytest.raises(TypeError):
113+
dpt.from_dlpack(DummyWithProperty())
114+
115+
class DummyWithMethod:
116+
def __dlpack__(self):
117+
return None
118+
119+
with pytest.raises(TypeError):
120+
dpt.from_dlpack(DummyWithMethod())

0 commit comments

Comments
 (0)