|
17 | 17 | """ Defines unit test cases for the SyclQueue class.
|
18 | 18 | """
|
19 | 19 |
|
| 20 | +import ctypes |
| 21 | +import sys |
| 22 | + |
20 | 23 | import pytest
|
21 | 24 |
|
22 | 25 | import dpctl
|
@@ -395,22 +398,22 @@ def test_hashing_of_queue():
|
395 | 398 | assert queue_dict
|
396 | 399 |
|
397 | 400 |
|
398 |
| -def test_channeling_device_properties(): |
| 401 | +def test_channeling_device_properties(capsys): |
399 | 402 | try:
|
400 | 403 | q = dpctl.SyclQueue()
|
401 | 404 | dev = q.sycl_device
|
402 | 405 | except dpctl.SyclQueueCreationError:
|
403 | 406 | pytest.fail("Failed to create device from default selector")
|
404 |
| - import io |
405 |
| - from contextlib import redirect_stdout |
406 |
| - |
407 |
| - f1 = io.StringIO() |
408 |
| - with redirect_stdout(f1): |
409 |
| - q.print_device_info() # should execute without raising |
410 |
| - f2 = io.StringIO() |
411 |
| - with redirect_stdout(f2): |
412 |
| - dev.print_device_info() |
413 |
| - assert f1.getvalue() == f2.getvalue(), "Mismatch in print_device_info" |
| 407 | + |
| 408 | + q.print_device_info() # should execute without raising |
| 409 | + q_captured = capsys.readouterr() |
| 410 | + q_output = q_captured.out |
| 411 | + dev.print_device_info() |
| 412 | + d_captured = capsys.readouterr() |
| 413 | + d_output = d_captured.out |
| 414 | + assert q_output, "No output captured" |
| 415 | + assert q_output == d_output, "Mismatch in print_device_info" |
| 416 | + assert q_captured.err == "" and d_captured.err == "" |
414 | 417 | for pr in ["backend", "name", "driver_version"]:
|
415 | 418 | assert getattr(q, pr) == getattr(
|
416 | 419 | dev, pr
|
@@ -468,9 +471,6 @@ def test_queue_capsule():
|
468 | 471 |
|
469 | 472 |
|
470 | 473 | def test_cpython_api():
|
471 |
| - import ctypes |
472 |
| - import sys |
473 |
| - |
474 | 474 | q = dpctl.SyclQueue()
|
475 | 475 | mod = sys.modules[q.__class__.__module__]
|
476 | 476 | # get capsule storign get_context_ref function ptr
|
|
0 commit comments