Skip to content

Commit 67cf506

Browse files
Added some tests to involve capsule
Added test to sycl context testing file to create a context from a sub-device
1 parent 33031a2 commit 67cf506

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

dpctl/tests/test_sycl_context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ def test_context_multi_device():
159159
shmem_1 = dpmem.MemoryUSMShared(256, queue=q1)
160160
shmem_2 = dpmem.MemoryUSMDevice(256, queue=q2)
161161
shmem_2.copy_from_device(shmem_1)
162+
# create context for single sub-device
163+
ctx1 = dpctl.SyclContext(d1)
164+
q1 = dpctl.SyclQueue(ctx1, d1)
165+
shmem_1 = dpmem.MemoryUSMShared(256, queue=q1)
166+
cap = ctx1._get_capsule()
167+
del ctx1
168+
ctx2 = dpctl.SyclContext(cap)
169+
q2 = dpctl.SyclQueue(ctx2, d1)
170+
shmem_2 = dpmem.MemoryUSMDevice(256, queue=q2)
171+
shmem_2.copy_from_device(shmem_1)
162172

163173

164174
def test_hashing_of_context():
@@ -169,3 +179,8 @@ def test_hashing_of_context():
169179
"""
170180
ctx_dict = {dpctl.SyclContext(): "default_context"}
171181
assert ctx_dict
182+
183+
184+
def test_context_repr():
185+
ctx = dpctl.SyclContext()
186+
assert type(ctx.__repr__()) is str

dpctl/tests/test_sycl_event.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,10 @@ def test_sycl_timer():
180180
timer(queue=q_no_profiling)
181181
with pytest.raises(TypeError):
182182
timer(queue=None)
183+
184+
185+
def test_event_capsule():
186+
ev = dpctl.SyclEvent()
187+
cap = ev._get_capsule()
188+
ev2 = dpctl.SyclEvent(cap)
189+
assert type(ev2) == type(ev)

dpctl/tests/test_sycl_queue.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,25 @@ def test_queue_submit_barrier(valid_filter):
424424
ev3.wait()
425425
ev1.wait()
426426
ev2.wait()
427+
428+
429+
def test_queue__repr__():
430+
q1 = dpctl.SyclQueue()
431+
r1 = q1.__repr__()
432+
q2 = dpctl.SyclQueue(property="in_order")
433+
r2 = q2.__repr__()
434+
q3 = dpctl.SyclQueue(property="enable_profiling")
435+
r3 = q3.__repr__()
436+
q4 = dpctl.SyclQueue(property=["in_order", "enable_profiling"])
437+
r4 = q4.__repr__()
438+
assert type(r1) is str
439+
assert type(r2) is str
440+
assert type(r3) is str
441+
assert type(r4) is str
442+
443+
444+
def test_queue_capsule():
445+
q = dpctl.SyclQueue()
446+
cap = q._get_capsule()
447+
q2 = dpctl.SyclQueue(cap)
448+
assert q == q2

0 commit comments

Comments
 (0)