Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ class GenerateMetadata(TypedDict):
time_ms: float
avg_reward: float
avg_metrics: dict[str, float]
avg_error: float
pass_at_k: dict[str, float]
pass_all_k: dict[str, float]
pass_threshold: float
usage: TokenUsage | None
version_info: VersionInfo
state_columns: list[str]
path_to_save: Path
Expand Down Expand Up @@ -268,6 +273,8 @@ class Environment(ABC):
env_id: str | None = None,
env_args: dict | None = None,
max_seq_len: int | None = None,
score_rollouts: bool = True,
pass_threshold: float = 0.5,
**kwargs,
): ...
```
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,9 @@ def _make_metadata(
time_ms: float = 0.0,
avg_reward: float = 0.0,
avg_metrics: dict[str, float] = {},
pass_at_k: dict[str, float] = {},
pass_all_k: dict[str, float] = {},
pass_threshold: float = 0.5,
usage: dict[str, float] | None = None,
version_info: dict | None = None,
state_columns: list[str] = ["foo"],
Expand All @@ -579,6 +582,9 @@ def _make_metadata(
time_ms=time_ms,
avg_reward=avg_reward,
avg_metrics=avg_metrics,
pass_at_k=pass_at_k,
pass_all_k=pass_all_k,
pass_threshold=pass_threshold,
usage=usage,
version_info=version_info,
state_columns=state_columns,
Expand Down
205 changes: 205 additions & 0 deletions tests/test_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydantic import BaseModel

from verifiers.types import ClientConfig
from verifiers.utils.metric_utils import compute_pass_at_k
from verifiers.utils.save_utils import (
GenerateOutputsBuilder,
extract_usage_tokens,
Expand Down Expand Up @@ -425,3 +426,207 @@ def test_validate_resume_metadata_raises_on_mismatch(self, tmp_path: Path):
num_examples=3,
rollouts_per_example=2,
)


class TestComputePassAtK:
@staticmethod
def _make_output(example_id: int, reward: float) -> dict:
return {"example_id": example_id, "reward": reward}

def test_single_rollout_returns_empty(self):
"""rollouts_per_example=1 should return empty dicts."""
outputs = [self._make_output(0, 1.0)]
pass_at_k, pass_all_k = compute_pass_at_k(outputs, rollouts_per_example=1)
assert pass_at_k == {}
assert pass_all_k == {}

def test_all_correct(self):
"""All rollouts correct → pass@k = 1.0 and pass^k = 1.0 for all k."""
outputs = [self._make_output(0, 1.0) for _ in range(8)]
pass_at_k, pass_all_k = compute_pass_at_k(outputs, rollouts_per_example=8)
assert set(pass_at_k.keys()) == {"1", "2", "4", "8"}
for k in pass_at_k:
assert pass_at_k[k] == pytest.approx(1.0)
assert pass_all_k[k] == pytest.approx(1.0)

def test_none_correct(self):
"""No rollouts correct → pass@k = 0.0 and pass^k = 0.0 for all k."""
outputs = [self._make_output(0, 0.0) for _ in range(8)]
pass_at_k, pass_all_k = compute_pass_at_k(outputs, rollouts_per_example=8)
assert set(pass_at_k.keys()) == {"1", "2", "4", "8"}
for k in pass_at_k:
assert pass_at_k[k] == pytest.approx(0.0)
assert pass_all_k[k] == pytest.approx(0.0)

def test_partial_correctness(self):
"""Partial correctness: 2 correct out of 4 rollouts."""
outputs = [
self._make_output(0, 1.0),
self._make_output(0, 1.0),
self._make_output(0, 0.0),
self._make_output(0, 0.0),
]
pass_at_k, pass_all_k = compute_pass_at_k(outputs, rollouts_per_example=4)
# k values: 1, 2, 4
assert set(pass_at_k.keys()) == {"1", "2", "4"}
# n=4, c=2: pass@1 = 1 - C(2,1)/C(4,1) = 1 - 2/4 = 0.5
assert pass_at_k["1"] == pytest.approx(0.5)
# n=4, c=2: pass@2 = 1 - C(2,2)/C(4,2) = 1 - 1/6
assert pass_at_k["2"] == pytest.approx(1.0 - 1.0 / 6.0)
# n=4, c=2: pass@4 = 1 - C(2,4)/C(4,4) = 1 - 0/1 = 1.0 (n-c < k)
assert pass_at_k["4"] == pytest.approx(1.0)
# pass^k: C(c,k)/C(n,k)
# n=4, c=2: pass^1 = C(2,1)/C(4,1) = 2/4 = 0.5
assert pass_all_k["1"] == pytest.approx(0.5)
# n=4, c=2: pass^2 = C(2,2)/C(4,2) = 1/6
assert pass_all_k["2"] == pytest.approx(1.0 / 6.0)
# n=4, c=2: pass^4 = C(2,4)/C(4,4) = 0/1 = 0.0
assert pass_all_k["4"] == pytest.approx(0.0)

def test_multiple_examples_averaged(self):
"""pass@k and pass^k are averaged across multiple examples."""
outputs = [
# Example 0: all correct
self._make_output(0, 1.0),
self._make_output(0, 1.0),
self._make_output(0, 1.0),
self._make_output(0, 1.0),
# Example 1: none correct
self._make_output(1, 0.0),
self._make_output(1, 0.0),
self._make_output(1, 0.0),
self._make_output(1, 0.0),
]
pass_at_k, pass_all_k = compute_pass_at_k(outputs, rollouts_per_example=4)
assert set(pass_at_k.keys()) == {"1", "2", "4"}
# pass@1: (1.0 + 0.0) / 2 = 0.5
assert pass_at_k["1"] == pytest.approx(0.5)
# pass@2: (1.0 + 0.0) / 2 = 0.5
assert pass_at_k["2"] == pytest.approx(0.5)
# pass@4: (1.0 + 0.0) / 2 = 0.5
assert pass_at_k["4"] == pytest.approx(0.5)
# pass^1: (1.0 + 0.0) / 2 = 0.5
assert pass_all_k["1"] == pytest.approx(0.5)
# pass^4: (1.0 + 0.0) / 2 = 0.5
assert pass_all_k["4"] == pytest.approx(0.5)

def test_powers_of_two_k_selection(self):
"""k values are powers of 2 in [1, n]."""
outputs = [self._make_output(0, 1.0) for _ in range(16)]
pass_at_k, _ = compute_pass_at_k(outputs, rollouts_per_example=16)
assert set(pass_at_k.keys()) == {"1", "2", "4", "8", "16"}

def test_n3_k_values(self):
"""n=3 should give k=1,2."""
outputs = [self._make_output(0, 1.0) for _ in range(3)]
pass_at_k, _ = compute_pass_at_k(outputs, rollouts_per_example=3)
assert set(pass_at_k.keys()) == {"1", "2"}

def test_correctness_threshold(self):
"""Only reward >= 0.5 counts as correct by default."""
outputs = [
self._make_output(0, 0.49), # not correct
self._make_output(0, 0.5), # correct
self._make_output(0, 1.0), # correct
self._make_output(0, 0.0), # not correct
]
pass_at_k, _ = compute_pass_at_k(outputs, rollouts_per_example=4)
# n=4, c=2
assert pass_at_k["1"] == pytest.approx(0.5)

def test_custom_threshold(self):
"""Custom threshold changes which rollouts count as correct."""
outputs = [
self._make_output(0, 0.4), # not correct at 0.7
self._make_output(0, 0.6), # not correct at 0.7
self._make_output(0, 0.8), # correct at 0.7
self._make_output(0, 0.3), # not correct at 0.7
]
pass_at_k, _ = compute_pass_at_k(outputs, rollouts_per_example=4, threshold=0.7)
# n=4, c=1: pass@1 = 1 - C(3,1)/C(4,1) = 1 - 3/4 = 0.25
assert pass_at_k["1"] == pytest.approx(0.25)
# n=4, c=1: pass@2 = 1 - C(3,2)/C(4,2) = 1 - 3/6 = 0.5
assert pass_at_k["2"] == pytest.approx(0.5)

def test_builder_includes_pass_at_k(self):
"""GenerateOutputsBuilder.build_metadata() includes pass_at_k and pass_all_k."""
builder = GenerateOutputsBuilder(
env_id="test-env",
env_args={},
model="test-model",
client=ClientConfig(api_base_url="http://localhost:8000/v1"),
num_examples=1,
rollouts_per_example=4,
state_columns=[],
sampling_args={},
results_path=Path("/tmp/test-results"),
)
builder.add_outputs(
[
{"example_id": 0, "reward": 1.0, "metrics": {}},
{"example_id": 0, "reward": 0.0, "metrics": {}},
{"example_id": 0, "reward": 1.0, "metrics": {}},
{"example_id": 0, "reward": 0.0, "metrics": {}},
]
)
metadata = builder.build_metadata()
assert set(metadata["pass_at_k"].keys()) == {"1", "2", "4"}
assert set(metadata["pass_all_k"].keys()) == {"1", "2", "4"}
assert metadata["pass_threshold"] == 0.5

def test_builder_uses_custom_threshold(self):
"""GenerateOutputsBuilder respects pass_threshold."""
builder = GenerateOutputsBuilder(
env_id="test-env",
env_args={},
model="test-model",
client=ClientConfig(api_base_url="http://localhost:8000/v1"),
num_examples=1,
rollouts_per_example=4,
state_columns=[],
sampling_args={},
results_path=Path("/tmp/test-results"),
pass_threshold=0.7,
)
builder.add_outputs(
[
{"example_id": 0, "reward": 0.4, "metrics": {}},
{"example_id": 0, "reward": 0.6, "metrics": {}},
{"example_id": 0, "reward": 0.8, "metrics": {}},
{"example_id": 0, "reward": 0.3, "metrics": {}},
]
)
metadata = builder.build_metadata()
assert metadata["pass_threshold"] == 0.7
# 1 of 4 correct at threshold=0.7: pass@1 = 1 - C(3,1)/C(4,1) = 0.25
assert metadata["pass_at_k"]["1"] == pytest.approx(0.25)
# 1 of 4 correct at threshold=0.7: pass^1 = C(1,1)/C(4,1) = 0.25
assert metadata["pass_all_k"]["1"] == pytest.approx(0.25)

def test_incomplete_groups_excluded(self):
"""Examples with fewer outputs than rollouts_per_example are excluded."""
outputs = [
# Example 0: complete group (4 rollouts), all correct
self._make_output(0, 1.0),
self._make_output(0, 1.0),
self._make_output(0, 1.0),
self._make_output(0, 1.0),
# Example 1: incomplete group (only 2 of 4 rollouts)
self._make_output(1, 0.0),
self._make_output(1, 0.0),
]
pass_at_k, pass_all_k = compute_pass_at_k(outputs, rollouts_per_example=4)
# Only example 0 contributes; example 1 is excluded entirely
assert pass_at_k["1"] == pytest.approx(1.0)
assert pass_all_k["1"] == pytest.approx(1.0)

def test_all_groups_incomplete_returns_empty(self):
"""If no example has a complete group, return empty dicts."""
outputs = [
self._make_output(0, 1.0),
self._make_output(0, 1.0),
self._make_output(1, 1.0),
]
pass_at_k, pass_all_k = compute_pass_at_k(outputs, rollouts_per_example=4)
assert pass_at_k == {}
assert pass_all_k == {}
3 changes: 3 additions & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
map_kwargs: dict = {},
max_seq_len: int | None = None,
score_rollouts: bool = True,
pass_threshold: float = 0.5,
**kwargs,
):
if message_type is _MESSAGE_TYPE_UNSET:
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
self.map_kwargs = map_kwargs

self.set_score_rollouts(score_rollouts)
self.pass_threshold = pass_threshold

self.env_client: EnvClient | None = None
self.env_server_process: BaseProcess | None = None
Expand Down Expand Up @@ -932,6 +934,7 @@ def default_on_progress(*a, **kw):
state_columns=state_columns,
sampling_args=sampling_args,
results_path=results_path,
pass_threshold=self.pass_threshold,
)

single_client: Client | None = None
Expand Down
3 changes: 3 additions & 0 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ class GenerateMetadata(TypedDict):
avg_reward: float
avg_metrics: dict[str, float]
avg_error: float
pass_at_k: dict[str, float]
pass_all_k: dict[str, float]
pass_threshold: float
usage: TokenUsage | None
version_info: VersionInfo
state_columns: list[str]
Expand Down
25 changes: 24 additions & 1 deletion verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from verifiers.utils.async_utils import EventLoopLagMonitor
from verifiers.utils.import_utils import load_toml
from verifiers.utils.logging_utils import print_prompt_completions_sample, print_time
from verifiers.utils.metric_utils import compute_pass_at_k
from verifiers.utils.path_utils import get_eval_results_path

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -423,6 +424,21 @@ def print_rewards(results: GenerateOutputs):
out = f"r{i + 1}: {trials}"
print(out)

threshold = results["metadata"].get("pass_threshold", 0.5)
pass_at_k, pass_all_k = compute_pass_at_k(results["outputs"], r, threshold)
if pass_at_k:
parts = [
f"{k}={v:.3f}"
for k, v in sorted(pass_at_k.items(), key=lambda x: int(x[0]))
]
print(f"pass@k: {', '.join(parts)}")
if pass_all_k:
parts = [
f"{k}={v:.3f}"
for k, v in sorted(pass_all_k.items(), key=lambda x: int(x[0]))
]
print(f"pass^k: {', '.join(parts)}")

metrics = [o["metrics"] for o in results["outputs"]]
metrics_col = to_col_order(metrics)
for k in metrics_col.keys():
Expand Down Expand Up @@ -739,11 +755,18 @@ def on_display_progress(
new_outputs: list[RolloutOutput],
metadata: GenerateMetadata,
) -> None:
metrics = dict(metadata.get("avg_metrics") or {})
pass_at_k = metadata.get("pass_at_k") or {}
for k, v in pass_at_k.items():
metrics[f"pass@{k}"] = v
pass_all_k = metadata.get("pass_all_k") or {}
for k, v in pass_all_k.items():
metrics[f"pass^{k}"] = v
display.update_env_state(
env_idx,
progress=len(all_outputs),
reward=metadata.get("avg_reward"),
metrics=metadata.get("avg_metrics"),
metrics=metrics,
error_rate=metadata.get("avg_error"),
usage=metadata.get("usage"),
)
Expand Down
Loading
Loading