Skip to content

Commit 49f2378

Browse files
Used OutputGrabber to try to capture std::cout outputs of DPCTLSyclInterface
The stream `std::cout` that the library writes to in a function like `DPCTLDeviceMgr_PrintDeviceInfo` is not the same as Python's sys.stdout It makes it difficult to verify that printed output is what should be expected. The changes in this commit allow the test to pass, while ensuring that the captured output is non-empty, but pytest must be envoked with --capture=no option
1 parent d469c84 commit 49f2378

File tree

2 files changed

+108
-7
lines changed

2 files changed

+108
-7
lines changed

dpctl/tests/_redirector.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2021 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""" Defines class to capture std::cout/std::cerr outputs of loaded
18+
library (DPCTLSyclInterface) from Python.
19+
"""
20+
21+
import os
22+
import sys
23+
24+
# Source: https://stackoverflow.com/a/29834357/594376
25+
# License: https://creativecommons.org/licenses/by-sa/4.0/
26+
27+
28+
class OutputGrabber(object):
29+
"""
30+
Class used to grab standard output.
31+
"""
32+
33+
escape_char = "\b"
34+
35+
def __init__(self, stream=None):
36+
self.origstream = stream
37+
if self.origstream is None:
38+
self.origstream = sys.stdout
39+
self.origstreamfd = self.origstream.fileno()
40+
self.capturedtext = ""
41+
# Create a pipe so the stream can be captured:
42+
self.pipe_out, self.pipe_in = os.pipe()
43+
44+
def __enter__(self):
45+
self.start()
46+
return self
47+
48+
def __exit__(self, type, value, traceback):
49+
self.stop()
50+
51+
def start(self):
52+
"""
53+
Start capturing the stream data.
54+
"""
55+
self.capturedtext = ""
56+
# Save a copy of the stream:
57+
self.streamfd = os.dup(self.origstreamfd)
58+
# Replace the original stream with our write pipe:
59+
os.dup2(self.pipe_in, self.origstreamfd)
60+
61+
def stop(self):
62+
"""
63+
Stop capturing the stream data and save the text in `capturedtext`.
64+
"""
65+
# Print the escape character to make the readOutput method stop:
66+
self.origstream.write(self.escape_char)
67+
# Flush the stream to make sure all our data goes in before
68+
# the escape character:
69+
self.origstream.flush()
70+
self.readOutput()
71+
# Close the pipe:
72+
os.close(self.pipe_in)
73+
os.close(self.pipe_out)
74+
# Restore the original stream:
75+
os.dup2(self.streamfd, self.origstreamfd)
76+
# Close the duplicate stream:
77+
os.close(self.streamfd)
78+
79+
def readOutput(self):
80+
"""
81+
Read the stream data (one byte at a time)
82+
and save the text in `capturedtext`.
83+
"""
84+
while True:
85+
char = os.read(self.pipe_out, 1).decode(self.origstream.encoding)
86+
if not char or self.escape_char in char:
87+
print(char)
88+
break
89+
self.capturedtext += char

dpctl/tests/test_sycl_queue.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
""" Defines unit test cases for the SyclQueue class.
1818
"""
1919

20+
import ctypes
21+
import ctypes.util
22+
2023
import pytest
2124

2225
import dpctl
2326

2427
from ._helper import create_invalid_capsule
28+
from ._redirector import OutputGrabber
2529

2630
list_of_standard_selectors = [
2731
dpctl.select_accelerator_device,
@@ -401,16 +405,24 @@ def test_channeling_device_properties():
401405
dev = q.sycl_device
402406
except dpctl.SyclQueueCreationError:
403407
pytest.fail("Failed to create device from default selector")
404-
import io
405-
from contextlib import redirect_stdout
406408

407-
f1 = io.StringIO()
408-
with redirect_stdout(f1):
409+
libc = ctypes.cdll.LoadLibrary(ctypes.util.find_library("c"))
410+
411+
libc.puts(b"fff")
412+
out = OutputGrabber()
413+
with out:
414+
libc.puts(b"fff")
415+
416+
q_fh = OutputGrabber()
417+
with q_fh:
409418
q.print_device_info() # should execute without raising
410-
f2 = io.StringIO()
411-
with redirect_stdout(f2):
419+
d_fh = OutputGrabber()
420+
with d_fh:
412421
dev.print_device_info()
413-
assert f1.getvalue() == f2.getvalue(), "Mismatch in print_device_info"
422+
q_output = q_fh.capturedtext
423+
d_output = d_fh.capturedtext
424+
assert q_output, "No output captured"
425+
assert q_output == d_output, "Mismatch in print_device_info"
414426
for pr in ["backend", "name", "driver_version"]:
415427
assert getattr(q, pr) == getattr(
416428
dev, pr

0 commit comments

Comments
 (0)