Skip to content

Commit c1895f7

Browse files
Extended tests to improve coverage
1 parent cb846e8 commit c1895f7

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

dpctl/tests/test_sycl_context.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,26 @@ def test_hashing_of_context():
187187
def test_context_repr():
188188
ctx = dpctl.SyclContext()
189189
assert type(ctx.__repr__()) is str
190+
191+
192+
def test_cpython_api():
193+
import ctypes
194+
import sys
195+
196+
ctx = dpctl.SyclContext()
197+
mod = sys.modules[ctx.__class__.__module__]
198+
# get capsule storign get_context_ref function ptr
199+
ctx_ref_fn_cap = mod.__pyx_capi__["get_context_ref"]
200+
# construct Python callable to invoke "get_context_ref"
201+
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
202+
cap_ptr_fn.restype = ctypes.c_void_p
203+
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
204+
ctx_ref_fn_ptr = cap_ptr_fn(
205+
ctx_ref_fn_cap, b"DPCTLSyclContextRef (struct PySyclContextObject *)"
206+
)
207+
callable_maker = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
208+
get_context_ref_fn = callable_maker(ctx_ref_fn_ptr)
209+
210+
r2 = ctx.addressof_ref()
211+
r1 = get_context_ref_fn(ctx)
212+
assert r1 == r2

dpctl/tests/test_sycl_queue.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,16 +433,80 @@ def test_queue__repr__():
433433
r2 = q2.__repr__()
434434
q3 = dpctl.SyclQueue(property="enable_profiling")
435435
r3 = q3.__repr__()
436-
q4 = dpctl.SyclQueue(property=["in_order", "enable_profiling"])
436+
q4 = dpctl.SyclQueue(property="default")
437437
r4 = q4.__repr__()
438+
q5 = dpctl.SyclQueue(property=["in_order", "enable_profiling"])
439+
r5 = q5.__repr__()
438440
assert type(r1) is str
439441
assert type(r2) is str
440442
assert type(r3) is str
441443
assert type(r4) is str
444+
assert type(r5) is str
445+
446+
447+
def test_queue_invalid_property():
448+
with pytest.raises(ValueError):
449+
dpctl.SyclQueue(property=4.5)
450+
with pytest.raises(ValueError):
451+
dpctl.SyclQueue(property=["abc", tuple()])
442452

443453

444454
def test_queue_capsule():
445455
q = dpctl.SyclQueue()
446456
cap = q._get_capsule()
457+
cap2 = q._get_capsule()
447458
q2 = dpctl.SyclQueue(cap)
448459
assert q == q2
460+
del cap2 # call deleter on non-renamed capsule
461+
462+
463+
def test_cpython_api():
464+
import ctypes
465+
import sys
466+
467+
q = dpctl.SyclQueue()
468+
mod = sys.modules[q.__class__.__module__]
469+
# get capsule storign get_context_ref function ptr
470+
q_ref_fn_cap = mod.__pyx_capi__["get_queue_ref"]
471+
# construct Python callable to invoke "get_queue_ref"
472+
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
473+
cap_ptr_fn.restype = ctypes.c_void_p
474+
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
475+
q_ref_fn_ptr = cap_ptr_fn(
476+
q_ref_fn_cap, b"DPCTLSyclQueueRef (struct PySyclQueueObject *)"
477+
)
478+
callable_maker = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
479+
get_queue_ref_fn = callable_maker(q_ref_fn_ptr)
480+
481+
r2 = q.addressof_ref()
482+
r1 = get_queue_ref_fn(q)
483+
assert r1 == r2
484+
485+
486+
def test_constructor_many_arg():
487+
with pytest.raises(TypeError):
488+
dpctl.SyclQueue(None, None, None, None)
489+
with pytest.raises(TypeError):
490+
dpctl.SyclQueue(None, None)
491+
492+
493+
def test_queue_wait():
494+
try:
495+
q = dpctl.SyclQueue()
496+
except dpctl.SyclQueueCreationError:
497+
pytest.skip("Failed to create device with supported filter")
498+
q.wait()
499+
500+
501+
def test_queue_memops():
502+
try:
503+
q = dpctl.SyclQueue()
504+
except dpctl.SyclQueueCreationError:
505+
pytest.skip("Failed to create device with supported filter")
506+
from dpctl.memory import MemoryUSMDevice
507+
508+
m1 = MemoryUSMDevice(512, queue=q)
509+
m2 = MemoryUSMDevice(512, queue=q)
510+
q.memcpy(m1, m2, 512)
511+
q.prefetch(m1, 512)
512+
q.mem_advise(m1, 512, 0)

0 commit comments

Comments
 (0)