Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ asyncio_mode = "auto"
norecursedirs = [".git", ".tox", "dist", "build", "*.egg", "__pycache__"]

[tool.ty.rules]
unresolved-import = "warn"
unknown-argument = "warn"
redundant-cast = "ignore"

Expand Down
63 changes: 63 additions & 0 deletions tests/test_environment_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import asyncio
import json
from typing import Callable
from unittest.mock import AsyncMock, MagicMock

import pytest
from datasets import Dataset
Expand All @@ -31,6 +32,7 @@
)
from verifiers.utils.message_utils import sanitize_tool_calls
from verifiers.utils.save_utils import make_dataset as build_dataset
from verifiers.utils.save_utils import state_to_output


# Local simple concrete Environment for testing
Expand Down Expand Up @@ -133,6 +135,67 @@ async def test_get_model_response_chat_with_tools(
assert "tools" in kwargs and kwargs["tools"] == tools


@pytest.mark.asyncio
async def test_get_model_response_tracks_usage_on_state(
mock_openai_client, make_dummy_env, make_input
):
env = make_dummy_env(mock_openai_client)
prompt: vf.Messages = [{"role": "user", "content": "Track usage"}]
state = await env.init_state(
input=make_input(prompt=prompt),
client=mock_openai_client,
model="test-model",
)

resp1 = MagicMock()
resp1.choices = [MagicMock(message=MagicMock(content="ok", tool_calls=None))]
resp1.usage = {"prompt_tokens": 11, "completion_tokens": 7}

resp2 = MagicMock()
resp2.choices = [MagicMock(message=MagicMock(content="ok2", tool_calls=None))]
resp2.usage = {"input_tokens": 3, "output_tokens": 2}

mock_openai_client.chat.completions.create = AsyncMock(side_effect=[resp1, resp2])

await env.get_model_response(state=state, prompt=prompt)
await env.get_model_response(state=state, prompt=prompt)

usage = env.get_state_usage(state)
assert usage == {"input_tokens": 14.0, "output_tokens": 9.0}
assert state["usage"] == {"input_tokens": 14.0, "output_tokens": 9.0}
assert "usage_tracker" in state
with pytest.raises(TypeError):
state["usage"]["input_tokens"] = 999 # read-only view


@pytest.mark.asyncio
async def test_state_to_output_uses_state_usage_not_trajectory(
mock_openai_client, make_dummy_env, make_input
):
env = make_dummy_env(mock_openai_client)
prompt: vf.Messages = [{"role": "user", "content": "Track usage independently"}]
state = await env.init_state(
input=make_input(prompt=prompt),
client=mock_openai_client,
model="test-model",
)

resp = MagicMock()
resp.choices = [MagicMock(message=MagicMock(content="ok", tool_calls=None))]
resp.usage = {"prompt_tokens": 5, "completion_tokens": 4}
mock_openai_client.chat.completions.create = AsyncMock(return_value=resp)

await env.get_model_response(state=state, prompt=prompt)
# Simulate user clobbering visible usage and omitting response from trajectory.
state["usage"] = {"input_tokens": 0.0, "output_tokens": 0.0}
state["trajectory"] = []
state["metrics"] = {}
state["reward"] = 0.0

output = state_to_output(state, state_columns=[])
assert output["token_usage"] == {"input_tokens": 5.0, "output_tokens": 4.0}


@pytest.mark.asyncio
async def test_get_model_response_completion_rejects_tools(
mock_openai_client, make_dummy_env, make_input
Expand Down
20 changes: 20 additions & 0 deletions tests/test_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,23 @@ def test_print_results_three_rollouts(capsys, make_metadata, make_state, make_in
assert "r2: [0.2, 0.5]" in captured.out
# r3 should have [0.3, 0.6] (third rollout of each example)
assert "r3: [0.3, 0.6]" in captured.out


def test_print_results_includes_usage(capsys, make_metadata, make_output):
from verifiers.utils.eval_utils import print_results

outputs = [
make_output(example_id=0, reward=1.0, metrics={"test_metric": 1.0}),
make_output(example_id=1, reward=0.0, metrics={"test_metric": 2.0}),
]
outputs[0]["token_usage"] = {"input_tokens": 10.0, "output_tokens": 4.0}
outputs[1]["token_usage"] = {"input_tokens": 6.0, "output_tokens": 2.0}
metadata = make_metadata(num_examples=2, rollouts_per_example=1, usage=None)

results = GenerateOutputs(outputs=outputs, metadata=metadata)
print_results(results)
captured = capsys.readouterr()

assert "Usage:" in captured.out
assert "input_tokens (avg): 8.000" in captured.out
assert "output_tokens (avg): 3.000" in captured.out
22 changes: 22 additions & 0 deletions tests/test_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
states_to_outputs,
validate_resume_metadata,
)
from verifiers.utils.usage_utils import StateUsageTracker


# Test models for make_serializable tests
Expand Down Expand Up @@ -187,6 +188,27 @@ def test_extract_usage_tokens_input_output(self):
assert input_tokens == 8
assert output_tokens == 3

def test_extract_usage_tokens_invalid_values(self):
response = type(
"Response",
(),
{"usage": {"prompt_tokens": "bad", "completion_tokens": object()}},
)()
input_tokens, output_tokens = extract_usage_tokens(response)
assert input_tokens == 0
assert output_tokens == 0

def test_state_with_tracker_and_no_usage_does_not_emit_token_usage(
self, make_state
):
state = make_state()
tracker = StateUsageTracker()
state["usage_tracker"] = tracker
state["usage"] = tracker.usage
state["trajectory"] = []
output = states_to_outputs([state], state_columns=[])[0]
assert "token_usage" not in output

def test_states_to_outputs(self, make_state):
states = [
make_state(
Expand Down
24 changes: 12 additions & 12 deletions verifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def __getattr__(name: str):


if TYPE_CHECKING:
from typing import Any

from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401
from .envs.experimental.gym_env import GymEnv # noqa: F401
from .envs.experimental.harbor_env import HarborEnv # noqa: F401
Expand All @@ -155,15 +157,13 @@ def __getattr__(name: str):
from .envs.python_env import PythonEnv # noqa: F401
from .envs.sandbox_env import SandboxEnv # noqa: F401
from .rubrics.math_rubric import MathRubric # noqa: F401
from verifiers_rl.rl.trainer import ( # noqa: F401
GRPOConfig,
GRPOTrainer,
RLConfig,
RLTrainer,
grpo_defaults,
lora_defaults,
)
from verifiers_rl.rl.trainer.utils import ( # noqa: F401
get_model,
get_model_and_tokenizer,
)

# Optional verifiers-rl exports. Keep type-checking clean when extra is absent.
RLConfig: Any
RLTrainer: Any
GRPOTrainer: Any
GRPOConfig: Any
grpo_defaults: Any
lora_defaults: Any
get_model: Any
get_model_and_tokenizer: Any
58 changes: 58 additions & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
import uuid
from abc import ABC, abstractmethod
from collections.abc import Mapping
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
Expand Down Expand Up @@ -67,6 +68,7 @@
SamplingArgs,
StartCallback,
State,
TokenUsage,
)
from verifiers.utils.async_utils import (
maybe_retry,
Expand All @@ -92,6 +94,7 @@
get_prompt_ids,
prepare_sampling_args_for_token_prompts,
)
from verifiers.utils.usage_utils import StateUsageTracker
from verifiers.workers.client.env_client import EnvClient

if TYPE_CHECKING:
Expand Down Expand Up @@ -426,6 +429,58 @@ def get_eval_dataset(self, n: int = -1, seed: int | None = None) -> Dataset:
return self.eval_dataset.select(range(n))
return self.eval_dataset

@final
def _get_usage_tracker(
self, state: State, create_if_missing: bool = True
) -> StateUsageTracker | None:
tracker = state.get("usage_tracker")
if isinstance(tracker, StateUsageTracker):
return tracker
if not create_if_missing:
return None
tracker = StateUsageTracker()
state["usage_tracker"] = tracker
# Expose read-only usage in state for live inspection.
state["usage"] = tracker.usage
return tracker

@final
def increment_state_usage(
self,
state: State,
input_tokens: int | float = 0,
output_tokens: int | float = 0,
) -> None:
tracker = self._get_usage_tracker(state, create_if_missing=True)
assert tracker is not None
tracker.increment(input_tokens, output_tokens)

@final
def increment_state_usage_from_response(
self, state: State, response: object
) -> None:
tracker = self._get_usage_tracker(state, create_if_missing=True)
assert tracker is not None
tracker.increment_from_response(response)

@final
def get_state_usage(self, state: State) -> TokenUsage | None:
tracker = self._get_usage_tracker(state, create_if_missing=False)
if tracker is not None:
return tracker.snapshot()
usage = state.get("usage")
if isinstance(usage, Mapping):
try:
input_tokens = float(usage.get("input_tokens", 0.0))
output_tokens = float(usage.get("output_tokens", 0.0))
except (TypeError, ValueError):
return None
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}
return None

async def get_model_response(
self,
state: State,
Expand Down Expand Up @@ -623,6 +678,7 @@ async def get_model_response_with_tokens(
client, model, oai_tools, sampling_args, message_type = resolve_optional_args(
client, model, oai_tools, sampling_args, message_type
)
self._get_usage_tracker(state, create_if_missing=True)
sampling_args = normalize_sampling_args(sampling_args)
if self.interleaved_rollouts:
sampling_args = prepare_sampling_args_for_token_prompts(sampling_args)
Expand Down Expand Up @@ -651,6 +707,7 @@ async def get_model_response_with_tokens(
# Some providers (e.g. OpenRouter) may return None for response or response.choices
if response is None:
raise vf.EmptyModelResponseError("Model returned no response")
self.increment_state_usage_from_response(state, response)
if response.choices is None:
raise vf.EmptyModelResponseError("Model returned no response choices")
if not len(response.choices) == 1:
Expand Down Expand Up @@ -708,6 +765,7 @@ async def init_state(
else:
state["oai_tools"] = []
state["trajectory"] = []
self._get_usage_tracker(state, create_if_missing=True)
state["trajectory_id"] = uuid.uuid4().hex
state["reward"] = None
state["metrics"] = None
Expand Down
2 changes: 2 additions & 0 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class State(dict):
metrics: dict[str, float] | None
timing: RolloutTiming | None
error: Error | None
usage: TokenUsage | None
usage_tracker: object
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing documentation for new State usage fields

Low Severity

This PR adds new usage and usage_tracker fields to the State class, which is documented in docs/reference.md. The documentation lists State fields in tables under "Fields set during initialization" and "Fields set after scoring," but these new fields are not included. Per the review rules, PRs that modify core user-facing functionality described in docs must update the relevant documentation.

Fix in Cursor Fix in Web


def __getitem__(self, key: str) -> Any:
# forward to input if exists
Expand Down
31 changes: 31 additions & 0 deletions verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,35 @@ def print_timing(results: GenerateOutputs):
)


def print_usage(results: GenerateOutputs):
usage_count = 0
input_tokens_total = 0.0
output_tokens_total = 0.0
for output in results["outputs"]:
token_usage = output.get("token_usage")
if not isinstance(token_usage, Mapping):
continue
usage_count += 1
input_tokens_total += float(token_usage.get("input_tokens", 0.0))
output_tokens_total += float(token_usage.get("output_tokens", 0.0))

usage = None
if usage_count > 0:
usage = {
"input_tokens": input_tokens_total / usage_count,
"output_tokens": output_tokens_total / usage_count,
}
elif results["metadata"].get("usage") is not None:
usage = results["metadata"]["usage"]

if usage is None:
return

print("Usage:")
print(f"input_tokens (avg): {usage['input_tokens']:.3f}")
print(f"output_tokens (avg): {usage['output_tokens']:.3f}")


def print_results(results: GenerateOutputs, num_samples: int = 1):
assert results["metadata"] is not None
print("--- Evaluation ---")
Expand Down Expand Up @@ -458,6 +487,7 @@ def print_results(results: GenerateOutputs, num_samples: int = 1):
print_rewards(results)
print_info(results)
print_timing(results)
print_usage(results)

tasks = set([o["task"] for o in results["outputs"]])
if len(tasks) > 1:
Expand All @@ -467,6 +497,7 @@ def print_results(results: GenerateOutputs, num_samples: int = 1):
print_rewards(task_results)
print_info(task_results)
print_timing(task_results)
print_usage(task_results)


@contextmanager
Expand Down
Loading
Loading