@@ -140,6 +140,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
140
140
cdef c_dpctl.SyclQueue ary_sycl_queue
141
141
cdef c_dpctl.SyclDevice ary_sycl_device
142
142
cdef DPCTLSyclDeviceRef pDRef = NULL
143
+ cdef DPCTLSyclDeviceRef tDRef = NULL
143
144
cdef DLManagedTensor * dlm_tensor = NULL
144
145
cdef DLTensor * dl_tensor = NULL
145
146
cdef int nd = usm_ary.get_ndim()
@@ -157,20 +158,28 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
157
158
ary_sycl_queue = usm_ary.get_sycl_queue()
158
159
ary_sycl_device = ary_sycl_queue.get_sycl_device()
159
160
160
- # check that ary_sycl_device is a non-partitioned device
161
- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
162
- if pDRef is not NULL :
163
- DPCTLDevice_Delete(pDRef)
164
- raise DLPackCreationError(
165
- " to_dlpack_capsule: DLPack can only export arrays allocated on "
166
- " non-partitioned SYCL devices."
167
- )
168
- default_context = dpctl.SyclQueue(ary_sycl_device).sycl_context
161
+ default_context = ary_sycl_device.sycl_platform.default_context
169
162
if not usm_ary.sycl_context == default_context:
170
163
raise DLPackCreationError(
171
164
" to_dlpack_capsule: DLPack can only export arrays based on USM "
172
165
" allocations bound to a default platform SYCL context"
173
166
)
167
+ # Find the unpartitioned parent of the allocation device
168
+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
169
+ if pDRef is not NULL :
170
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
171
+ while tDRef is not NULL :
172
+ DPCTLDevice_Delete(pDRef)
173
+ pDRef = tDRef
174
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
175
+ ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
176
+
177
+ # Find ordinal number of the parent device
178
+ device_id = ary_sycl_device.get_overall_ordinal()
179
+ if device_id < 0 :
180
+ raise DLPackCreationError(
181
+ " to_dlpack_capsule: failed to determine device_id"
182
+ )
174
183
175
184
dlm_tensor = < DLManagedTensor * > stdlib.malloc(
176
185
sizeof(DLManagedTensor))
@@ -192,14 +201,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
192
201
for i in range (nd):
193
202
shape_strides_ptr[nd + i] = strides_ptr[i]
194
203
195
- device_id = ary_sycl_device.get_overall_ordinal()
196
- if device_id < 0 :
197
- stdlib.free(shape_strides_ptr)
198
- stdlib.free(dlm_tensor)
199
- raise DLPackCreationError(
200
- " to_dlpack_capsule: failed to determine device_id"
201
- )
202
-
203
204
ary_dt = usm_ary.dtype
204
205
ary_dtk = ary_dt.kind
205
206
element_offset = usm_ary.get_offset()
@@ -278,15 +279,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
278
279
success.
279
280
Raises:
280
281
TypeError: if argument is not a "dltensor" capsule.
281
- ValueError: if argument is "used_dltensor" capsule,
282
- if the USM pointer is not bound to the reconstructed
282
+ ValueError: if argument is "used_dltensor" capsule
283
+ BufferError: if the USM pointer is not bound to the reconstructed
283
284
sycl context, or the DLPack's device_type is not supported
284
285
by dpctl.
285
286
"""
286
287
cdef DLManagedTensor * dlm_tensor = NULL
287
288
cdef bytes usm_type
288
289
cdef size_t sz = 1
289
290
cdef int i
291
+ cdef int device_id = - 1
290
292
cdef int element_bytesize = 0
291
293
cdef Py_ssize_t offset_min = 0
292
294
cdef Py_ssize_t offset_max = 0
@@ -308,26 +310,34 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
308
310
py_caps, " dltensor" )
309
311
# Verify that we can work with this device
310
312
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
311
- q = dpctl.SyclQueue(str (< int > dlm_tensor.dl_tensor.device.device_id))
313
+ device_id = dlm_tensor.dl_tensor.device.device_id
314
+ root_device = dpctl.SyclDevice(str (< int > device_id))
315
+ default_context = root_device.sycl_platform.default_context
312
316
if dlm_tensor.dl_tensor.data is NULL :
313
317
usm_type = b" device"
318
+ q = dpctl.SyclQueue(default_context, root_device)
314
319
else :
315
320
usm_type = c_dpmem._Memory.get_pointer_type(
316
321
< DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
317
- < c_dpctl.SyclContext> q.sycl_context)
318
- if usm_type == b" unknown" :
319
- raise ValueError (
320
- f" Data pointer in DLPack is not bound to default sycl "
321
- " context of device '{device_id}', translated to "
322
- " {q.sycl_device.filter_string}"
322
+ < c_dpctl.SyclContext> default_context)
323
+ if usm_type == b" unknown" :
324
+ raise BufferError(
325
+ " Data pointer in DLPack is not bound to default sycl "
326
+ f" context of device '{device_id}', translated to "
327
+ f" {root_device.filter_string}"
328
+ )
329
+ alloc_device = c_dpmem._Memory.get_pointer_device(
330
+ < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
331
+ < c_dpctl.SyclContext> default_context
323
332
)
333
+ q = dpctl.SyclQueue(default_context, alloc_device)
324
334
if dlm_tensor.dl_tensor.dtype.bits % 8 :
325
- raise ValueError (
335
+ raise BufferError (
326
336
" Can not import DLPack tensor whose element's "
327
337
" bitsize is not a multiple of 8"
328
338
)
329
339
if dlm_tensor.dl_tensor.dtype.lanes != 1 :
330
- raise ValueError (
340
+ raise BufferError (
331
341
" Can not import DLPack tensor with lanes != 1"
332
342
)
333
343
if dlm_tensor.dl_tensor.strides is NULL :
0 commit comments