Skip to content

Commit 0fa30ee

Browse files
Added support for consuming DLPack allocated on a sub-device
USM allocation must be bound to the default context. When producing DLPack, device_id is populated with id of the ancestor root device.
1 parent 2e2e35d commit 0fa30ee

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
140140
cdef c_dpctl.SyclQueue ary_sycl_queue
141141
cdef c_dpctl.SyclDevice ary_sycl_device
142142
cdef DPCTLSyclDeviceRef pDRef = NULL
143+
cdef DPCTLSyclDeviceRef tDRef = NULL
143144
cdef DLManagedTensor *dlm_tensor = NULL
144145
cdef DLTensor *dl_tensor = NULL
145146
cdef int nd = usm_ary.get_ndim()
@@ -157,20 +158,28 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
157158
ary_sycl_queue = usm_ary.get_sycl_queue()
158159
ary_sycl_device = ary_sycl_queue.get_sycl_device()
159160

160-
# check that ary_sycl_device is a non-partitioned device
161-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
162-
if pDRef is not NULL:
163-
DPCTLDevice_Delete(pDRef)
164-
raise DLPackCreationError(
165-
"to_dlpack_capsule: DLPack can only export arrays allocated on "
166-
"non-partitioned SYCL devices."
167-
)
168-
default_context = dpctl.SyclQueue(ary_sycl_device).sycl_context
161+
default_context = ary_sycl_device.sycl_platform.default_context
169162
if not usm_ary.sycl_context == default_context:
170163
raise DLPackCreationError(
171164
"to_dlpack_capsule: DLPack can only export arrays based on USM "
172165
"allocations bound to a default platform SYCL context"
173166
)
167+
# Find the unpartitioned parent of the allocation device
168+
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
169+
if pDRef is not NULL:
170+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
171+
while tDRef is not NULL:
172+
DPCTLDevice_Delete(pDRef)
173+
pDRef = tDRef
174+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
175+
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
176+
177+
# Find ordinal number of the parent device
178+
device_id = ary_sycl_device.get_overall_ordinal()
179+
if device_id < 0:
180+
raise DLPackCreationError(
181+
"to_dlpack_capsule: failed to determine device_id"
182+
)
174183

175184
dlm_tensor = <DLManagedTensor *> stdlib.malloc(
176185
sizeof(DLManagedTensor))
@@ -192,14 +201,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
192201
for i in range(nd):
193202
shape_strides_ptr[nd + i] = strides_ptr[i]
194203

195-
device_id = ary_sycl_device.get_overall_ordinal()
196-
if device_id < 0:
197-
stdlib.free(shape_strides_ptr)
198-
stdlib.free(dlm_tensor)
199-
raise DLPackCreationError(
200-
"to_dlpack_capsule: failed to determine device_id"
201-
)
202-
203204
ary_dt = usm_ary.dtype
204205
ary_dtk = ary_dt.kind
205206
element_offset = usm_ary.get_offset()
@@ -278,15 +279,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
278279
success.
279280
Raises:
280281
TypeError: if argument is not a "dltensor" capsule.
281-
ValueError: if argument is "used_dltensor" capsule,
282-
if the USM pointer is not bound to the reconstructed
282+
ValueError: if argument is "used_dltensor" capsule
283+
BufferError: if the USM pointer is not bound to the reconstructed
283284
sycl context, or the DLPack's device_type is not supported
284285
by dpctl.
285286
"""
286287
cdef DLManagedTensor *dlm_tensor = NULL
287288
cdef bytes usm_type
288289
cdef size_t sz = 1
289290
cdef int i
291+
cdef int device_id = -1
290292
cdef int element_bytesize = 0
291293
cdef Py_ssize_t offset_min = 0
292294
cdef Py_ssize_t offset_max = 0
@@ -308,26 +310,34 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
308310
py_caps, "dltensor")
309311
# Verify that we can work with this device
310312
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
311-
q = dpctl.SyclQueue(str(<int>dlm_tensor.dl_tensor.device.device_id))
313+
device_id = dlm_tensor.dl_tensor.device.device_id
314+
root_device = dpctl.SyclDevice(str(<int>device_id))
315+
default_context = root_device.sycl_platform.default_context
312316
if dlm_tensor.dl_tensor.data is NULL:
313317
usm_type = b"device"
318+
q = dpctl.SyclQueue(default_context, root_device)
314319
else:
315320
usm_type = c_dpmem._Memory.get_pointer_type(
316321
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
317-
<c_dpctl.SyclContext>q.sycl_context)
318-
if usm_type == b"unknown":
319-
raise ValueError(
320-
f"Data pointer in DLPack is not bound to default sycl "
321-
"context of device '{device_id}', translated to "
322-
"{q.sycl_device.filter_string}"
322+
<c_dpctl.SyclContext>default_context)
323+
if usm_type == b"unknown":
324+
raise BufferError(
325+
"Data pointer in DLPack is not bound to default sycl "
326+
f"context of device '{device_id}', translated to "
327+
f"{root_device.filter_string}"
328+
)
329+
alloc_device = c_dpmem._Memory.get_pointer_device(
330+
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
331+
<c_dpctl.SyclContext>default_context
323332
)
333+
q = dpctl.SyclQueue(default_context, alloc_device)
324334
if dlm_tensor.dl_tensor.dtype.bits % 8:
325-
raise ValueError(
335+
raise BufferError(
326336
"Can not import DLPack tensor whose element's "
327337
"bitsize is not a multiple of 8"
328338
)
329339
if dlm_tensor.dl_tensor.dtype.lanes != 1:
330-
raise ValueError(
340+
raise BufferError(
331341
"Can not import DLPack tensor with lanes != 1"
332342
)
333343
if dlm_tensor.dl_tensor.strides is NULL:

0 commit comments

Comments
 (0)