Skip to content

Commit 1205b0b

Browse files
Merge pull request #1671 from IntelPython/factor-out-device-id-computation
2 parents 85f12d0 + 470bb7c commit 1205b0b

File tree

1 file changed

+53
-73
lines changed

1 file changed

+53
-73
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 53 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,57 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
202202
return dev.get_overall_ordinal()
203203

204204

205+
cdef int get_array_dlpack_device_id(
206+
usm_ndarray usm_ary
207+
) except *:
208+
"""Finds ordinal number of the parent of device where array
209+
was allocated.
210+
"""
211+
cdef c_dpctl.SyclQueue ary_sycl_queue
212+
cdef c_dpctl.SyclDevice ary_sycl_device
213+
cdef DPCTLSyclDeviceRef pDRef = NULL
214+
cdef DPCTLSyclDeviceRef tDRef = NULL
215+
cdef int device_id = -1
216+
217+
ary_sycl_queue = usm_ary.get_sycl_queue()
218+
ary_sycl_device = ary_sycl_queue.get_sycl_device()
219+
220+
default_context = _get_default_context(ary_sycl_device)
221+
if default_context is None:
222+
# check that ary_sycl_device is a non-partitioned device
223+
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
224+
if pDRef is not NULL:
225+
DPCTLDevice_Delete(pDRef)
226+
raise DLPackCreationError(
227+
"to_dlpack_capsule: DLPack can only export arrays allocated "
228+
"on non-partitioned SYCL devices on platforms where "
229+
"default_context oneAPI extension is not supported."
230+
)
231+
else:
232+
if not usm_ary.sycl_context == default_context:
233+
raise DLPackCreationError(
234+
"to_dlpack_capsule: DLPack can only export arrays based on USM "
235+
"allocations bound to a default platform SYCL context"
236+
)
237+
# Find the unpartitioned parent of the allocation device
238+
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
239+
if pDRef is not NULL:
240+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
241+
while tDRef is not NULL:
242+
DPCTLDevice_Delete(pDRef)
243+
pDRef = tDRef
244+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
245+
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
246+
247+
device_id = ary_sycl_device.get_overall_ordinal()
248+
if device_id < 0:
249+
raise DLPackCreationError(
250+
"to_dlpack_capsule: failed to determine device_id"
251+
)
252+
253+
return device_id
254+
255+
205256
cpdef to_dlpack_capsule(usm_ndarray usm_ary):
206257
"""
207258
to_dlpack_capsule(usm_ary)
@@ -225,10 +276,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
225276
ValueError: when array elements data type could not be represented
226277
in ``DLManagedTensor``.
227278
"""
228-
cdef c_dpctl.SyclQueue ary_sycl_queue
229-
cdef c_dpctl.SyclDevice ary_sycl_device
230-
cdef DPCTLSyclDeviceRef pDRef = NULL
231-
cdef DPCTLSyclDeviceRef tDRef = NULL
232279
cdef DLManagedTensor *dlm_tensor = NULL
233280
cdef DLTensor *dl_tensor = NULL
234281
cdef int nd = usm_ary.get_ndim()
@@ -245,38 +292,9 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
245292
cdef Py_ssize_t si = 1
246293

247294
ary_base = usm_ary.get_base()
248-
ary_sycl_queue = usm_ary.get_sycl_queue()
249-
ary_sycl_device = ary_sycl_queue.get_sycl_device()
250295

251-
default_context = _get_default_context(ary_sycl_device)
252-
if default_context is None:
253-
# check that ary_sycl_device is a non-partitioned device
254-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
255-
if pDRef is not NULL:
256-
DPCTLDevice_Delete(pDRef)
257-
raise DLPackCreationError(
258-
"to_dlpack_capsule: DLPack can only export arrays allocated "
259-
"on non-partitioned SYCL devices on platforms where "
260-
"default_context oneAPI extension is not supported."
261-
)
262-
else:
263-
if not usm_ary.sycl_context == default_context:
264-
raise DLPackCreationError(
265-
"to_dlpack_capsule: DLPack can only export arrays based on USM "
266-
"allocations bound to a default platform SYCL context"
267-
)
268-
# Find the unpartitioned parent of the allocation device
269-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
270-
if pDRef is not NULL:
271-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
272-
while tDRef is not NULL:
273-
DPCTLDevice_Delete(pDRef)
274-
pDRef = tDRef
275-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
276-
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
296+
device_id = get_array_dlpack_device_id(usm_ary)
277297

278-
# Find ordinal number of the parent device
279-
device_id = ary_sycl_device.get_overall_ordinal()
280298
if device_id < 0:
281299
raise DLPackCreationError(
282300
"to_dlpack_capsule: failed to determine device_id"
@@ -376,10 +394,6 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
376394
ValueError: when array elements data type could not be represented
377395
in ``DLManagedTensorVersioned``.
378396
"""
379-
cdef c_dpctl.SyclQueue ary_sycl_queue
380-
cdef c_dpctl.SyclDevice ary_sycl_device
381-
cdef DPCTLSyclDeviceRef pDRef = NULL
382-
cdef DPCTLSyclDeviceRef tDRef = NULL
383397
cdef DLManagedTensorVersioned *dlmv_tensor = NULL
384398
cdef DLTensor *dl_tensor = NULL
385399
cdef uint32_t dlmv_flags = 0
@@ -397,43 +411,9 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
397411
cdef Py_ssize_t si = 1
398412

399413
ary_base = usm_ary.get_base()
400-
ary_sycl_queue = usm_ary.get_sycl_queue()
401-
ary_sycl_device = ary_sycl_queue.get_sycl_device()
402-
403-
default_context = _get_default_context(ary_sycl_device)
404-
if default_context is None:
405-
# check that ary_sycl_device is a non-partitioned device
406-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
407-
if pDRef is not NULL:
408-
DPCTLDevice_Delete(pDRef)
409-
raise DLPackCreationError(
410-
"to_dlpack_versioned_capsule: DLPack can only export arrays "
411-
"allocated on non-partitioned SYCL devices on platforms "
412-
"where default_context oneAPI extension is not supported."
413-
)
414-
else:
415-
if not usm_ary.sycl_context == default_context:
416-
raise DLPackCreationError(
417-
"to_dlpack_versioned_capsule: DLPack can only export arrays "
418-
"based on USM allocations bound to a default platform SYCL "
419-
"context"
420-
)
421-
# Find the unpartitioned parent of the allocation device
422-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
423-
if pDRef is not NULL:
424-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
425-
while tDRef is not NULL:
426-
DPCTLDevice_Delete(pDRef)
427-
pDRef = tDRef
428-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
429-
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
430414

431415
# Find ordinal number of the parent device
432-
device_id = ary_sycl_device.get_overall_ordinal()
433-
if device_id < 0:
434-
raise DLPackCreationError(
435-
"to_dlpack_versioned_capsule: failed to determine device_id"
436-
)
416+
device_id = get_array_dlpack_device_id(usm_ary)
437417

438418
dlmv_tensor = <DLManagedTensorVersioned *> stdlib.malloc(
439419
sizeof(DLManagedTensorVersioned))

0 commit comments

Comments
 (0)