Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion verifiers/utils/display_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
"""

import asyncio
import io
import logging
import os
import sys
import threading
from collections import deque
from typing import Any

Expand Down Expand Up @@ -61,12 +63,54 @@ def __init__(self, max_lines: int = 3) -> None:

def emit(self, record: logging.LogRecord) -> None:
try:
msg = self.format(record)
if record.name.endswith(".stdout") or record.name.endswith(".stderr"):
msg = record.getMessage()
else:
msg = self.format(record)
self.logs.append(msg)
except Exception:
pass


class _FDToLogger(threading.Thread):
"""Background reader that forwards a file descriptor's output to a logger."""

def __init__(
self, fd: int, logger: logging.Logger, level: int, encoding: str | None
) -> None:
super().__init__(daemon=True)
self._fd = fd
self._logger = logger
self._level = level
self._encoding = encoding or "utf-8"
self._buffer = ""

def run(self) -> None:
try:
while True:
try:
data = os.read(self._fd, 1024)
except OSError:
break
if not data:
break
text = data.decode(self._encoding, errors="replace").replace("\r", "\n")
combined = f"{self._buffer}{text}"
lines = combined.split("\n")
self._buffer = lines.pop() if lines else ""
for line in lines:
if line:
self._logger.log(self._level, line)
finally:
if self._buffer:
self._logger.log(self._level, self._buffer)
self._buffer = ""
try:
os.close(self._fd)
except OSError:
pass


class BaseDisplay:
"""
Base class for Rich-based terminal displays.
Expand All @@ -90,6 +134,13 @@ def __init__(self, screen: bool = False, refresh_per_second: int = 4) -> None:
self._log_handler = DisplayLogHandler(max_lines=3)
self._old_handler_levels: dict[logging.Handler, int] = {}
self._old_datasets_level: int | None = None
self._old_stdout = None
self._old_stderr = None
self._old_stdout_fd: int | None = None
self._old_stderr_fd: int | None = None
self._console_file: io.TextIOWrapper | None = None
self._stdout_thread: _FDToLogger | None = None
self._stderr_thread: _FDToLogger | None = None

def _render(self) -> Any:
"""
Expand Down Expand Up @@ -131,6 +182,18 @@ def start(self) -> None:

# Suppress console output from existing handlers but capture logs for display
logger = logging.getLogger("verifiers")

# Preserve original streams for Rich rendering before capturing stdout/stderr
self._old_stdout = sys.stdout
self._old_stderr = sys.stderr
self._old_stdout_fd = os.dup(1)
self._old_stderr_fd = os.dup(2)
self._console_file = io.TextIOWrapper(
os.fdopen(self._old_stdout_fd, "wb", closefd=False),
encoding=getattr(self._old_stdout, "encoding", "utf-8"),
write_through=True,
)
self.console = Console(file=self._console_file, force_terminal=True)
for handler in logger.handlers:
self._old_handler_levels[handler] = handler.level
handler.setLevel(logging.CRITICAL)
Expand All @@ -144,6 +207,28 @@ def start(self) -> None:
self._log_handler.setLevel(logging.INFO)
logger.addHandler(self._log_handler)

# Capture stdout/stderr at the FD level so stray prints don't corrupt the live display
stdout_r, stdout_w = os.pipe()
stderr_r, stderr_w = os.pipe()
os.dup2(stdout_w, 1)
os.close(stdout_w)
os.dup2(stderr_w, 2)
os.close(stderr_w)
self._stdout_thread = _FDToLogger(
stdout_r,
logger.getChild("stdout"),
logging.INFO,
getattr(self._old_stdout, "encoding", None),
)
self._stderr_thread = _FDToLogger(
stderr_r,
logger.getChild("stderr"),
logging.ERROR,
getattr(self._old_stderr, "encoding", None),
)
self._stdout_thread.start()
self._stderr_thread.start()

# Disable terminal echo in screen mode to prevent scroll/arrow keys from displaying
if self.screen and HAS_TERMINAL_CONTROL and sys.stdin.isatty():
import termios
Expand Down Expand Up @@ -171,6 +256,20 @@ def stop(self) -> None:
self._live.stop()
self._live = None

# Restore stdout/stderr file descriptors (ends pipe, unblocks readers)
if self._old_stdout_fd is not None:
os.dup2(self._old_stdout_fd, 1)
if self._old_stderr_fd is not None:
os.dup2(self._old_stderr_fd, 2)

# Join reader threads
if self._stdout_thread is not None:
self._stdout_thread.join(timeout=0.5)
self._stdout_thread = None
if self._stderr_thread is not None:
self._stderr_thread.join(timeout=0.5)
self._stderr_thread = None

# Restore datasets progress bar
from datasets import enable_progress_bar

Expand All @@ -189,6 +288,28 @@ def stop(self) -> None:
datasets_logger.setLevel(self._old_datasets_level)
self._old_datasets_level = None

# Restore stdout/stderr
if self._old_stdout is not None:
sys.stdout = self._old_stdout # type: ignore[assignment]
self._old_stdout = None
if self._old_stderr is not None:
sys.stderr = self._old_stderr # type: ignore[assignment]
self._old_stderr = None
if self._console_file is not None:
# Redirect console back to original stdout before closing temp stream
self.console = Console(file=sys.stdout, force_terminal=sys.stdout.isatty())
try:
self._console_file.flush()
self._console_file.close()
finally:
self._console_file = None
if self._old_stdout_fd is not None:
os.close(self._old_stdout_fd)
self._old_stdout_fd = None
if self._old_stderr_fd is not None:
os.close(self._old_stderr_fd)
self._old_stderr_fd = None

# Restore terminal settings
if self._old_terminal_settings is not None:
import termios
Expand Down
Loading