Skip to content

Commit 1851c69

Browse files
Added tests for new C-API
Used ctypes to call new C-API for dpctl.tensor.usm_ndarray
1 parent 6d3c08f commit 1851c69

File tree

1 file changed

+107
-2
lines changed

1 file changed

+107
-2
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,11 @@ def test_datapi_device():
373373

374374

375375
def _pyx_capi_fnptr_to_callable(
376-
X, pyx_capi_name, caps_name, fn_restype=ctypes.c_void_p
376+
X,
377+
pyx_capi_name,
378+
caps_name,
379+
fn_restype=ctypes.c_void_p,
380+
fn_argtypes=(ctypes.py_object,),
377381
):
378382
import sys
379383

@@ -388,7 +392,7 @@ def _pyx_capi_fnptr_to_callable(
388392
cap_ptr_fn.restype = ctypes.c_void_p
389393
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
390394
fn_ptr = cap_ptr_fn(cap, caps_name)
391-
callable_maker_ptr = ctypes.PYFUNCTYPE(fn_restype, ctypes.py_object)
395+
callable_maker_ptr = ctypes.PYFUNCTYPE(fn_restype, *fn_argtypes)
392396
return callable_maker_ptr(fn_ptr)
393397

394398

@@ -399,6 +403,7 @@ def test_pyx_capi_get_data():
399403
"UsmNDArray_GetData",
400404
b"char *(struct PyUSMArrayObject *)",
401405
fn_restype=ctypes.c_void_p,
406+
fn_argtypes=(ctypes.py_object,),
402407
)
403408
r1 = get_data_fn(X)
404409
sua_iface = X.__sycl_usm_array_interface__
@@ -412,6 +417,7 @@ def test_pyx_capi_get_shape():
412417
"UsmNDArray_GetShape",
413418
b"Py_ssize_t *(struct PyUSMArrayObject *)",
414419
fn_restype=ctypes.c_void_p,
420+
fn_argtypes=(ctypes.py_object,),
415421
)
416422
c_longlong_p = ctypes.POINTER(ctypes.c_longlong)
417423
shape0 = ctypes.cast(get_shape_fn(X), c_longlong_p).contents.value
@@ -425,6 +431,7 @@ def test_pyx_capi_get_strides():
425431
"UsmNDArray_GetStrides",
426432
b"Py_ssize_t *(struct PyUSMArrayObject *)",
427433
fn_restype=ctypes.c_void_p,
434+
fn_argtypes=(ctypes.py_object,),
428435
)
429436
c_longlong_p = ctypes.POINTER(ctypes.c_longlong)
430437
strides0_p = get_strides_fn(X)
@@ -441,6 +448,7 @@ def test_pyx_capi_get_ndim():
441448
"UsmNDArray_GetNDim",
442449
b"int (struct PyUSMArrayObject *)",
443450
fn_restype=ctypes.c_int,
451+
fn_argtypes=(ctypes.py_object,),
444452
)
445453
assert get_ndim_fn(X) == X.ndim
446454

@@ -452,6 +460,7 @@ def test_pyx_capi_get_typenum():
452460
"UsmNDArray_GetTypenum",
453461
b"int (struct PyUSMArrayObject *)",
454462
fn_restype=ctypes.c_int,
463+
fn_argtypes=(ctypes.py_object,),
455464
)
456465
typenum = get_typenum_fn(X)
457466
assert type(typenum) is int
@@ -465,6 +474,7 @@ def test_pyx_capi_get_elemsize():
465474
"UsmNDArray_GetElementSize",
466475
b"int (struct PyUSMArrayObject *)",
467476
fn_restype=ctypes.c_int,
477+
fn_argtypes=(ctypes.py_object,),
468478
)
469479
itemsize = get_elemsize_fn(X)
470480
assert type(itemsize) is int
@@ -478,6 +488,7 @@ def test_pyx_capi_get_flags():
478488
"UsmNDArray_GetFlags",
479489
b"int (struct PyUSMArrayObject *)",
480490
fn_restype=ctypes.c_int,
491+
fn_argtypes=(ctypes.py_object,),
481492
)
482493
flags = get_flags_fn(X)
483494
assert type(flags) is int and X.flags == flags
@@ -490,6 +501,7 @@ def test_pyx_capi_get_offset():
490501
"UsmNDArray_GetOffset",
491502
b"Py_ssize_t (struct PyUSMArrayObject *)",
492503
fn_restype=ctypes.c_longlong,
504+
fn_argtypes=(ctypes.py_object,),
493505
)
494506
offset = get_offset_fn(X)
495507
assert type(offset) is int
@@ -503,11 +515,104 @@ def test_pyx_capi_get_queue_ref():
503515
"UsmNDArray_GetQueueRef",
504516
b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)",
505517
fn_restype=ctypes.c_void_p,
518+
fn_argtypes=(ctypes.py_object,),
506519
)
507520
queue_ref = get_queue_ref_fn(X) # address of a copy, should be unequal
508521
assert queue_ref != X.sycl_queue.addressof_ref()
509522

510523

524+
def test_pyx_capi_make_from_memory():
525+
q = get_queue_or_skip()
526+
n0, n1 = 4, 6
527+
c_tuple = (ctypes.c_ssize_t * 2)(n0, n1)
528+
mem = dpm.MemoryUSMShared(n0 * n1 * 4, queue=q)
529+
typenum = dpt.dtype("single").num
530+
any_usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
531+
make_from_memory_fn = _pyx_capi_fnptr_to_callable(
532+
any_usm_ndarray,
533+
"UsmNDArray_MakeFromMemory",
534+
b"PyObject *(int, Py_ssize_t const *, int, "
535+
b"struct Py_MemoryObject *, Py_ssize_t)",
536+
fn_restype=ctypes.py_object,
537+
fn_argtypes=(
538+
ctypes.c_int,
539+
ctypes.POINTER(ctypes.c_ssize_t),
540+
ctypes.c_int,
541+
ctypes.py_object,
542+
ctypes.c_ssize_t,
543+
),
544+
)
545+
r = make_from_memory_fn(
546+
ctypes.c_int(2),
547+
c_tuple,
548+
ctypes.c_int(typenum),
549+
mem,
550+
ctypes.c_ssize_t(0),
551+
)
552+
assert isinstance(r, dpt.usm_ndarray)
553+
assert r.ndim == 2
554+
assert r.shape == (n0, n1)
555+
assert r._pointer == mem._pointer
556+
assert r.usm_type == "shared"
557+
assert r.sycl_queue == q
558+
559+
560+
def test_pyx_capi_set_writable_flag():
561+
q = get_queue_or_skip()
562+
usm_ndarray = dpt.empty((4, 5), dtype="i4", sycl_queue=q)
563+
assert isinstance(usm_ndarray, dpt.usm_ndarray)
564+
assert usm_ndarray.flags["WRITABLE"] is True
565+
set_writable = _pyx_capi_fnptr_to_callable(
566+
usm_ndarray,
567+
"UsmNDArray_SetWritableFlag",
568+
b"void (struct PyUSMArrayObject *, int)",
569+
fn_restype=None,
570+
fn_argtypes=(ctypes.py_object, ctypes.c_int),
571+
)
572+
set_writable(usm_ndarray, ctypes.c_int(0))
573+
assert isinstance(usm_ndarray, dpt.usm_ndarray)
574+
assert usm_ndarray.flags["WRITABLE"] is False
575+
set_writable(usm_ndarray, ctypes.c_int(1))
576+
assert isinstance(usm_ndarray, dpt.usm_ndarray)
577+
assert usm_ndarray.flags["WRITABLE"] is True
578+
579+
580+
def test_pyx_capi_make_from_ptr():
581+
q = get_queue_or_skip()
582+
usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
583+
make_from_ptr = _pyx_capi_fnptr_to_callable(
584+
usm_ndarray,
585+
"UsmNDArray_MakeFromPtr",
586+
b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
587+
b"DPCTLSyclQueueRef, PyObject *)",
588+
fn_restype=ctypes.py_object,
589+
fn_argtypes=(
590+
ctypes.c_size_t,
591+
ctypes.c_int,
592+
ctypes.c_void_p,
593+
ctypes.c_void_p,
594+
ctypes.py_object,
595+
),
596+
)
597+
nelems = 10
598+
dt = dpt.int64
599+
mem = dpm.MemoryUSMDevice(nelems * dt.itemsize, queue=q)
600+
arr = make_from_ptr(
601+
ctypes.c_size_t(nelems),
602+
dt.num,
603+
mem._pointer,
604+
mem.sycl_queue.addressof_ref(),
605+
mem,
606+
)
607+
assert isinstance(arr, dpt.usm_ndarray)
608+
assert arr.shape == (nelems,)
609+
assert arr.dtype == dt
610+
assert arr.sycl_queue == q
611+
assert arr._pointer == mem._pointer
612+
del mem
613+
assert isinstance(arr.__repr__(), str)
614+
615+
511616
def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
512617
import sys
513618

0 commit comments

Comments
 (0)