@@ -140,6 +140,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
140
140
cdef c_dpctl.SyclQueue ary_sycl_queue
141
141
cdef c_dpctl.SyclDevice ary_sycl_device
142
142
cdef DPCTLSyclDeviceRef pDRef = NULL
143
+ cdef DPCTLSyclDeviceRef tDRef = NULL
143
144
cdef DLManagedTensor * dlm_tensor = NULL
144
145
cdef DLTensor * dl_tensor = NULL
145
146
cdef int nd = usm_ary.get_ndim()
@@ -157,19 +158,42 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
157
158
ary_sycl_queue = usm_ary.get_sycl_queue()
158
159
ary_sycl_device = ary_sycl_queue.get_sycl_device()
159
160
160
- # check that ary_sycl_device is a non-partitioned device
161
- pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
162
- if pDRef is not NULL :
163
- DPCTLDevice_Delete(pDRef)
164
- raise DLPackCreationError(
165
- " to_dlpack_capsule: DLPack can only export arrays allocated on "
166
- " non-partitioned SYCL devices."
167
- )
168
- default_context = dpctl.SyclQueue(ary_sycl_device).sycl_context
169
- if not usm_ary.sycl_context == default_context:
161
+ try :
162
+ default_context = ary_sycl_device.sycl_platform.default_context
163
+ except RuntimeError :
164
+ # RT does not support default_context, e.g. Windows
165
+ default_context = None
166
+ if default_context is None :
167
+ # check that ary_sycl_device is a non-partitioned device
168
+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
169
+ if pDRef is not NULL :
170
+ DPCTLDevice_Delete(pDRef)
171
+ raise DLPackCreationError(
172
+ " to_dlpack_capsule: DLPack can only export arrays allocated "
173
+ " on non-partitioned SYCL devices on platforms where "
174
+ " default_context oneAPI extension is not supported."
175
+ )
176
+ else :
177
+ if not usm_ary.sycl_context == default_context:
178
+ raise DLPackCreationError(
179
+ " to_dlpack_capsule: DLPack can only export arrays based on USM "
180
+ " allocations bound to a default platform SYCL context"
181
+ )
182
+ # Find the unpartitioned parent of the allocation device
183
+ pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
184
+ if pDRef is not NULL :
185
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
186
+ while tDRef is not NULL :
187
+ DPCTLDevice_Delete(pDRef)
188
+ pDRef = tDRef
189
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
190
+ ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
191
+
192
+ # Find ordinal number of the parent device
193
+ device_id = ary_sycl_device.get_overall_ordinal()
194
+ if device_id < 0 :
170
195
raise DLPackCreationError(
171
- " to_dlpack_capsule: DLPack can only export arrays based on USM "
172
- " allocations bound to a default platform SYCL context"
196
+ " to_dlpack_capsule: failed to determine device_id"
173
197
)
174
198
175
199
dlm_tensor = < DLManagedTensor * > stdlib.malloc(
@@ -192,14 +216,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
192
216
for i in range (nd):
193
217
shape_strides_ptr[nd + i] = strides_ptr[i]
194
218
195
- device_id = ary_sycl_device.get_overall_ordinal()
196
- if device_id < 0 :
197
- stdlib.free(shape_strides_ptr)
198
- stdlib.free(dlm_tensor)
199
- raise DLPackCreationError(
200
- " to_dlpack_capsule: failed to determine device_id"
201
- )
202
-
203
219
ary_dt = usm_ary.dtype
204
220
ary_dtk = ary_dt.kind
205
221
element_offset = usm_ary.get_offset()
@@ -278,15 +294,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
278
294
success.
279
295
Raises:
280
296
TypeError: if argument is not a "dltensor" capsule.
281
- ValueError: if argument is "used_dltensor" capsule,
282
- if the USM pointer is not bound to the reconstructed
297
+ ValueError: if argument is "used_dltensor" capsule
298
+ BufferError: if the USM pointer is not bound to the reconstructed
283
299
sycl context, or the DLPack's device_type is not supported
284
300
by dpctl.
285
301
"""
286
302
cdef DLManagedTensor * dlm_tensor = NULL
287
303
cdef bytes usm_type
288
304
cdef size_t sz = 1
289
305
cdef int i
306
+ cdef int device_id = - 1
290
307
cdef int element_bytesize = 0
291
308
cdef Py_ssize_t offset_min = 0
292
309
cdef Py_ssize_t offset_max = 0
@@ -308,26 +325,37 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
308
325
py_caps, " dltensor" )
309
326
# Verify that we can work with this device
310
327
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
311
- q = dpctl.SyclQueue(str (< int > dlm_tensor.dl_tensor.device.device_id))
328
+ device_id = dlm_tensor.dl_tensor.device.device_id
329
+ root_device = dpctl.SyclDevice(str (< int > device_id))
330
+ try :
331
+ default_context = root_device.sycl_platform.default_context
332
+ except RuntimeError :
333
+ default_context = dpctl.SyclQueue(root_device).sycl_context
312
334
if dlm_tensor.dl_tensor.data is NULL :
313
335
usm_type = b" device"
336
+ q = dpctl.SyclQueue(default_context, root_device)
314
337
else :
315
338
usm_type = c_dpmem._Memory.get_pointer_type(
316
339
< DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
317
- < c_dpctl.SyclContext> q.sycl_context)
318
- if usm_type == b" unknown" :
319
- raise ValueError (
320
- f" Data pointer in DLPack is not bound to default sycl "
321
- " context of device '{device_id}', translated to "
322
- " {q.sycl_device.filter_string}"
340
+ < c_dpctl.SyclContext> default_context)
341
+ if usm_type == b" unknown" :
342
+ raise BufferError(
343
+ " Data pointer in DLPack is not bound to default sycl "
344
+ f" context of device '{device_id}', translated to "
345
+ f" {root_device.filter_string}"
346
+ )
347
+ alloc_device = c_dpmem._Memory.get_pointer_device(
348
+ < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
349
+ < c_dpctl.SyclContext> default_context
323
350
)
351
+ q = dpctl.SyclQueue(default_context, alloc_device)
324
352
if dlm_tensor.dl_tensor.dtype.bits % 8 :
325
- raise ValueError (
353
+ raise BufferError (
326
354
" Can not import DLPack tensor whose element's "
327
355
" bitsize is not a multiple of 8"
328
356
)
329
357
if dlm_tensor.dl_tensor.dtype.lanes != 1 :
330
- raise ValueError (
358
+ raise BufferError (
331
359
" Can not import DLPack tensor with lanes != 1"
332
360
)
333
361
if dlm_tensor.dl_tensor.strides is NULL :
0 commit comments