Skip to content

Add usm_ndarray creation c-api #1050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions dpctl/apis/include/dpctl4pybind11.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ class dpctl_capi
int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
PyObject *(*UsmNDArray_MakeFromMemory_)(int,
const py::ssize_t *,
int,
Py_MemoryObject *,
py::ssize_t,
char);
PyObject *(*UsmNDArray_MakeFromPtr_)(size_t,
int,
DPCTLSyclUSMRef,
DPCTLSyclQueueRef,
PyObject *);

int USM_ARRAY_C_CONTIGUOUS_;
int USM_ARRAY_F_CONTIGUOUS_;
Expand Down Expand Up @@ -220,11 +232,13 @@ class dpctl_capi
UsmNDArray_GetShape_(nullptr), UsmNDArray_GetStrides_(nullptr),
UsmNDArray_GetTypenum_(nullptr), UsmNDArray_GetElementSize_(nullptr),
UsmNDArray_GetFlags_(nullptr), UsmNDArray_GetQueueRef_(nullptr),
UsmNDArray_GetOffset_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
UAR_SHORT_(-1), UAR_USHORT_(-1), UAR_INT_(-1), UAR_UINT_(-1),
UAR_LONG_(-1), UAR_ULONG_(-1), UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1),
UAR_FLOAT_(-1), UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
UsmNDArray_GetOffset_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
UsmNDArray_MakeFromMemory_(nullptr), UsmNDArray_MakeFromPtr_(nullptr),
USM_ARRAY_C_CONTIGUOUS_(0), USM_ARRAY_F_CONTIGUOUS_(0),
USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1), UAR_SHORT_(-1),
UAR_USHORT_(-1), UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1),
UAR_ULONG_(-1), UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
Expand Down Expand Up @@ -295,6 +309,9 @@ class dpctl_capi
this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
this->UsmNDArray_MakeFromMemory_ = UsmNDArray_MakeFromMemory;
this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;

// constants
this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
Expand Down
52 changes: 52 additions & 0 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1308,3 +1308,55 @@ cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr):
"""Get offset of zero-index array element from the beginning of the USM
allocation"""
return arr.get_offset()

cdef api void UsmNDArray_SetWritableFlag(usm_ndarray arr, int flag):
"""Set/unset USM_ARRAY_WRITABLE in the given array `arr`."""
cdef int arr_fl = arr.flags_
arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag
arr_fl |= (USM_ARRAY_WRITABLE if flag else 0)
arr.flags_ = arr_fl

cdef api object UsmNDArray_MakeFromMemory(
int nd, const Py_ssize_t *shape, int typenum,
c_dpmem._Memory mobj, Py_ssize_t offset, char order
):
"""Create usm_ndarray.

Equivalent to usm_ndarray(
_make_tuple(nd, shape), dtype=_make_dtype(typenum),
buffer=mobj, offset=offset)
"""
cdef object shape_tuple = _make_int_tuple(nd, <Py_ssize_t *>shape)
cdef usm_ndarray arr = usm_ndarray(
shape_tuple,
dtype=_make_typestr(typenum),
buffer=mobj,
offset=offset,
order=<bytes>(order)
)
return arr


cdef api object UsmNDArray_MakeFromPtr(
size_t nelems,
int typenum,
c_dpctl.DPCTLSyclUSMRef ptr,
c_dpctl.DPCTLSyclQueueRef QRef,
object owner
):
"""Create usm_ndarray from pointer.

Argument owner=None implies transert of USM allocation ownership
to create array object.
"""
cdef size_t itemsize = type_bytesize(typenum)
cdef size_t nbytes = itemsize * nelems
cdef c_dpmem._Memory mobj = c_dpmem._Memory.create_from_usm_pointer_size_qref(
ptr, nbytes, QRef, memory_owner=owner
)
cdef usm_ndarray arr = usm_ndarray(
(nelems,),
dtype=_make_typestr(typenum),
buffer=mobj
)
return arr
22 changes: 22 additions & 0 deletions dpctl/tests/test_sycl_usm.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def test_cpython_api(memory_ctor):
mem_q_ref_fn_cap = mod.__pyx_capi__["Memory_GetQueueRef"]
mem_ctx_ref_fn_cap = mod.__pyx_capi__["Memory_GetContextRef"]
mem_nby_fn_cap = mod.__pyx_capi__["Memory_GetNumBytes"]
mem_make_fn_cap = mod.__pyx_capi__["Memory_Make"]
# construct Python callable to invoke functions
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
cap_ptr_fn.restype = ctypes.c_void_p
Expand All @@ -561,11 +562,23 @@ def test_cpython_api(memory_ctor):
mem_nby_fn_ptr = cap_ptr_fn(
mem_nby_fn_cap, b"size_t (struct Py_MemoryObject *)"
)
mem_make_fn_ptr = cap_ptr_fn(
mem_make_fn_cap,
b"PyObject *(DPCTLSyclUSMRef, size_t, DPCTLSyclQueueRef, PyObject *)",
)
callable_maker = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
get_ptr_fn = callable_maker(mem_ptr_fn_ptr)
get_ctx_ref_fn = callable_maker(mem_ctx_ref_fn_ptr)
get_q_ref_fn = callable_maker(mem_q_ref_fn_ptr)
get_nby_fn = callable_maker(mem_nby_fn_ptr)
make_callable_maker = ctypes.PYFUNCTYPE(
ctypes.py_object,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_void_p,
ctypes.py_object,
)
make_fn = make_callable_maker(mem_make_fn_ptr)

capi_ptr = get_ptr_fn(mobj)
direct_ptr = mobj._pointer
Expand All @@ -580,6 +593,15 @@ def test_cpython_api(memory_ctor):
direct_nbytes = mobj.nbytes
assert capi_nbytes == direct_nbytes

mobj2 = make_fn(
mobj._pointer,
ctypes.c_size_t(mobj.nbytes),
mobj.sycl_queue.addressof_ref(),
mobj,
)
assert mobj2._pointer == mobj._pointer
assert mobj2.reference_obj is mobj


def test_memory_construction_from_other_memory_objects():
try:
Expand Down
128 changes: 126 additions & 2 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,11 @@ def test_datapi_device():


def _pyx_capi_fnptr_to_callable(
X, pyx_capi_name, caps_name, fn_restype=ctypes.c_void_p
X,
pyx_capi_name,
caps_name,
fn_restype=ctypes.c_void_p,
fn_argtypes=(ctypes.py_object,),
):
import sys

Expand All @@ -388,7 +392,7 @@ def _pyx_capi_fnptr_to_callable(
cap_ptr_fn.restype = ctypes.c_void_p
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
fn_ptr = cap_ptr_fn(cap, caps_name)
callable_maker_ptr = ctypes.PYFUNCTYPE(fn_restype, ctypes.py_object)
callable_maker_ptr = ctypes.PYFUNCTYPE(fn_restype, *fn_argtypes)
return callable_maker_ptr(fn_ptr)


Expand All @@ -399,6 +403,7 @@ def test_pyx_capi_get_data():
"UsmNDArray_GetData",
b"char *(struct PyUSMArrayObject *)",
fn_restype=ctypes.c_void_p,
fn_argtypes=(ctypes.py_object,),
)
r1 = get_data_fn(X)
sua_iface = X.__sycl_usm_array_interface__
Expand All @@ -412,6 +417,7 @@ def test_pyx_capi_get_shape():
"UsmNDArray_GetShape",
b"Py_ssize_t *(struct PyUSMArrayObject *)",
fn_restype=ctypes.c_void_p,
fn_argtypes=(ctypes.py_object,),
)
c_longlong_p = ctypes.POINTER(ctypes.c_longlong)
shape0 = ctypes.cast(get_shape_fn(X), c_longlong_p).contents.value
Expand All @@ -425,6 +431,7 @@ def test_pyx_capi_get_strides():
"UsmNDArray_GetStrides",
b"Py_ssize_t *(struct PyUSMArrayObject *)",
fn_restype=ctypes.c_void_p,
fn_argtypes=(ctypes.py_object,),
)
c_longlong_p = ctypes.POINTER(ctypes.c_longlong)
strides0_p = get_strides_fn(X)
Expand All @@ -441,6 +448,7 @@ def test_pyx_capi_get_ndim():
"UsmNDArray_GetNDim",
b"int (struct PyUSMArrayObject *)",
fn_restype=ctypes.c_int,
fn_argtypes=(ctypes.py_object,),
)
assert get_ndim_fn(X) == X.ndim

Expand All @@ -452,6 +460,7 @@ def test_pyx_capi_get_typenum():
"UsmNDArray_GetTypenum",
b"int (struct PyUSMArrayObject *)",
fn_restype=ctypes.c_int,
fn_argtypes=(ctypes.py_object,),
)
typenum = get_typenum_fn(X)
assert type(typenum) is int
Expand All @@ -465,6 +474,7 @@ def test_pyx_capi_get_elemsize():
"UsmNDArray_GetElementSize",
b"int (struct PyUSMArrayObject *)",
fn_restype=ctypes.c_int,
fn_argtypes=(ctypes.py_object,),
)
itemsize = get_elemsize_fn(X)
assert type(itemsize) is int
Expand All @@ -478,6 +488,7 @@ def test_pyx_capi_get_flags():
"UsmNDArray_GetFlags",
b"int (struct PyUSMArrayObject *)",
fn_restype=ctypes.c_int,
fn_argtypes=(ctypes.py_object,),
)
flags = get_flags_fn(X)
assert type(flags) is int and X.flags == flags
Expand All @@ -490,6 +501,7 @@ def test_pyx_capi_get_offset():
"UsmNDArray_GetOffset",
b"Py_ssize_t (struct PyUSMArrayObject *)",
fn_restype=ctypes.c_longlong,
fn_argtypes=(ctypes.py_object,),
)
offset = get_offset_fn(X)
assert type(offset) is int
Expand All @@ -503,11 +515,123 @@ def test_pyx_capi_get_queue_ref():
"UsmNDArray_GetQueueRef",
b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)",
fn_restype=ctypes.c_void_p,
fn_argtypes=(ctypes.py_object,),
)
queue_ref = get_queue_ref_fn(X) # address of a copy, should be unequal
assert queue_ref != X.sycl_queue.addressof_ref()


def test_pyx_capi_make_from_memory():
q = get_queue_or_skip()
n0, n1 = 4, 6
c_tuple = (ctypes.c_ssize_t * 2)(n0, n1)
mem = dpm.MemoryUSMShared(n0 * n1 * 4, queue=q)
typenum = dpt.dtype("single").num
any_usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
make_from_memory_fn = _pyx_capi_fnptr_to_callable(
any_usm_ndarray,
"UsmNDArray_MakeFromMemory",
b"PyObject *(int, Py_ssize_t const *, int, "
b"struct Py_MemoryObject *, Py_ssize_t, char)",
fn_restype=ctypes.py_object,
fn_argtypes=(
ctypes.c_int,
ctypes.POINTER(ctypes.c_ssize_t),
ctypes.c_int,
ctypes.py_object,
ctypes.c_ssize_t,
ctypes.c_char,
),
)
r = make_from_memory_fn(
ctypes.c_int(2),
c_tuple,
ctypes.c_int(typenum),
mem,
ctypes.c_ssize_t(0),
ctypes.c_char(b"C"),
)
assert isinstance(r, dpt.usm_ndarray)
assert r.ndim == 2
assert r.shape == (n0, n1)
assert r._pointer == mem._pointer
assert r.usm_type == "shared"
assert r.sycl_queue == q
assert r.flags["C"]
r2 = make_from_memory_fn(
ctypes.c_int(2),
c_tuple,
ctypes.c_int(typenum),
mem,
ctypes.c_ssize_t(0),
ctypes.c_char(b"F"),
)
ptr = mem._pointer
del mem
del r
assert isinstance(r2, dpt.usm_ndarray)
assert r2._pointer == ptr
assert r2.usm_type == "shared"
assert r2.sycl_queue == q
assert r2.flags["F"]


def test_pyx_capi_set_writable_flag():
q = get_queue_or_skip()
usm_ndarray = dpt.empty((4, 5), dtype="i4", sycl_queue=q)
assert isinstance(usm_ndarray, dpt.usm_ndarray)
assert usm_ndarray.flags["WRITABLE"] is True
set_writable = _pyx_capi_fnptr_to_callable(
usm_ndarray,
"UsmNDArray_SetWritableFlag",
b"void (struct PyUSMArrayObject *, int)",
fn_restype=None,
fn_argtypes=(ctypes.py_object, ctypes.c_int),
)
set_writable(usm_ndarray, ctypes.c_int(0))
assert isinstance(usm_ndarray, dpt.usm_ndarray)
assert usm_ndarray.flags["WRITABLE"] is False
set_writable(usm_ndarray, ctypes.c_int(1))
assert isinstance(usm_ndarray, dpt.usm_ndarray)
assert usm_ndarray.flags["WRITABLE"] is True


def test_pyx_capi_make_from_ptr():
q = get_queue_or_skip()
usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
make_from_ptr = _pyx_capi_fnptr_to_callable(
usm_ndarray,
"UsmNDArray_MakeFromPtr",
b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
b"DPCTLSyclQueueRef, PyObject *)",
fn_restype=ctypes.py_object,
fn_argtypes=(
ctypes.c_size_t,
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.py_object,
),
)
nelems = 10
dt = dpt.int64
mem = dpm.MemoryUSMDevice(nelems * dt.itemsize, queue=q)
arr = make_from_ptr(
ctypes.c_size_t(nelems),
dt.num,
mem._pointer,
mem.sycl_queue.addressof_ref(),
mem,
)
assert isinstance(arr, dpt.usm_ndarray)
assert arr.shape == (nelems,)
assert arr.dtype == dt
assert arr.sycl_queue == q
assert arr._pointer == mem._pointer
del mem
assert isinstance(arr.__repr__(), str)


def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
import sys

Expand Down