@@ -71,6 +71,7 @@ cdef extern from 'dlpack/dlpack.h' nogil:
71
71
kDLFloat
72
72
kDLBfloat
73
73
kDLComplex
74
+ kDLBool
74
75
75
76
ctypedef struct DLDataType:
76
77
uint8_t code
@@ -244,7 +245,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
244
245
dl_tensor.dtype.lanes = < uint16_t> 1
245
246
dl_tensor.dtype.bits = < uint8_t> (ary_dt.itemsize * 8 )
246
247
if (ary_dtk == " b" ):
247
- dl_tensor.dtype.code = < uint8_t> kDLUInt
248
+ dl_tensor.dtype.code = < uint8_t> kDLBool
248
249
elif (ary_dtk == " u" ):
249
250
dl_tensor.dtype.code = < uint8_t> kDLUInt
250
251
elif (ary_dtk == " i" ):
@@ -311,14 +312,17 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
311
312
cdef DLManagedTensor * dlm_tensor = NULL
312
313
cdef bytes usm_type
313
314
cdef size_t sz = 1
315
+ cdef size_t alloc_sz = 1
314
316
cdef int i
315
317
cdef int device_id = - 1
316
318
cdef int element_bytesize = 0
317
319
cdef Py_ssize_t offset_min = 0
318
320
cdef Py_ssize_t offset_max = 0
319
- cdef int64_t stride_i
320
321
cdef char * mem_ptr = NULL
322
+ cdef Py_ssize_t mem_ptr_delta = 0
321
323
cdef Py_ssize_t element_offset = 0
324
+ cdef int64_t stride_i = - 1
325
+ cdef int64_t shape_i = - 1
322
326
323
327
if not cpython.PyCapsule_IsValid(py_caps, ' dltensor' ):
324
328
if cpython.PyCapsule_IsValid(py_caps, ' used_dltensor' ):
@@ -370,22 +374,22 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
370
374
raise BufferError(
371
375
" Can not import DLPack tensor with lanes != 1"
372
376
)
377
+ offset_min = 0
373
378
if dlm_tensor.dl_tensor.strides is NULL :
374
379
for i in range (dlm_tensor.dl_tensor.ndim):
375
380
sz = sz * dlm_tensor.dl_tensor.shape[i]
381
+ offset_max = sz - 1
376
382
else :
377
- offset_min = 0
378
383
offset_max = 0
379
384
for i in range (dlm_tensor.dl_tensor.ndim):
380
385
stride_i = dlm_tensor.dl_tensor.strides[i]
381
- if stride_i > 0 :
382
- offset_max = offset_max + stride_i * (
383
- dlm_tensor.dl_tensor.shape[i] - 1
384
- )
385
- else :
386
- offset_min = offset_min + stride_i * (
387
- dlm_tensor.dl_tensor.shape[i] - 1
388
- )
386
+ shape_i = dlm_tensor.dl_tensor.shape[i]
387
+ if shape_i > 1 :
388
+ shape_i -= 1
389
+ if stride_i > 0 :
390
+ offset_max = offset_max + stride_i * shape_i
391
+ else :
392
+ offset_min = offset_min + stride_i * shape_i
389
393
sz = offset_max - offset_min + 1
390
394
if sz == 0 :
391
395
sz = 1
@@ -401,14 +405,29 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
401
405
if dlm_tensor.dl_tensor.data is NULL :
402
406
usm_mem = dpmem.MemoryUSMDevice(sz, q)
403
407
else :
404
- mem_ptr = < char * > dlm_tensor.dl_tensor.data + dlm_tensor.dl_tensor.byte_offset
405
- mem_ptr = mem_ptr - (element_offset * element_bytesize)
406
- usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
408
+ mem_ptr_delta = dlm_tensor.dl_tensor.byte_offset - (
409
+ element_offset * element_bytesize
410
+ )
411
+ mem_ptr = < char * > dlm_tensor.dl_tensor.data
412
+ alloc_sz = dlm_tensor.dl_tensor.byte_offset + < uint64_t> (
413
+ (offset_max + 1 ) * element_bytesize)
414
+ tmp = c_dpmem._Memory.create_from_usm_pointer_size_qref(
407
415
< DPCTLSyclUSMRef> mem_ptr,
408
- sz ,
416
+ max (alloc_sz, < uint64_t > element_bytesize) ,
409
417
(< c_dpctl.SyclQueue> q).get_queue_ref(),
410
418
memory_owner = dlm_holder
411
419
)
420
+ if mem_ptr_delta == 0 :
421
+ usm_mem = tmp
422
+ else :
423
+ alloc_sz = dlm_tensor.dl_tensor.byte_offset + < uint64_t> (
424
+ (offset_max * element_bytesize + mem_ptr_delta))
425
+ usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
426
+ < DPCTLSyclUSMRef> (mem_ptr + (element_bytesize - mem_ptr_delta)),
427
+ max (alloc_sz, < uint64_t> element_bytesize),
428
+ (< c_dpctl.SyclQueue> q).get_queue_ref(),
429
+ memory_owner = tmp
430
+ )
412
431
py_shape = list ()
413
432
for i in range (dlm_tensor.dl_tensor.ndim):
414
433
py_shape.append(dlm_tensor.dl_tensor.shape[i])
@@ -426,8 +445,10 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
426
445
ary_dt = np.dtype(" f" + str (element_bytesize))
427
446
elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
428
447
ary_dt = np.dtype(" c" + str (element_bytesize))
448
+ elif (dlm_tensor.dl_tensor.dtype.code == kDLBool):
449
+ ary_dt = np.dtype(" ?" )
429
450
else :
430
- raise ValueError (
451
+ raise BufferError (
431
452
" Can not import DLPack tensor with type code {}." .format(
432
453
< object > dlm_tensor.dl_tensor.dtype.code
433
454
)
@@ -441,7 +462,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
441
462
)
442
463
return res_ary
443
464
else :
444
- raise ValueError (
465
+ raise BufferError (
445
466
" The DLPack tensor resides on unsupported device."
446
467
)
447
468
0 commit comments