Skip to content

Commit 288ac71

Browse files
DLManagedTensor lifetime management implemented per array-API specs
1. The pycapsule destructor only calls DLManagedTensor.deleter is the name is "dltensor" 2. Code consuming the DLPack capsule renamed the capsule (to avoid destructor calling the deleter) and instead creates an internal object to do that and uses that internal object as the base of _Memory object `from_dlpack_capsule` function should handle NULL data field For zero-elements arrays in DLPack, allocate 1 element Proper support for strides added. Expanded docstring of `dpctl.tensor.from_dlpack`
1 parent eb94dd4 commit 288ac71

File tree

1 file changed

+86
-21
lines changed

1 file changed

+86
-21
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 86 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ from ._usmarray cimport usm_ndarray
3636
import numpy as np
3737

3838
import dpctl
39+
import dpctl.memory as dpmem
3940

4041

4142
cdef extern from './include/dlpack/dlpack.h' nogil:
@@ -95,10 +96,6 @@ cdef void pycapsule_deleter(object dlt_capsule):
9596
dlm_tensor = <DLManagedTensor*>cpython.PyCapsule_GetPointer(
9697
dlt_capsule, 'dltensor')
9798
dlm_tensor.deleter(dlm_tensor)
98-
elif cpython.PyCapsule_IsValid(dlt_capsule, 'used_dltensor'):
99-
dlm_tensor = <DLManagedTensor*>cpython.PyCapsule_GetPointer(
100-
dlt_capsule, 'used_dltensor')
101-
dlm_tensor.deleter(dlm_tensor)
10299

103100

104101
cdef void managed_tensor_deleter(DLManagedTensor *dlm_tensor) with gil:
@@ -133,7 +130,11 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
133130
cdef int64_t *shape_strides_ptr = NULL
134131
cdef int i = 0
135132
cdef int device_id = -1
133+
cdef char* base_ptr = NULL
134+
cdef Py_ssize_t element_offset = 0
135+
cdef Py_ssize_t byte_offset = 0
136136

137+
ary_base = usm_ary.get_base()
137138
ary_sycl_queue = usm_ary.get_sycl_queue()
138139
ary_sycl_device = ary_sycl_queue.get_sycl_device()
139140

@@ -176,11 +177,13 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
176177

177178
ary_dt = usm_ary.dtype
178179
ary_dtk = ary_dt.kind
180+
element_offset = usm_ary.get_offset()
181+
byte_offset = element_offset * (<Py_ssize_t>ary_dt.itemsize)
179182

180183
dl_tensor = &dlm_tensor.dl_tensor
181-
dl_tensor.data = <void*>data_ptr
184+
dl_tensor.data = <void*>(data_ptr - byte_offset)
182185
dl_tensor.ndim = nd
183-
dl_tensor.byte_offset = <uint64_t>0
186+
dl_tensor.byte_offset = <uint64_t>byte_offset
184187
dl_tensor.shape = &shape_strides_ptr[0]
185188
if strides_ptr is NULL:
186189
dl_tensor.strides = NULL
@@ -212,6 +215,24 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
212215
return cpython.PyCapsule_New(dlm_tensor, 'dltensor', pycapsule_deleter)
213216

214217

218+
cdef class _DLManagedTensorOwner:
219+
"""Helper class managing lifetimes of the DLManagedTensor struct"""
220+
cdef DLManagedTensor *dlm_tensor
221+
222+
def __cinit__(self):
223+
self.dlm_tensor = NULL
224+
225+
def __dealloc__(self):
226+
if self.dlm_tensor:
227+
self.dlm_tensor.deleter(self.dlm_tensor)
228+
229+
@staticmethod
230+
cdef _DLManagedTensorOwner _create(DLManagedTensor *dlm_tensor_src):
231+
cdef _DLManagedTensorOwner res = _DLManagedTensorOwner.__new__(_DLManagedTensorOwner)
232+
res.dlm_tensor = dlm_tensor_src
233+
return res
234+
235+
215236
cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
216237
"""Reconstructs instance of usm_ndarray from named Python
217238
capsule object referencing instance of `DLManagedTensor` without
@@ -221,6 +242,11 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
221242
cdef size_t sz = 1
222243
cdef int i
223244
cdef int element_bytesize = 0
245+
cdef Py_ssize_t offset_min = 0
246+
cdef Py_ssize_t offset_max = 0
247+
cdef int64_t stride_i
248+
cdef char* mem_ptr = NULL
249+
cdef Py_ssize_t element_offset = 0
224250

225251
if not cpython.PyCapsule_IsValid(py_caps, 'dltensor'):
226252
if cpython.PyCapsule_IsValid(py_caps, 'used_dltensor'):
@@ -237,9 +263,12 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
237263
# Verify that we can work with this device
238264
if dlm_tensor.dl_tensor.device.device_type == kDLOneAPI:
239265
q = dpctl.SyclQueue(str(<int>dlm_tensor.dl_tensor.device.device_id))
240-
usm_type = c_dpmem._Memory.get_pointer_type(
241-
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
242-
<c_dpctl.SyclContext>q.sycl_context)
266+
if dlm_tensor.dl_tensor.data is NULL:
267+
usm_type = b"device"
268+
else:
269+
usm_type = c_dpmem._Memory.get_pointer_type(
270+
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
271+
<c_dpctl.SyclContext>q.sycl_context)
243272
if usm_type == b"unknown":
244273
raise ValueError(
245274
f"Data pointer in DLPack is not bound to default sycl "
@@ -255,17 +284,45 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
255284
raise ValueError(
256285
"Can not import DLPack tensor with lanes != 1"
257286
)
258-
for i in range(dlm_tensor.dl_tensor.ndim):
259-
sz = sz * dlm_tensor.dl_tensor.shape[i]
287+
if dlm_tensor.dl_tensor.strides is NULL:
288+
for i in range(dlm_tensor.dl_tensor.ndim):
289+
sz = sz * dlm_tensor.dl_tensor.shape[i]
290+
else:
291+
offset_min = 0
292+
offset_max = 0
293+
for i in range(dlm_tensor.dl_tensor.ndim):
294+
stride_i = dlm_tensor.dl_tensor.strides[i]
295+
if stride_i > 0:
296+
offset_max = offset_max + stride_i * (
297+
dlm_tensor.dl_tensor.shape[i] - 1
298+
)
299+
else:
300+
offset_min = offset_min + stride_i * (
301+
dlm_tensor.dl_tensor.shape[i] - 1
302+
)
303+
sz = offset_max - offset_min + 1
304+
if sz == 0:
305+
sz = 1
260306

261307
element_bytesize = (dlm_tensor.dl_tensor.dtype.bits // 8)
262308
sz = sz * element_bytesize
263-
usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
264-
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
265-
sz,
266-
(<c_dpctl.SyclQueue>q).get_queue_ref(),
267-
memory_owner=py_caps
268-
)
309+
element_offset = dlm_tensor.dl_tensor.byte_offset // element_bytesize
310+
311+
# transfer dlm_tensor ownership
312+
dlm_holder = _DLManagedTensorOwner._create(dlm_tensor)
313+
cpython.PyCapsule_SetName(py_caps, 'used_dltensor')
314+
315+
if dlm_tensor.dl_tensor.data is NULL:
316+
usm_mem = dpmem.MemoryUSMDevice(sz, q)
317+
else:
318+
mem_ptr = <char *>dlm_tensor.dl_tensor.data + dlm_tensor.dl_tensor.byte_offset
319+
mem_ptr = mem_ptr - (element_offset * element_bytesize)
320+
usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
321+
<DPCTLSyclUSMRef> mem_ptr,
322+
sz,
323+
(<c_dpctl.SyclQueue>q).get_queue_ref(),
324+
memory_owner=dlm_holder
325+
)
269326
py_shape = list()
270327
for i in range(dlm_tensor.dl_tensor.ndim):
271328
py_shape.append(dlm_tensor.dl_tensor.shape[i])
@@ -293,9 +350,9 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
293350
py_shape,
294351
dtype=ary_dt,
295352
buffer=usm_mem,
296-
strides=py_strides
353+
strides=py_strides,
354+
offset=element_offset
297355
)
298-
cpython.PyCapsule_SetName(py_caps, 'used_dltensor')
299356
return res_ary
300357
else:
301358
raise ValueError(
@@ -304,8 +361,16 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
304361

305362

306363
cpdef from_dlpack(array):
307-
"""Constructs `usm_ndarray` from a Python object that implements
308-
`__dlpack__` protocol.
364+
"""dpctl.tensor.from_dlpack(obj)
365+
366+
Constructs :class:`dpctl.tensor.usm_ndarray` instance from a Python
367+
object `obj` that implements `__dlpack__` protocol. The output
368+
array is always a zero-copy view of the input.
369+
370+
Raises:
371+
TypeError: if `obj` does not implement `__dlpack__` method.
372+
ValueError: if zero copy view can not be constructed because
373+
the input array resides on an unsupported device.
309374
"""
310375
if not hasattr(array, "__dlpack__"):
311376
raise TypeError(

0 commit comments

Comments
 (0)