Skip to content

Commit a148f04

Browse files
Extend usm_ndarray C-API
Added UsmNDArray_MakeFromMemory, UsmNDArray_SetWritableFlag, UsmNDArray_MakeFromPtr C-functions.
1 parent db8fe2f commit a148f04

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,3 +1308,54 @@ cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr):
13081308
"""Get offset of zero-index array element from the beginning of the USM
13091309
allocation"""
13101310
return arr.get_offset()
1311+
1312+
cdef api void UsmNDArray_SetWritableFlag(usm_ndarray arr, int flag):
1313+
"""Set/unset USM_ARRAY_WRITABLE in the given array `arr`."""
1314+
cdef int arr_fl = arr.flags_
1315+
arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag
1316+
arr_fl |= (USM_ARRAY_WRITABLE if flag else 0)
1317+
arr.flags_ = arr_fl
1318+
1319+
cdef api object UsmNDArray_MakeFromMemory(
1320+
int nd, const Py_ssize_t *shape, int typenum,
1321+
c_dpmem._Memory mobj, Py_ssize_t offset
1322+
):
1323+
"""Create usm_ndarray.
1324+
1325+
Equivalent to usm_ndarray(
1326+
_make_tuple(nd, shape), dtype=_make_dtype(typenum),
1327+
buffer=mobj, offset=offset)
1328+
"""
1329+
cdef object shape_tuple = _make_int_tuple(nd, <Py_ssize_t *>shape)
1330+
cdef usm_ndarray arr = usm_ndarray(
1331+
shape_tuple,
1332+
dtype=_make_typestr(typenum),
1333+
buffer=mobj,
1334+
offset=offset
1335+
)
1336+
return arr
1337+
1338+
1339+
cdef api object UsmNDArray_MakeFromPtr(
1340+
size_t nelems,
1341+
int typenum,
1342+
c_dpctl.DPCTLSyclUSMRef ptr,
1343+
c_dpctl.DPCTLSyclQueueRef QRef,
1344+
object owner
1345+
):
1346+
"""Create usm_ndarray from pointer.
1347+
1348+
Argument owner=None implies transert of USM allocation ownership
1349+
to create array object.
1350+
"""
1351+
cdef size_t itemsize = type_bytesize(typenum)
1352+
cdef size_t nbytes = itemsize * nelems
1353+
cdef c_dpmem._Memory mobj = c_dpmem._Memory.create_from_usm_pointer_size_qref(
1354+
ptr, nbytes, QRef, memory_owner=owner
1355+
)
1356+
cdef usm_ndarray arr = usm_ndarray(
1357+
(nelems,),
1358+
dtype=_make_typestr(typenum),
1359+
buffer=mobj
1360+
)
1361+
return arr

0 commit comments

Comments
 (0)