@@ -36,6 +36,7 @@ from ._usmarray cimport usm_ndarray
36
36
import numpy as np
37
37
38
38
import dpctl
39
+ import dpctl.memory as dpmem
39
40
40
41
41
42
cdef extern from ' ./include/dlpack/dlpack.h' nogil:
@@ -95,10 +96,6 @@ cdef void pycapsule_deleter(object dlt_capsule):
95
96
dlm_tensor = < DLManagedTensor* > cpython.PyCapsule_GetPointer(
96
97
dlt_capsule, ' dltensor' )
97
98
dlm_tensor.deleter(dlm_tensor)
98
- elif cpython.PyCapsule_IsValid(dlt_capsule, ' used_dltensor' ):
99
- dlm_tensor = < DLManagedTensor* > cpython.PyCapsule_GetPointer(
100
- dlt_capsule, ' used_dltensor' )
101
- dlm_tensor.deleter(dlm_tensor)
102
99
103
100
104
101
cdef void managed_tensor_deleter(DLManagedTensor * dlm_tensor) with gil:
@@ -133,7 +130,11 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
133
130
cdef int64_t * shape_strides_ptr = NULL
134
131
cdef int i = 0
135
132
cdef int device_id = - 1
133
+ cdef char * base_ptr = NULL
134
+ cdef Py_ssize_t element_offset = 0
135
+ cdef Py_ssize_t byte_offset = 0
136
136
137
+ ary_base = usm_ary.get_base()
137
138
ary_sycl_queue = usm_ary.get_sycl_queue()
138
139
ary_sycl_device = ary_sycl_queue.get_sycl_device()
139
140
@@ -176,11 +177,13 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
176
177
177
178
ary_dt = usm_ary.dtype
178
179
ary_dtk = ary_dt.kind
180
+ element_offset = usm_ary.get_offset()
181
+ byte_offset = element_offset * (< Py_ssize_t> ary_dt.itemsize)
179
182
180
183
dl_tensor = & dlm_tensor.dl_tensor
181
- dl_tensor.data = < void * > data_ptr
184
+ dl_tensor.data = < void * > ( data_ptr - byte_offset)
182
185
dl_tensor.ndim = nd
183
- dl_tensor.byte_offset = < uint64_t> 0
186
+ dl_tensor.byte_offset = < uint64_t> byte_offset
184
187
dl_tensor.shape = & shape_strides_ptr[0 ]
185
188
if strides_ptr is NULL :
186
189
dl_tensor.strides = NULL
@@ -212,6 +215,24 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
212
215
return cpython.PyCapsule_New(dlm_tensor, ' dltensor' , pycapsule_deleter)
213
216
214
217
218
+ cdef class _DLManagedTensorOwner:
219
+ """ Helper class managing lifetimes of the DLManagedTensor struct"""
220
+ cdef DLManagedTensor * dlm_tensor
221
+
222
+ def __cinit__ (self ):
223
+ self .dlm_tensor = NULL
224
+
225
+ def __dealloc__ (self ):
226
+ if self .dlm_tensor:
227
+ self .dlm_tensor.deleter(self .dlm_tensor)
228
+
229
+ @staticmethod
230
+ cdef _DLManagedTensorOwner _create(DLManagedTensor * dlm_tensor_src):
231
+ cdef _DLManagedTensorOwner res = _DLManagedTensorOwner.__new__ (_DLManagedTensorOwner)
232
+ res.dlm_tensor = dlm_tensor_src
233
+ return res
234
+
235
+
215
236
cpdef usm_ndarray from_dlpack_capsule(object py_caps) except + :
216
237
""" Reconstructs instance of usm_ndarray from named Python
217
238
capsule object referencing instance of `DLManagedTensor` without
@@ -221,6 +242,11 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
221
242
cdef size_t sz = 1
222
243
cdef int i
223
244
cdef int element_bytesize = 0
245
+ cdef Py_ssize_t offset_min = 0
246
+ cdef Py_ssize_t offset_max = 0
247
+ cdef int64_t stride_i
248
+ cdef char * mem_ptr = NULL
249
+ cdef Py_ssize_t element_offset = 0
224
250
225
251
if not cpython.PyCapsule_IsValid(py_caps, ' dltensor' ):
226
252
if cpython.PyCapsule_IsValid(py_caps, ' used_dltensor' ):
@@ -237,9 +263,12 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
237
263
# Verify that we can work with this device
238
264
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
239
265
q = dpctl.SyclQueue(str (< int > dlm_tensor.dl_tensor.device.device_id))
240
- usm_type = c_dpmem._Memory.get_pointer_type(
241
- < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
242
- < c_dpctl.SyclContext> q.sycl_context)
266
+ if dlm_tensor.dl_tensor.data is NULL :
267
+ usm_type = b" device"
268
+ else :
269
+ usm_type = c_dpmem._Memory.get_pointer_type(
270
+ < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
271
+ < c_dpctl.SyclContext> q.sycl_context)
243
272
if usm_type == b" unknown" :
244
273
raise ValueError (
245
274
f" Data pointer in DLPack is not bound to default sycl "
@@ -255,17 +284,45 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
255
284
raise ValueError (
256
285
" Can not import DLPack tensor with lanes != 1"
257
286
)
258
- for i in range (dlm_tensor.dl_tensor.ndim):
259
- sz = sz * dlm_tensor.dl_tensor.shape[i]
287
+ if dlm_tensor.dl_tensor.strides is NULL :
288
+ for i in range (dlm_tensor.dl_tensor.ndim):
289
+ sz = sz * dlm_tensor.dl_tensor.shape[i]
290
+ else :
291
+ offset_min = 0
292
+ offset_max = 0
293
+ for i in range (dlm_tensor.dl_tensor.ndim):
294
+ stride_i = dlm_tensor.dl_tensor.strides[i]
295
+ if stride_i > 0 :
296
+ offset_max = offset_max + stride_i * (
297
+ dlm_tensor.dl_tensor.shape[i] - 1
298
+ )
299
+ else :
300
+ offset_min = offset_min + stride_i * (
301
+ dlm_tensor.dl_tensor.shape[i] - 1
302
+ )
303
+ sz = offset_max - offset_min + 1
304
+ if sz == 0 :
305
+ sz = 1
260
306
261
307
element_bytesize = (dlm_tensor.dl_tensor.dtype.bits // 8 )
262
308
sz = sz * element_bytesize
263
- usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
264
- < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
265
- sz,
266
- (< c_dpctl.SyclQueue> q).get_queue_ref(),
267
- memory_owner = py_caps
268
- )
309
+ element_offset = dlm_tensor.dl_tensor.byte_offset // element_bytesize
310
+
311
+ # transfer dlm_tensor ownership
312
+ dlm_holder = _DLManagedTensorOwner._create(dlm_tensor)
313
+ cpython.PyCapsule_SetName(py_caps, ' used_dltensor' )
314
+
315
+ if dlm_tensor.dl_tensor.data is NULL :
316
+ usm_mem = dpmem.MemoryUSMDevice(sz, q)
317
+ else :
318
+ mem_ptr = < char * > dlm_tensor.dl_tensor.data + dlm_tensor.dl_tensor.byte_offset
319
+ mem_ptr = mem_ptr - (element_offset * element_bytesize)
320
+ usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
321
+ < DPCTLSyclUSMRef> mem_ptr,
322
+ sz,
323
+ (< c_dpctl.SyclQueue> q).get_queue_ref(),
324
+ memory_owner = dlm_holder
325
+ )
269
326
py_shape = list ()
270
327
for i in range (dlm_tensor.dl_tensor.ndim):
271
328
py_shape.append(dlm_tensor.dl_tensor.shape[i])
@@ -293,9 +350,9 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
293
350
py_shape,
294
351
dtype = ary_dt,
295
352
buffer = usm_mem,
296
- strides = py_strides
353
+ strides = py_strides,
354
+ offset = element_offset
297
355
)
298
- cpython.PyCapsule_SetName(py_caps, ' used_dltensor' )
299
356
return res_ary
300
357
else :
301
358
raise ValueError (
@@ -304,8 +361,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
304
361
305
362
306
363
cpdef from_dlpack(array):
307
- """ Constructs `usm_ndarray` from a Python object that implements
308
- `__dlpack__` protocol.
364
+ """ dpctl.tensor.from_dlpack(obj)
365
+
366
+ Constructs :class:`dpctl.tensor.usm_ndarray` instance from a Python
367
+ object `obj` that implements `__dlpack__` protocol. The output
368
+ array is always a zero-copy view of the input.
369
+
370
+ Raises:
371
+ TypeError: if `obj` does not implement `__dlpack__` method.
372
+ ValueError: if zero copy view can not be constructed because
373
+ the input array resides on an unsupported device.
309
374
"""
310
375
if not hasattr (array, " __dlpack__" ):
311
376
raise TypeError (
0 commit comments