Skip to content

Commit 8b14af6

Browse files
usm_ndarray used to be declared public api, not just api
public keyword amounted to Cython exposing private struct members to Python (such usm_ndarray.typenum_, usm_ndarray.nd_, etc) Keeping only api (no public), these members are not longer exposed. Added __pyx_capi__ function to access these private members. Added tests to test __pyx_capi__ functions
1 parent 8d704db commit 8b14af6

File tree

3 files changed

+132
-6
lines changed

3 files changed

+132
-6
lines changed

dpctl/tensor/_usmarray.pxd

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,34 @@ cdef public int USM_ARRAY_C_CONTIGUOUS
88
cdef public int USM_ARRAY_F_CONTIGUOUS
99
cdef public int USM_ARRAY_WRITEABLE
1010

11-
12-
cdef public api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:
11+
cdef public int UAR_BOOL
12+
cdef public int UAR_BYTE
13+
cdef public int UAR_UBYTE
14+
cdef public int UAR_SHORT
15+
cdef public int UAR_USHORT
16+
cdef public int UAR_INT
17+
cdef public int UAR_UINT
18+
cdef public int UAR_LONG
19+
cdef public int UAR_ULONG
20+
cdef public int UAR_LONGLONG
21+
cdef public int UAR_ULONGLONG
22+
cdef public int UAR_FLOAT
23+
cdef public int UAR_DOUBLE
24+
cdef public int UAR_CFLOAT
25+
cdef public int UAR_CDOUBLE
26+
cdef public int UAR_TYPE_SENTINEL
27+
cdef public int UAR_HALF
28+
29+
30+
cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:
1331
# data fields
1432
cdef char* data_
15-
cdef readonly int nd_
33+
cdef int nd_
1634
cdef Py_ssize_t *shape_
1735
cdef Py_ssize_t *strides_
18-
cdef readonly int typenum_
19-
cdef readonly int flags_
20-
cdef readonly object base_
36+
cdef int typenum_
37+
cdef int flags_
38+
cdef object base_
2139
# make usm_ndarray weak-referenceable
2240
cdef object __weakref__
2341

dpctl/tensor/_usmarray.pyx

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,39 @@ cdef usm_ndarray _zero_like(usm_ndarray ary):
635635
)
636636
# TODO: call function to set array elements to zero
637637
return r
638+
639+
640+
cdef api char* usm_ndarray_get_data(usm_ndarray arr):
641+
"""
642+
"""
643+
return arr.get_data()
644+
645+
646+
cdef api int usm_ndarray_get_ndim(usm_ndarray arr):
647+
""""""
648+
return arr.get_ndim()
649+
650+
651+
cdef api Py_ssize_t* usm_ndarray_get_shape(usm_ndarray arr):
652+
""" """
653+
return arr.get_shape()
654+
655+
656+
cdef api Py_ssize_t* usm_ndarray_get_strides(usm_ndarray arr):
657+
""" """
658+
return arr.get_strides()
659+
660+
661+
cdef api int usm_ndarray_get_typenum(usm_ndarray arr):
662+
""" """
663+
return arr.get_typenum()
664+
665+
666+
cdef api int usm_ndarray_get_flags(usm_ndarray arr):
667+
""" """
668+
return arr.get_flags()
669+
670+
671+
cdef api c_dpctl.DPCTLSyclQueueRef usm_ndarray_get_queue_ref(usm_ndarray arr):
672+
""" """
673+
return arr.get_queue_ref()

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,75 @@ def test_datapi_device():
373373
X.device.sycl_queue
374374
X.device.sycl_device
375375
repr(X.device)
376+
377+
378+
def test_pyx_capi():
379+
import ctypes
380+
import sys
381+
382+
X = dpt.usm_ndarray(17)[1::2]
383+
mod = sys.modules[X.__class__.__module__]
384+
# get capsule storign get_context_ref function ptr
385+
arr_data_fn_cap = mod.__pyx_capi__["usm_ndarray_get_data"]
386+
arr_ndim_fn_cap = mod.__pyx_capi__["usm_ndarray_get_ndim"]
387+
arr_shape_fn_cap = mod.__pyx_capi__["usm_ndarray_get_shape"]
388+
arr_strides_fn_cap = mod.__pyx_capi__["usm_ndarray_get_strides"]
389+
arr_typenum_fn_cap = mod.__pyx_capi__["usm_ndarray_get_typenum"]
390+
arr_flags_fn_cap = mod.__pyx_capi__["usm_ndarray_get_flags"]
391+
arr_queue_ref_fn_cap = mod.__pyx_capi__["usm_ndarray_get_queue_ref"]
392+
# construct Python callable to invoke these functions
393+
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
394+
cap_ptr_fn.restype = ctypes.c_void_p
395+
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
396+
callable_maker_ptr = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
397+
callable_maker_int = ctypes.PYFUNCTYPE(ctypes.c_int, ctypes.py_object)
398+
arr_data_fn_ptr = cap_ptr_fn(
399+
arr_data_fn_cap, b"char *(struct PyUSMArrayObject *)"
400+
)
401+
get_data_fn = callable_maker_ptr(arr_data_fn_ptr)
402+
403+
arr_ndim_fn_ptr = cap_ptr_fn(
404+
arr_ndim_fn_cap, b"int (struct PyUSMArrayObject *)"
405+
)
406+
get_ndim_fn = callable_maker_int(arr_ndim_fn_ptr)
407+
408+
arr_shape_fn_ptr = cap_ptr_fn(
409+
arr_shape_fn_cap, b"Py_ssize_t *(struct PyUSMArrayObject *)"
410+
)
411+
get_shape_fn = callable_maker_ptr(arr_shape_fn_ptr)
412+
413+
arr_strides_fn_ptr = cap_ptr_fn(
414+
arr_strides_fn_cap, b"Py_ssize_t *(struct PyUSMArrayObject *)"
415+
)
416+
get_strides_fn = callable_maker_ptr(arr_strides_fn_ptr)
417+
arr_typenum_fn_ptr = cap_ptr_fn(
418+
arr_typenum_fn_cap, b"int (struct PyUSMArrayObject *)"
419+
)
420+
get_typenum_fn = callable_maker_int(arr_typenum_fn_ptr)
421+
arr_flags_fn_ptr = cap_ptr_fn(
422+
arr_flags_fn_cap, b"int (struct PyUSMArrayObject *)"
423+
)
424+
get_flags_fn = callable_maker_int(arr_flags_fn_ptr)
425+
arr_queue_ref_fn_ptr = cap_ptr_fn(
426+
arr_queue_ref_fn_cap, b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)"
427+
)
428+
get_queue_ref_fn = callable_maker_ptr(arr_queue_ref_fn_ptr)
429+
430+
r1 = get_data_fn(X)
431+
sua_iface = X.__sycl_usm_array_interface__
432+
assert r1 == sua_iface["data"][0] + sua_iface.get("offset") * X.itemsize
433+
assert get_ndim_fn(X) == X.ndim
434+
c_longlong_p = ctypes.POINTER(ctypes.c_longlong)
435+
shape0 = ctypes.cast(get_shape_fn(X), c_longlong_p).contents.value
436+
assert shape0 == X.shape[0]
437+
strides0_p = get_strides_fn(X)
438+
if strides0_p:
439+
strides0_p = ctypes.cast(strides0_p, c_longlong_p).contents
440+
strides0_p = strides0_p.value
441+
assert strides0_p == 0 or strides0_p == X.strides[0]
442+
typenum = get_typenum_fn(X)
443+
assert type(typenum) is int
444+
flags = get_flags_fn(X)
445+
assert type(flags) is int and flags == X.flags
446+
queue_ref = get_queue_ref_fn(X) # address of a copy, should be unequal
447+
assert queue_ref != X.sycl_queue.addressof_ref()

0 commit comments

Comments
 (0)