Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 4 additions & 53 deletions verifiers/utils/eval_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
- TUI mode (screen=True): Alternate screen buffer with echo handling
"""

import json
import time
from collections.abc import Mapping
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
from typing import Literal

from rich.columns import Columns
from rich.console import Group
Expand All @@ -22,6 +20,7 @@

from verifiers.types import EvalConfig, GenerateOutputs
from verifiers.utils.display_utils import BaseDisplay, make_aligned_row
from verifiers.utils.message_utils import format_messages


@dataclass
Expand Down Expand Up @@ -59,54 +58,6 @@ def elapsed_time(self) -> float:
return end - self.start_time


def _format_messages(messages: Any) -> Text:
"""Format messages for display (similar to print_prompt_completions_sample)."""

def _attr_or_key(obj: Any, key: str, default: Any = None) -> Any:
val = getattr(obj, key, None)
if val is not None:
return val
if isinstance(obj, Mapping):
return obj.get(key, default)
return default

def _normalize_tool_call(tc: Any) -> dict[str, str]:
src = _attr_or_key(tc, "function") or tc
name = _attr_or_key(src, "name", "") or ""
args = _attr_or_key(src, "arguments", {}) or {}
if not isinstance(args, str):
try:
args = json.dumps(args)
except Exception:
args = str(args)
return {"name": name, "args": args}

if isinstance(messages, str):
return Text(messages)

out = Text()
for idx, msg in enumerate(messages):
if idx:
out.append("\n\n")

assert isinstance(msg, dict)
role = msg.get("role", "")
content = msg.get("content", "")
style = "bright_cyan" if role == "assistant" else "bright_magenta"

out.append(f"{role}: ", style="bold")
out.append(str(content) if content else "", style=style)

for tc in msg.get("tool_calls") or []:
payload = _normalize_tool_call(tc)
out.append(
"\n\n[tool call]\n" + json.dumps(payload, indent=2, ensure_ascii=False),
style=style,
)

return out


def _make_histogram(values: list[float], bins: int = 10, width: int = 20) -> Text:
"""Create a simple text histogram of values."""
if not values:
Expand Down Expand Up @@ -577,14 +528,14 @@ def _make_env_detail(
# Prompt panel
items.append(
Panel(
_format_messages(prompt),
format_messages(prompt),
title="[dim]example 0 — prompt[/dim]",
border_style="dim",
)
)

# Completion panel (with error if any)
completion_text = _format_messages(completion)
completion_text = format_messages(completion)
if error_0 is not None:
completion_text.append("\n\nerror: ", style="bold red")
completion_text.append(error_0, style="bold red")
Expand Down
58 changes: 3 additions & 55 deletions verifiers/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
import logging
import sys
from collections.abc import Mapping
from contextlib import contextmanager

from rich.console import Console
Expand All @@ -12,6 +10,7 @@
from verifiers.errors import Error
from verifiers.types import Messages
from verifiers.utils.error_utils import ErrorChain
from verifiers.utils.message_utils import format_messages

LOGGER_NAME = "verifiers"

Expand Down Expand Up @@ -81,57 +80,6 @@ def print_prompt_completions_sample(
step: int,
num_samples: int = 1,
) -> None:
def _attr_or_key(obj, key: str, default=None):
"""Return obj.key if present, else obj[key] if Mapping, else default."""
val = getattr(obj, key, None)
if val is not None:
return val
if isinstance(obj, Mapping):
return obj.get(key, default)
return default

def _normalize_tool_call(tc):
"""Return {"name": ..., "args": ...} from a dict or Pydantic-like object."""
src = (
_attr_or_key(tc, "function") or tc
) # prefer nested function object if present
name = _attr_or_key(src, "name", "") or ""
args = _attr_or_key(src, "arguments", {}) or {}

if not isinstance(args, str):
try:
args = json.dumps(args)
except Exception:
args = str(args)
return {"name": name, "args": args}

def _format_messages(messages) -> Text:
if isinstance(messages, str):
return Text(messages)

out = Text()
for idx, msg in enumerate(messages):
if idx:
out.append("\n\n")

assert isinstance(msg, dict)
role = msg.get("role", "")
content = msg.get("content", "")
style = "bright_cyan" if role == "assistant" else "bright_magenta"

out.append(f"{role}: ", style="bold")
out.append(content, style=style)

for tc in msg.get("tool_calls") or []: # treat None as empty list
payload = _normalize_tool_call(tc)
out.append(
"\n\n[tool call]\n"
+ json.dumps(payload, indent=2, ensure_ascii=False),
style=style,
)

return out

def _format_error(error: str | BaseException) -> Text:
out = Text()
if isinstance(error, str):
Expand All @@ -158,8 +106,8 @@ def _format_error(error: str | BaseException) -> Text:
error = errors[i]
reward = reward_values[i]

formatted_prompt = _format_messages(prompt)
formatted_completion = _format_messages(completion)
formatted_prompt = format_messages(prompt)
formatted_completion = format_messages(completion)
if error is not None:
formatted_completion += Text("\n\n") + _format_error(error)

Expand Down
53 changes: 52 additions & 1 deletion verifiers/utils/message_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from typing import cast
from collections.abc import Mapping
from typing import Any, cast

from openai.types.chat import (
ChatCompletionAssistantMessageParam,
)
from rich.text import Text

from verifiers.types import ChatMessage, Messages

Expand Down Expand Up @@ -85,6 +87,55 @@ def messages_to_printable(messages: Messages) -> Messages:
return [message_to_printable(m) for m in messages or []]


def format_messages(messages: Any) -> Text:
def _attr_or_key(obj: Any, key: str, default: Any = None) -> Any:
val = getattr(obj, key, None)
if val is not None:
return val
if isinstance(obj, Mapping):
return obj.get(key, default)
return default


def _normalize_tool_call(tc: Any) -> dict[str, str]:
if isinstance(tc, str):
tc = json.loads(tc)
src = _attr_or_key(tc, "function") or tc
name = _attr_or_key(src, "name", "") or ""
args = _attr_or_key(src, "arguments", {}) or {}
if not isinstance(args, str):
try:
args = json.dumps(args)
except Exception:
args = str(args)
return {"name": name, "args": args}

if isinstance(messages, str):
return Text(messages)

out = Text()
for idx, msg in enumerate(messages):
if idx:
out.append("\n\n")

assert isinstance(msg, dict)
role = msg.get("role", "")
content = msg.get("content", "")
style = "bright_cyan" if role == "assistant" else "bright_magenta"

out.append(f"{role}: ", style="bold")
out.append(str(content) if content else "", style=style)

for tc in msg.get("tool_calls") or []:
payload = _normalize_tool_call(tc)
out.append(
"\n\n[tool call]\n" + json.dumps(payload, indent=2, ensure_ascii=False),
style=style,
)

return out


def sanitize_tool_calls(messages: Messages):
"""
Sanitize tool calls from messages.
Expand Down
Loading