Skip to content

Commit 6e58a41

Browse files
Merge pull request #1073 from IntelPython/fix-gh-1071-from-dlpack
Fix gh 1071 from dlpack
2 parents 8e63098 + 84b0232 commit 6e58a41

File tree

4 files changed

+69
-25
lines changed

4 files changed

+69
-25
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ cdef extern from 'dlpack/dlpack.h' nogil:
7171
kDLFloat
7272
kDLBfloat
7373
kDLComplex
74+
kDLBool
7475

7576
ctypedef struct DLDataType:
7677
uint8_t code
@@ -244,7 +245,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
244245
dl_tensor.dtype.lanes = <uint16_t>1
245246
dl_tensor.dtype.bits = <uint8_t>(ary_dt.itemsize * 8)
246247
if (ary_dtk == "b"):
247-
dl_tensor.dtype.code = <uint8_t>kDLUInt
248+
dl_tensor.dtype.code = <uint8_t>kDLBool
248249
elif (ary_dtk == "u"):
249250
dl_tensor.dtype.code = <uint8_t>kDLUInt
250251
elif (ary_dtk == "i"):
@@ -311,14 +312,17 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
311312
cdef DLManagedTensor *dlm_tensor = NULL
312313
cdef bytes usm_type
313314
cdef size_t sz = 1
315+
cdef size_t alloc_sz = 1
314316
cdef int i
315317
cdef int device_id = -1
316318
cdef int element_bytesize = 0
317319
cdef Py_ssize_t offset_min = 0
318320
cdef Py_ssize_t offset_max = 0
319-
cdef int64_t stride_i
320321
cdef char *mem_ptr = NULL
322+
cdef Py_ssize_t mem_ptr_delta = 0
321323
cdef Py_ssize_t element_offset = 0
324+
cdef int64_t stride_i = -1
325+
cdef int64_t shape_i = -1
322326

323327
if not cpython.PyCapsule_IsValid(py_caps, 'dltensor'):
324328
if cpython.PyCapsule_IsValid(py_caps, 'used_dltensor'):
@@ -370,22 +374,22 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
370374
raise BufferError(
371375
"Can not import DLPack tensor with lanes != 1"
372376
)
377+
offset_min = 0
373378
if dlm_tensor.dl_tensor.strides is NULL:
374379
for i in range(dlm_tensor.dl_tensor.ndim):
375380
sz = sz * dlm_tensor.dl_tensor.shape[i]
381+
offset_max = sz - 1
376382
else:
377-
offset_min = 0
378383
offset_max = 0
379384
for i in range(dlm_tensor.dl_tensor.ndim):
380385
stride_i = dlm_tensor.dl_tensor.strides[i]
381-
if stride_i > 0:
382-
offset_max = offset_max + stride_i * (
383-
dlm_tensor.dl_tensor.shape[i] - 1
384-
)
385-
else:
386-
offset_min = offset_min + stride_i * (
387-
dlm_tensor.dl_tensor.shape[i] - 1
388-
)
386+
shape_i = dlm_tensor.dl_tensor.shape[i]
387+
if shape_i > 1:
388+
shape_i -= 1
389+
if stride_i > 0:
390+
offset_max = offset_max + stride_i * shape_i
391+
else:
392+
offset_min = offset_min + stride_i * shape_i
389393
sz = offset_max - offset_min + 1
390394
if sz == 0:
391395
sz = 1
@@ -401,14 +405,29 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
401405
if dlm_tensor.dl_tensor.data is NULL:
402406
usm_mem = dpmem.MemoryUSMDevice(sz, q)
403407
else:
404-
mem_ptr = <char *>dlm_tensor.dl_tensor.data + dlm_tensor.dl_tensor.byte_offset
405-
mem_ptr = mem_ptr - (element_offset * element_bytesize)
406-
usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
408+
mem_ptr_delta = dlm_tensor.dl_tensor.byte_offset - (
409+
element_offset * element_bytesize
410+
)
411+
mem_ptr = <char *>dlm_tensor.dl_tensor.data
412+
alloc_sz = dlm_tensor.dl_tensor.byte_offset + <uint64_t>(
413+
(offset_max + 1) * element_bytesize)
414+
tmp = c_dpmem._Memory.create_from_usm_pointer_size_qref(
407415
<DPCTLSyclUSMRef> mem_ptr,
408-
sz,
416+
max(alloc_sz, <uint64_t>element_bytesize),
409417
(<c_dpctl.SyclQueue>q).get_queue_ref(),
410418
memory_owner=dlm_holder
411419
)
420+
if mem_ptr_delta == 0:
421+
usm_mem = tmp
422+
else:
423+
alloc_sz = dlm_tensor.dl_tensor.byte_offset + <uint64_t>(
424+
(offset_max * element_bytesize + mem_ptr_delta))
425+
usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
426+
<DPCTLSyclUSMRef> (mem_ptr + (element_bytesize - mem_ptr_delta)),
427+
max(alloc_sz, <uint64_t>element_bytesize),
428+
(<c_dpctl.SyclQueue>q).get_queue_ref(),
429+
memory_owner=tmp
430+
)
412431
py_shape = list()
413432
for i in range(dlm_tensor.dl_tensor.ndim):
414433
py_shape.append(dlm_tensor.dl_tensor.shape[i])
@@ -426,8 +445,10 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
426445
ary_dt = np.dtype("f" + str(element_bytesize))
427446
elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
428447
ary_dt = np.dtype("c" + str(element_bytesize))
448+
elif (dlm_tensor.dl_tensor.dtype.code == kDLBool):
449+
ary_dt = np.dtype("?")
429450
else:
430-
raise ValueError(
451+
raise BufferError(
431452
"Can not import DLPack tensor with type code {}.".format(
432453
<object>dlm_tensor.dl_tensor.dtype.code
433454
)
@@ -441,7 +462,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
441462
)
442463
return res_ary
443464
else:
444-
raise ValueError(
465+
raise BufferError(
445466
"The DLPack tensor resides on unsupported device."
446467
)
447468

dpctl/tensor/include/dlpack/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# DLPack header
22

3-
The header `dlpack.h` downloaded from `https://github.com/dmlc/dlpack.git` remote at tag 0.7 commit [`e2bdd3bee8`](https://github.com/dmlc/dlpack/commit/e2bdd3bee8cb6501558042633fa59144cc8b7f5f).
3+
The header `dlpack.h` downloaded from `https://github.com/dmlc/dlpack.git` remote at tag v0.8 commit [`365b823`](https://github.com/dmlc/dlpack/commit/365b823cedb281cd0240ca601aba9b78771f91a3).
44

55
The file can also be viewed using github web interface at https://github.com/dmlc/dlpack/blob/e2bdd3bee8cb6501558042633fa59144cc8b7f5f/include/dlpack/dlpack.h
66

dpctl/tensor/include/dlpack/dlpack.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#endif
1717

1818
/*! \brief The current version of dlpack */
19-
#define DLPACK_VERSION 70
19+
#define DLPACK_VERSION 80
2020

2121
/*! \brief The current ABI version of dlpack */
2222
#define DLPACK_ABI_VERSION 1
@@ -126,6 +126,8 @@ typedef enum {
126126
* (C/C++/Python layout: compact struct per complex number)
127127
*/
128128
kDLComplex = 5U,
129+
/*! \brief boolean */
130+
kDLBool = 6U,
129131
} DLDataTypeCode;
130132

131133
/*!
@@ -134,10 +136,11 @@ typedef enum {
134136
* export an array with non-native endianness
135137
*
136138
* Examples
137-
* - float: type_code = 2, bits = 32, lanes=1
138-
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
139-
* - int8: type_code = 0, bits = 8, lanes=1
139+
* - float: type_code = 2, bits = 32, lanes = 1
140+
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4
141+
* - int8: type_code = 0, bits = 8, lanes = 1
140142
* - std::complex<float>: type_code = 5, bits = 64, lanes = 1
143+
* - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
141144
*/
142145
typedef struct {
143146
/*!

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def test_from_dlpack(shape, typestr, usm_type):
101101
X = dpt.empty(shape, dtype=typestr, usm_type=usm_type, device=sycl_dev)
102102
Y = dpt.from_dlpack(X)
103103
assert X.shape == Y.shape
104-
assert X.dtype == Y.dtype or (
105-
str(X.dtype) == "bool" and str(Y.dtype) == "uint8"
106-
)
104+
assert X.dtype == Y.dtype
107105
assert X.sycl_device == Y.sycl_device
108106
assert X.usm_type == Y.usm_type
109107
assert X._pointer == Y._pointer
@@ -113,6 +111,28 @@ def test_from_dlpack(shape, typestr, usm_type):
113111
assert V.strides == W.strides
114112

115113

114+
@pytest.mark.parametrize("mod", [2, 5])
115+
def test_from_dlpack_strides(mod, typestr, usm_type):
116+
all_root_devices = dpctl.get_devices()
117+
for sycl_dev in all_root_devices:
118+
skip_if_dtype_not_supported(typestr, sycl_dev)
119+
X0 = dpt.empty(
120+
3 * mod, dtype=typestr, usm_type=usm_type, device=sycl_dev
121+
)
122+
for start in range(mod):
123+
X = X0[slice(-start - 1, None, -mod)]
124+
Y = dpt.from_dlpack(X)
125+
assert X.shape == Y.shape
126+
assert X.dtype == Y.dtype
127+
assert X.sycl_device == Y.sycl_device
128+
assert X.usm_type == Y.usm_type
129+
assert X._pointer == Y._pointer
130+
if Y.ndim:
131+
V = Y[::-1]
132+
W = dpt.from_dlpack(V)
133+
assert V.strides == W.strides
134+
135+
116136
def test_from_dlpack_input_validation():
117137
vstr = dpt._dlpack.get_build_dlpack_version()
118138
assert type(vstr) is str

0 commit comments

Comments
 (0)