Skip to content

Commit 70fdc5e

Browse files
Added support for consuming DLPack allocated on a sub-device
USM allocation must be bound to the default context. When producing DLPack, device_id is populated with id of the ancestor root device. Code remains functional on systems where default_context extension support is not enabled (e.g. Windows), but DLPack sharing is limited to allocation made on root devices only.
1 parent 2e2e35d commit 70fdc5e

File tree

1 file changed

+59
-31
lines changed

1 file changed

+59
-31
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
140140
cdef c_dpctl.SyclQueue ary_sycl_queue
141141
cdef c_dpctl.SyclDevice ary_sycl_device
142142
cdef DPCTLSyclDeviceRef pDRef = NULL
143+
cdef DPCTLSyclDeviceRef tDRef = NULL
143144
cdef DLManagedTensor *dlm_tensor = NULL
144145
cdef DLTensor *dl_tensor = NULL
145146
cdef int nd = usm_ary.get_ndim()
@@ -157,19 +158,42 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
157158
ary_sycl_queue = usm_ary.get_sycl_queue()
158159
ary_sycl_device = ary_sycl_queue.get_sycl_device()
159160

160-
# check that ary_sycl_device is a non-partitioned device
161-
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
162-
if pDRef is not NULL:
163-
DPCTLDevice_Delete(pDRef)
164-
raise DLPackCreationError(
165-
"to_dlpack_capsule: DLPack can only export arrays allocated on "
166-
"non-partitioned SYCL devices."
167-
)
168-
default_context = dpctl.SyclQueue(ary_sycl_device).sycl_context
169-
if not usm_ary.sycl_context == default_context:
161+
try:
162+
default_context = ary_sycl_device.sycl_platform.default_context
163+
except RuntimeError:
164+
# RT does not support default_context, e.g. Windows
165+
default_context = None
166+
if default_context is None:
167+
# check that ary_sycl_device is a non-partitioned device
168+
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
169+
if pDRef is not NULL:
170+
DPCTLDevice_Delete(pDRef)
171+
raise DLPackCreationError(
172+
"to_dlpack_capsule: DLPack can only export arrays allocated "
173+
"on non-partitioned SYCL devices on platforms where "
174+
"default_context oneAPI extension is not supported."
175+
)
176+
else:
177+
if not usm_ary.sycl_context == default_context:
178+
raise DLPackCreationError(
179+
"to_dlpack_capsule: DLPack can only export arrays based on USM "
180+
"allocations bound to a default platform SYCL context"
181+
)
182+
# Find the unpartitioned parent of the allocation device
183+
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
184+
if pDRef is not NULL:
185+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
186+
while tDRef is not NULL:
187+
DPCTLDevice_Delete(pDRef)
188+
pDRef = tDRef
189+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
190+
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
191+
192+
# Find ordinal number of the parent device
193+
device_id = ary_sycl_device.get_overall_ordinal()
194+
if device_id < 0:
170195
raise DLPackCreationError(
171-
"to_dlpack_capsule: DLPack can only export arrays based on USM "
172-
"allocations bound to a default platform SYCL context"
196+
"to_dlpack_capsule: failed to determine device_id"
173197
)
174198

175199
dlm_tensor = <DLManagedTensor *> stdlib.malloc(
@@ -192,14 +216,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
192216
for i in range(nd):
193217
shape_strides_ptr[nd + i] = strides_ptr[i]
194218

195-
device_id = ary_sycl_device.get_overall_ordinal()
196-
if device_id < 0:
197-
stdlib.free(shape_strides_ptr)
198-
stdlib.free(dlm_tensor)
199-
raise DLPackCreationError(
200-
"to_dlpack_capsule: failed to determine device_id"
201-
)
202-
203219
ary_dt = usm_ary.dtype
204220
ary_dtk = ary_dt.kind
205221
element_offset = usm_ary.get_offset()
@@ -278,15 +294,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
278294
success.
279295
Raises:
280296
TypeError: if argument is not a "dltensor" capsule.
281-
ValueError: if argument is "used_dltensor" capsule,
282-
if the USM pointer is not bound to the reconstructed
297+
ValueError: if argument is "used_dltensor" capsule
298+
BufferError: if the USM pointer is not bound to the reconstructed
283299
sycl context, or the DLPack's device_type is not supported
284300
by dpctl.
285301
"""
286302
cdef DLManagedTensor *dlm_tensor = NULL
287303
cdef bytes usm_type
288304
cdef size_t sz = 1
289305
cdef int i
306+
cdef int device_id = -1
290307
cdef int element_bytesize = 0
291308
cdef Py_ssize_t offset_min = 0
292309
cdef Py_ssize_t offset_max = 0
@@ -308,26 +325,37 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
308325
py_caps, "dltensor")
309326
# Verify that we can work with this device
310327
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
311-
q = dpctl.SyclQueue(str(<int>dlm_tensor.dl_tensor.device.device_id))
328+
device_id = dlm_tensor.dl_tensor.device.device_id
329+
root_device = dpctl.SyclDevice(str(<int>device_id))
330+
try:
331+
default_context = root_device.sycl_platform.default_context
332+
except RuntimeError:
333+
default_context = dpctl.SyclQueue(root_device).sycl_context
312334
if dlm_tensor.dl_tensor.data is NULL:
313335
usm_type = b"device"
336+
q = dpctl.SyclQueue(default_context, root_device)
314337
else:
315338
usm_type = c_dpmem._Memory.get_pointer_type(
316339
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
317-
<c_dpctl.SyclContext>q.sycl_context)
318-
if usm_type == b"unknown":
319-
raise ValueError(
320-
f"Data pointer in DLPack is not bound to default sycl "
321-
"context of device '{device_id}', translated to "
322-
"{q.sycl_device.filter_string}"
340+
<c_dpctl.SyclContext>default_context)
341+
if usm_type == b"unknown":
342+
raise BufferError(
343+
"Data pointer in DLPack is not bound to default sycl "
344+
f"context of device '{device_id}', translated to "
345+
f"{root_device.filter_string}"
346+
)
347+
alloc_device = c_dpmem._Memory.get_pointer_device(
348+
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
349+
<c_dpctl.SyclContext>default_context
323350
)
351+
q = dpctl.SyclQueue(default_context, alloc_device)
324352
if dlm_tensor.dl_tensor.dtype.bits % 8:
325-
raise ValueError(
353+
raise BufferError(
326354
"Can not import DLPack tensor whose element's "
327355
"bitsize is not a multiple of 8"
328356
)
329357
if dlm_tensor.dl_tensor.dtype.lanes != 1:
330-
raise ValueError(
358+
raise BufferError(
331359
"Can not import DLPack tensor with lanes != 1"
332360
)
333361
if dlm_tensor.dl_tensor.strides is NULL:

0 commit comments

Comments
 (0)