Skip to content

Commit c5585c2

Browse files
Class dpctl_capi made aware of new UsmNDArray_* C-API functions
1 parent 3df30fa commit c5585c2

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ class dpctl_capi
114114
int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
115115
DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
116116
py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
117+
void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
118+
PyObject *(*UsmNDArray_MakeFromMemory_)(int,
119+
const py::ssize_t *,
120+
int,
121+
Py_MemoryObject *,
122+
py::ssize_t);
123+
PyObject *(*UsmNDArray_MakeFromPtr_)(size_t,
124+
int,
125+
DPCTLSyclUSMRef,
126+
DPCTLSyclQueueRef,
127+
PyObject *);
117128

118129
int USM_ARRAY_C_CONTIGUOUS_;
119130
int USM_ARRAY_F_CONTIGUOUS_;
@@ -220,11 +231,13 @@ class dpctl_capi
220231
UsmNDArray_GetShape_(nullptr), UsmNDArray_GetStrides_(nullptr),
221232
UsmNDArray_GetTypenum_(nullptr), UsmNDArray_GetElementSize_(nullptr),
222233
UsmNDArray_GetFlags_(nullptr), UsmNDArray_GetQueueRef_(nullptr),
223-
UsmNDArray_GetOffset_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
224-
USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
225-
UAR_SHORT_(-1), UAR_USHORT_(-1), UAR_INT_(-1), UAR_UINT_(-1),
226-
UAR_LONG_(-1), UAR_ULONG_(-1), UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1),
227-
UAR_FLOAT_(-1), UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
234+
UsmNDArray_GetOffset_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
235+
UsmNDArray_MakeFromMemory_(nullptr), UsmNDArray_MakeFromPtr_(nullptr),
236+
USM_ARRAY_C_CONTIGUOUS_(0), USM_ARRAY_F_CONTIGUOUS_(0),
237+
USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1), UAR_SHORT_(-1),
238+
UAR_USHORT_(-1), UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1),
239+
UAR_ULONG_(-1), UAR_LONGLONG_(-1), UAR_ULONGLONG_(-1), UAR_FLOAT_(-1),
240+
UAR_DOUBLE_(-1), UAR_CFLOAT_(-1), UAR_CDOUBLE_(-1),
228241
UAR_TYPE_SENTINEL_(-1), UAR_HALF_(-1), UAR_INT8_(-1), UAR_UINT8_(-1),
229242
UAR_INT16_(-1), UAR_UINT16_(-1), UAR_INT32_(-1), UAR_UINT32_(-1),
230243
UAR_INT64_(-1), UAR_UINT64_(-1), default_sycl_queue_{},
@@ -295,6 +308,9 @@ class dpctl_capi
295308
this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
296309
this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
297310
this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
311+
this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
312+
this->UsmNDArray_MakeFromMemory_ = UsmNDArray_MakeFromMemory;
313+
this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
298314

299315
// constants
300316
this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;

0 commit comments

Comments
 (0)