Skip to content

Add support for nested py_capture_output() calls. #1564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# reticulate (development version)

- Fixed an issue where nested `py_capture_output()` calls result in a lost reference
to the original `sys.stdout` and `sys.stderr`, resulting in no further visible output
from Python, and possibly a segfault. (#1564)

- Fixed an issue where printing a delayed module (`import("foo", delay_load = TRUE)`)
would output `<pointer: 0x0>`.

Expand Down
19 changes: 14 additions & 5 deletions R/python.R
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,8 @@ register_class_filter <- function(filter) {
py_capture_output <- function(expr, type = c("stdout", "stderr")) {

# initialize python if necessary
# without expressing an implict venv preference
# via an internal import() call
ensure_python_initialized()

# resolve type argument
Expand All @@ -1057,14 +1059,21 @@ py_capture_output <- function(expr, type = c("stdout", "stderr")) {
# scope output capture
capture_stdout <- "stdout" %in% type
capture_stderr <- "stderr" %in% type
output_tools$start_capture(capture_stdout, capture_stderr)
on.exit(output_tools$end_capture(capture_stdout, capture_stderr), add = TRUE)

# evaluate the expression
force(expr)
context_manager <- output_tools$CaptureOutputStreams(
capture_stdout, capture_stderr
)

context_manager$`__enter__`()
tryCatch(
force(expr),
finally = {
context_manager$`__exit__`()
}
)

# collect output
output_tools$collect_output()
context_manager$collect_output()

}

Expand Down
142 changes: 56 additions & 86 deletions inst/python/rpytools/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
else:
from io import StringIO

_capture_stdout = StringIO()
_capture_stderr = StringIO()
_stdout = None
_stderr = None

def _setStream(handler, stream):
setStream = getattr(handler, "setStream", None)
if setStream is not None:
return setStream(stream)
old_stream = handler.stream
handler.stream = stream
return old_stream


def _override_logger_streams(
Expand All @@ -32,108 +36,74 @@ def _override_logger_streams(
stream = getattr(handler, "stream", None)
if stream is None:
continue

if capture_stdout and stream is old_stdout:
handler.stream = new_stdout

if capture_stderr and stream is old_stderr:
handler.stream = new_stderr

_setStream(handler, new_stdout)
elif capture_stderr and stream is old_stderr:
_setStream(handler, new_stderr)
# capture loggers registered with the default manager
loggers = getattr(logging.Logger.manager, "loggerDict", {})
for logger in loggers.values():
handlers = getattr(logger, "handlers", [])
for handler in handlers:

stream = getattr(handler, "stream", None)
if stream is None:
continue
if capture_stdout and stream is old_stdout:
_setStream(handler, new_stdout)
elif capture_stderr and stream is old_stderr:
_setStream(handler, new_stderr)

if capture_stdout and handler.stream is old_stdout:
handler.stream = new_stdout

if capture_stderr and handler.stream is old_stderr:
handler.stream = new_stderr


def start_capture(capture_stdout, capture_stderr):

global _stdout
global _stderr
class CaptureOutputStreams:
def __init__(self, capture_stdout, capture_stderr):
self._capture_stdout = capture_stdout
self._capture_stderr = capture_stderr

if capture_stdout:
_stdout = sys.stdout
sys.stdout = _capture_stdout
def __enter__(self):
self._capturing_stream = StringIO()
if self._capture_stdout:
self._prev_stdout = sys.stdout
sys.stdout = self._capturing_stream

if capture_stderr:
_stderr = sys.stderr
sys.stderr = _capture_stderr
if self._capture_stderr:
self._prev_stderr = sys.stderr
sys.stderr = self._capturing_stream

try:
_override_logger_streams(
capture_stdout,
sys.__stdout__,
_capture_stdout,
capture_stderr,
sys.__stderr__,
_capture_stderr,
capture_stdout=self._capture_stdout,
new_stdout=sys.stdout if self._capture_stdout else None,
old_stdout=self._prev_stdout if self._capture_stdout else None,
capture_stderr=self._capture_stderr,
new_stderr=sys.stderr if self._capture_stderr else None,
old_stderr=self._prev_stderr if self._capture_stderr else None,
)
except:
pass

self._active = True

def end_capture(capture_stdout, capture_stderr):

global _stdout
global _stderr

if capture_stdout:
_capture_stdout.seek(0)
_capture_stdout.truncate()
sys.stdout = _stdout
_stdout = None

if capture_stderr:
_capture_stderr.seek(0)
_capture_stderr.truncate()
sys.stderr = _stderr
_stderr = None
def __exit__(self, *args):
self._capturing_stream.flush()
if self._capture_stdout:
sys.stdout = self._prev_stdout
if self._capture_stderr:
sys.stderr = self._prev_stderr

try:
_override_logger_streams(
capture_stdout,
_capture_stdout,
sys.__stdout__,
capture_stderr,
_capture_stderr,
sys.__stderr__,
capture_stdout=self._capture_stdout,
new_stdout=sys.stdout,
old_stdout=self._prev_stdout if self._capture_stdout else None,
capture_stderr=self._capture_stderr,
new_stderr=sys.stderr,
old_stderr=self._prev_stderr if self._capture_stderr else None,
)
except:
pass


def collect_output():

global _stdout
global _stderr

# collect outputs into array
outputs = []
if _stdout is not None:
stdout = _capture_stdout.getvalue()
if stdout:
outputs.append(stdout)

if _stderr is not None:
stderr = _capture_stderr.getvalue()
if stderr:
outputs.append(stderr)

# ensure trailing newline
outputs.append("")

# join outputs
return "\n".join(outputs)
self._active = False

def collect_output(self):
if self._active:
raise Exception(
"Must exit capturing context before collecting output"
)
output = self._capturing_stream.getvalue()
self._capturing_stream.close()
return output


class OutputRemap(object):
Expand Down
78 changes: 68 additions & 10 deletions tests/testthat/test-python-output.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
context("output")

capture_test_output <- function(type) {
sys <- import("sys")
py_capture_output(type = type, {
if ("stdout" %in% type)
sys$stdout$write("out");
sys$stdout$write("out\n");
if ("stderr" %in% type)
sys$stderr$write("err");
sys$stderr$write("err\n");
})
}

Expand All @@ -25,25 +26,82 @@ test_that("Python stderr stream can be captured", {
})

test_that("Python loggers work with py_capture_output", {

skip_if(py_version() < "3.2")
skip_on_os("windows")

output <- py_capture_output({
logging <- import("logging")
l <- logging$getLogger("test.logger")
l$addHandler(logging$StreamHandler())
l$setLevel("INFO")
l$info("info")
})
expect_equal(output, "info\n\n")

expect_equal(output, "info\n")

l <- logging$getLogger("test.logger2")
l$addHandler(logging$StreamHandler())
l$setLevel("INFO")
output <- py_capture_output(l$info("info"))

expect_equal(output, "info\n\n")


expect_equal(output, "info\n")

})


test_that("nested py_capture_output() calls work", {

# capture original py ids to check we restored
# everything correctly at the end
sys <- import("sys")
og_sys.stdout_pd_id <- py_id(sys$stdout)
og_sys.stderr_pd_id <- py_id(sys$stderr)
og_sys.__stdout___pd_id <- py_id(sys$`__stdout__`)
og_sys.__stderr___pd_id <- py_id(sys$`__stderr__`)

# Outer level captures both stdout and stderr
level_1 <- py_capture_output({

py_run_string("print('Start outer')")

# Middle level is configured to only capture stdout,
# allowing stderr to propagate to the outer level
level_2 <- py_capture_output(type = "stdout", {

py_run_string("print('Start middle')")

# Innermost level captures both stdout and stderr
level_3 <- py_capture_output({
py_run_string("print('Start inner')")
py_run_string("import sys; print('Innermost error', file=sys.stderr)")
py_run_string("print('End inner')")
})

# Middle level generating stderr, should be captured by level_1
py_run_string("import sys; print('Middle level error', file=sys.stderr)")
py_run_string("print('End middle')")

})

py_run_string("print('End outer')")
})

# level_1 captures both stdout and stderr, including the
# stderr propagated from the middle level
expect_equal(level_1, "Start outer\nMiddle level error\nEnd outer\n")

# level_2 only captures stdout, so the stderr from the middle level is not here
expect_equal(level_2, "Start middle\nEnd middle\n")

# level 3 captures both stdout and stderr
expect_equal(level_3, "Start inner\nInnermost error\nEnd inner\n")

# Check the original streams were restored correctly
sys <- import("sys")
expect_identical(og_sys.stdout_pd_id, py_id(sys$stdout))
expect_identical(og_sys.stderr_pd_id, py_id(sys$stderr))
expect_identical(og_sys.__stdout___pd_id, py_id(sys$`__stdout__`))
expect_identical(og_sys.__stderr___pd_id, py_id(sys$`__stderr__`))

})