@@ -202,6 +202,57 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
202
202
return dev.get_overall_ordinal()
203
203
204
204
205
+ cdef int get_array_dlpack_device_id(
206
+ usm_ndarray usm_ary
207
+ ) except * :
208
+ """ Finds ordinal number of the parent of device where array
209
+ was allocated.
210
+ """
211
+ cdef c_dpctl.SyclQueue ary_sycl_queue
212
+ cdef c_dpctl.SyclDevice ary_sycl_device
213
+ cdef DPCTLSyclDeviceRef pDRef = NULL
214
+ cdef DPCTLSyclDeviceRef tDRef = NULL
215
+ cdef int device_id = - 1
216
+
217
+ ary_sycl_queue = usm_ary.get_sycl_queue()
218
+ ary_sycl_device = ary_sycl_queue.get_sycl_device()
219
+
220
+ default_context = _get_default_context(ary_sycl_device)
221
+ if default_context is None :
222
+ # check that ary_sycl_device is a non-partitioned device
223
+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
224
+ if pDRef is not NULL :
225
+ DPCTLDevice_Delete(pDRef)
226
+ raise DLPackCreationError(
227
+ " to_dlpack_capsule: DLPack can only export arrays allocated "
228
+ " on non-partitioned SYCL devices on platforms where "
229
+ " default_context oneAPI extension is not supported."
230
+ )
231
+ else :
232
+ if not usm_ary.sycl_context == default_context:
233
+ raise DLPackCreationError(
234
+ " to_dlpack_capsule: DLPack can only export arrays based on USM "
235
+ " allocations bound to a default platform SYCL context"
236
+ )
237
+ # Find the unpartitioned parent of the allocation device
238
+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
239
+ if pDRef is not NULL :
240
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
241
+ while tDRef is not NULL :
242
+ DPCTLDevice_Delete(pDRef)
243
+ pDRef = tDRef
244
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
245
+ ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
246
+
247
+ device_id = ary_sycl_device.get_overall_ordinal()
248
+ if device_id < 0 :
249
+ raise DLPackCreationError(
250
+ " to_dlpack_capsule: failed to determine device_id"
251
+ )
252
+
253
+ return device_id
254
+
255
+
205
256
cpdef to_dlpack_capsule(usm_ndarray usm_ary):
206
257
"""
207
258
to_dlpack_capsule(usm_ary)
@@ -225,10 +276,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
225
276
ValueError: when array elements data type could not be represented
226
277
in ``DLManagedTensor``.
227
278
"""
228
- cdef c_dpctl.SyclQueue ary_sycl_queue
229
- cdef c_dpctl.SyclDevice ary_sycl_device
230
- cdef DPCTLSyclDeviceRef pDRef = NULL
231
- cdef DPCTLSyclDeviceRef tDRef = NULL
232
279
cdef DLManagedTensor * dlm_tensor = NULL
233
280
cdef DLTensor * dl_tensor = NULL
234
281
cdef int nd = usm_ary.get_ndim()
@@ -245,38 +292,9 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
245
292
cdef Py_ssize_t si = 1
246
293
247
294
ary_base = usm_ary.get_base()
248
- ary_sycl_queue = usm_ary.get_sycl_queue()
249
- ary_sycl_device = ary_sycl_queue.get_sycl_device()
250
295
251
- default_context = _get_default_context(ary_sycl_device)
252
- if default_context is None :
253
- # check that ary_sycl_device is a non-partitioned device
254
- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
255
- if pDRef is not NULL :
256
- DPCTLDevice_Delete(pDRef)
257
- raise DLPackCreationError(
258
- " to_dlpack_capsule: DLPack can only export arrays allocated "
259
- " on non-partitioned SYCL devices on platforms where "
260
- " default_context oneAPI extension is not supported."
261
- )
262
- else :
263
- if not usm_ary.sycl_context == default_context:
264
- raise DLPackCreationError(
265
- " to_dlpack_capsule: DLPack can only export arrays based on USM "
266
- " allocations bound to a default platform SYCL context"
267
- )
268
- # Find the unpartitioned parent of the allocation device
269
- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
270
- if pDRef is not NULL :
271
- tDRef = DPCTLDevice_GetParentDevice(pDRef)
272
- while tDRef is not NULL :
273
- DPCTLDevice_Delete(pDRef)
274
- pDRef = tDRef
275
- tDRef = DPCTLDevice_GetParentDevice(pDRef)
276
- ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
296
+ device_id = get_array_dlpack_device_id(usm_ary)
277
297
278
- # Find ordinal number of the parent device
279
- device_id = ary_sycl_device.get_overall_ordinal()
280
298
if device_id < 0 :
281
299
raise DLPackCreationError(
282
300
" to_dlpack_capsule: failed to determine device_id"
@@ -376,10 +394,6 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
376
394
ValueError: when array elements data type could not be represented
377
395
in ``DLManagedTensorVersioned``.
378
396
"""
379
- cdef c_dpctl.SyclQueue ary_sycl_queue
380
- cdef c_dpctl.SyclDevice ary_sycl_device
381
- cdef DPCTLSyclDeviceRef pDRef = NULL
382
- cdef DPCTLSyclDeviceRef tDRef = NULL
383
397
cdef DLManagedTensorVersioned * dlmv_tensor = NULL
384
398
cdef DLTensor * dl_tensor = NULL
385
399
cdef uint32_t dlmv_flags = 0
@@ -397,43 +411,9 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
397
411
cdef Py_ssize_t si = 1
398
412
399
413
ary_base = usm_ary.get_base()
400
- ary_sycl_queue = usm_ary.get_sycl_queue()
401
- ary_sycl_device = ary_sycl_queue.get_sycl_device()
402
-
403
- default_context = _get_default_context(ary_sycl_device)
404
- if default_context is None :
405
- # check that ary_sycl_device is a non-partitioned device
406
- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
407
- if pDRef is not NULL :
408
- DPCTLDevice_Delete(pDRef)
409
- raise DLPackCreationError(
410
- " to_dlpack_versioned_capsule: DLPack can only export arrays "
411
- " allocated on non-partitioned SYCL devices on platforms "
412
- " where default_context oneAPI extension is not supported."
413
- )
414
- else :
415
- if not usm_ary.sycl_context == default_context:
416
- raise DLPackCreationError(
417
- " to_dlpack_versioned_capsule: DLPack can only export arrays "
418
- " based on USM allocations bound to a default platform SYCL "
419
- " context"
420
- )
421
- # Find the unpartitioned parent of the allocation device
422
- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
423
- if pDRef is not NULL :
424
- tDRef = DPCTLDevice_GetParentDevice(pDRef)
425
- while tDRef is not NULL :
426
- DPCTLDevice_Delete(pDRef)
427
- pDRef = tDRef
428
- tDRef = DPCTLDevice_GetParentDevice(pDRef)
429
- ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
430
414
431
415
# Find ordinal number of the parent device
432
- device_id = ary_sycl_device.get_overall_ordinal()
433
- if device_id < 0 :
434
- raise DLPackCreationError(
435
- " to_dlpack_versioned_capsule: failed to determine device_id"
436
- )
416
+ device_id = get_array_dlpack_device_id(usm_ary)
437
417
438
418
dlmv_tensor = < DLManagedTensorVersioned * > stdlib.malloc(
439
419
sizeof(DLManagedTensorVersioned))
0 commit comments