14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
+ import collections
17
18
import ctypes
18
19
19
20
import pytest
20
21
from helper import skip_if_dtype_not_supported
21
22
22
23
import dpctl
23
24
import dpctl .tensor as dpt
25
+ import dpctl .tensor ._dlpack as _dlp
24
26
25
27
device_oneAPI = 14 # DLDeviceType.kDLOneAPI
26
28
@@ -57,7 +59,18 @@ def typestr(request):
57
59
58
60
@pytest .fixture
59
61
def all_root_devices ():
60
- return dpctl .get_devices ()
62
+ """
63
+ Caches root devices. For the sake of speed
64
+ of test suite execution, keep at most two
65
+ devices from each platform
66
+ """
67
+ devs = dpctl .get_devices ()
68
+ devs_per_platform = collections .defaultdict (list )
69
+ for dev in devs :
70
+ devs_per_platform [dev .sycl_platform ].append (dev )
71
+
72
+ pruned = map (lambda li : li [:2 ], devs_per_platform .values ())
73
+ return sum (pruned , start = [])
61
74
62
75
63
76
def test_dlpack_device (usm_type , all_root_devices ):
@@ -216,9 +229,8 @@ def test_dlpack_from_subdevice():
216
229
except dpctl .SyclSubDeviceCreationError :
217
230
sdevs = None
218
231
try :
219
- sdevs = (
220
- dev .create_sub_devices (partition = [1 , 1 ]) if sdevs is None else sdevs
221
- )
232
+ if sdevs is None :
233
+ sdevs = dev .create_sub_devices (partition = [1 , 1 ])
222
234
except dpctl .SyclSubDeviceCreationError :
223
235
pytest .skip ("Default device can not be partitioned" )
224
236
assert isinstance (sdevs , list ) and len (sdevs ) > 0
@@ -234,3 +246,76 @@ def test_dlpack_from_subdevice():
234
246
ar = dpt .arange (n , dtype = dpt .int32 , sycl_queue = q )
235
247
ar2 = dpt .from_dlpack (ar )
236
248
assert ar2 .sycl_device == sdevs [0 ]
249
+
250
+
251
+ def test_legacy_dlpack_capsule ():
252
+ try :
253
+ x = dpt .arange (100 , dtype = "i4" )
254
+ except dpctl .SyclDeviceCreationError :
255
+ pytest .skip ("No default device available" )
256
+
257
+ legacy_ver = (0 , 8 )
258
+
259
+ cap = x .__dlpack__ (max_version = legacy_ver )
260
+ y = _dlp .from_dlpack_capsule (cap )
261
+ del cap
262
+ assert x ._pointer == y ._pointer
263
+
264
+ x2 = dpt .reshape (x , (10 , 10 )).mT
265
+ cap = x2 .__dlpack__ (max_version = legacy_ver )
266
+ y = _dlp .from_dlpack_capsule (cap )
267
+ del cap
268
+ assert x2 ._pointer == y ._pointer
269
+ del x2
270
+
271
+ x2 = dpt .asarray (dpt .reshape (x , (10 , 10 )), order = "F" )
272
+ cap = x2 .__dlpack__ (max_version = legacy_ver )
273
+ y = _dlp .from_dlpack_capsule (cap )
274
+ del cap
275
+ assert x2 ._pointer == y ._pointer
276
+
277
+ x3 = x [::- 2 ]
278
+ cap = x3 .__dlpack__ (max_version = legacy_ver )
279
+ y = _dlp .from_dlpack_capsule (cap )
280
+ assert x3 ._pointer == y ._pointer
281
+ del x3 , y , x
282
+ del cap
283
+
284
+
285
+ def test_versioned_dlpack_capsule ():
286
+ try :
287
+ x = dpt .arange (100 , dtype = "i4" )
288
+ except dpctl .SyclDeviceCreationError :
289
+ pytest .skip ("No default device available" )
290
+
291
+ max_supported_ver = _dlp .get_build_dlpack_version ()
292
+ cap = x .__dlpack__ (max_version = max_supported_ver )
293
+ y = _dlp .from_dlpack_versioned_capsule (cap )
294
+ del cap
295
+ assert x ._pointer == y ._pointer
296
+
297
+ x2 = dpt .asarray (dpt .reshape (x , (10 , 10 )), order = "F" )
298
+ cap = x2 .__dlpack__ (max_version = max_supported_ver )
299
+ y = _dlp .from_dlpack_versioned_capsule (cap )
300
+ del cap
301
+ assert x2 ._pointer == y ._pointer
302
+ del x2
303
+
304
+ x3 = x [::- 2 ]
305
+ cap = x3 .__dlpack__ (max_version = max_supported_ver )
306
+ y = _dlp .from_dlpack_versioned_capsule (cap )
307
+ assert x3 ._pointer == y ._pointer
308
+ del x3 , y , x
309
+ del cap
310
+
311
+ # read-only array
312
+ x = dpt .arange (100 , dtype = "i4" )
313
+ x .flags ["W" ] = False
314
+ cap = x .__dlpack__ (max_version = max_supported_ver )
315
+ y = _dlp .from_dlpack_versioned_capsule (cap )
316
+ assert x ._pointer == y ._pointer
317
+
318
+ # read-only array, and copy
319
+ cap = x .__dlpack__ (max_version = max_supported_ver , copy = True )
320
+ y = _dlp .from_dlpack_versioned_capsule (cap )
321
+ assert x ._pointer != y ._pointer
0 commit comments