Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ Multiple rollouts per example enable metrics like pass@k and help measure varian
| `--max-concurrent-generation` | — | same as `-c` | Concurrent generation requests |
| `--max-concurrent-scoring` | — | same as `-c` | Concurrent scoring requests |
| `--no-interleave-scoring` | `-N` | false | Disable interleaved scoring |
| `--max-retries` | — | 0 | Retries per rollout on transient `InfraError` |

By default, scoring runs interleaved with generation. Use `--no-interleave-scoring` to score all rollouts after generation completes.

The `--max-retries` flag enables automatic retry with exponential backoff when rollouts fail due to transient infrastructure errors (e.g., sandbox timeouts, API failures).

### Output and Saving

| Flag | Short | Default | Description |
Expand Down
2 changes: 2 additions & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,9 @@ class EvalConfig(BaseModel):
max_concurrent: int
max_concurrent_generation: int | None = None
max_concurrent_scoring: int | None = None
independent_scoring: bool = False
extra_env_kwargs: dict = {}
max_retries: int = 0
print_results: bool = False
verbose: bool = False
state_columns: list[str] | None = None
Expand Down
98 changes: 98 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,101 @@ async def has_error(state):
assert "RuntimeError" in call_args
assert "caused by" not in call_args
assert "something went wrong" in call_args


class RetryCounterEnv(SimpleEnvironment):
"""Environment that fails first N times with configurable error type."""

def __init__(self, fail_count: int, error_type: type = vf.InfraError, **kwargs):
super().__init__(**kwargs)
self.fail_count = fail_count
self.error_type = error_type
self.call_counts: dict[int, int] = {}

async def setup_state(self, state, **kwargs):
example_id = state["example_id"]
self.call_counts.setdefault(example_id, 0)
self.call_counts[example_id] += 1

if self.call_counts[example_id] <= self.fail_count:
raise self.error_type(
f"Simulated failure {self.call_counts[example_id]}/{self.fail_count}"
)

return state


class TestMaybeRetry:
"""Test cases for maybe_retry functionality in Environment.generate()."""

@pytest.mark.asyncio
async def test_retry_succeeds_after_transient_infra_error(self, mock_openai_client):
"""InfraError on first 2 attempts, succeeds on 3rd with max_retries=3."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=2, dataset=dataset, parser=Parser(), rubric=Rubric()
)

inputs = [
RolloutInput(
prompt=[{"role": "user", "content": "test"}],
answer="test",
example_id=0,
)
]
results = await env.generate(
inputs, client=mock_openai_client, model="test-model", max_retries=3
)

assert results["state"][0].get("error") is None
assert env.call_counts[0] == 3

@pytest.mark.asyncio
async def test_retry_fails_after_max_retries_exhausted(self, mock_openai_client):
"""InfraError persists after all retries exhausted."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=10, dataset=dataset, parser=Parser(), rubric=Rubric()
)

inputs = [
RolloutInput(
prompt=[{"role": "user", "content": "test"}],
answer="test",
example_id=0,
)
]

with pytest.raises(vf.InfraError):
await env.generate(
inputs, client=mock_openai_client, model="test-model", max_retries=2
)

assert env.call_counts[0] == 3 # 1 initial + 2 retries

@pytest.mark.asyncio
async def test_non_infra_error_not_retried(self, mock_openai_client):
"""ToolError is NOT retried even with max_retries > 0."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=10,
error_type=vf.ToolError,
dataset=dataset,
parser=Parser(),
rubric=Rubric(),
)

inputs = [
RolloutInput(
prompt=[{"role": "user", "content": "test"}],
answer="test",
example_id=0,
)
]

with pytest.raises(vf.ToolError):
await env.generate(
inputs, client=mock_openai_client, model="test-model", max_retries=3
)

assert env.call_counts[0] == 1 # No retries for non-InfraError
1 change: 1 addition & 0 deletions tests/test_eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _run_cli(monkeypatch, overrides):
"save_to_hf_hub": False,
"hf_hub_dataset_name": "",
"extra_env_kwargs": {},
"max_retries": 0,
}
base_args.update(overrides)
args_namespace = SimpleNamespace(**base_args)
Expand Down
11 changes: 8 additions & 3 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
SamplingArgs,
State,
)
from verifiers.utils.async_utils import maybe_semaphore
from verifiers.utils.async_utils import maybe_retry, maybe_semaphore
from verifiers.utils.error_utils import ErrorChain
from verifiers.utils.eval_utils import make_dataset, save_rollout_results
from verifiers.utils.message_utils import (
Expand Down Expand Up @@ -865,6 +865,7 @@ async def generate(
save_every: int = -1,
use_tqdm: bool = True,
independent_scoring: bool = False,
max_retries: int = 0,
) -> GenerateOutputs:
"""
Generate rollouts for a set of inputs.
Expand Down Expand Up @@ -898,7 +899,7 @@ async def generate(
if independent_scoring:
for i, input_item in enumerate(inputs_list):
task = asyncio.create_task(
self.run_rollout(
maybe_retry(self.run_rollout, max_retries=max_retries)(
input_item,
client,
model,
Expand All @@ -922,7 +923,7 @@ async def generate(

for i, group in enumerate(group_list):
task = asyncio.create_task(
self.run_group(
maybe_retry(self.run_group, max_retries=max_retries)(
group,
client,
model,
Expand Down Expand Up @@ -1070,6 +1071,7 @@ async def evaluate(
save_results: bool = False,
save_every: int = -1,
independent_scoring: bool = False,
max_retries: int = 0,
**kwargs,
) -> GenerateOutputs:
"""
Expand All @@ -1089,6 +1091,7 @@ async def evaluate(
save_results=save_results,
save_every=save_every,
independent_scoring=independent_scoring,
max_retries=max_retries,
**kwargs,
)

Expand All @@ -1107,6 +1110,7 @@ def evaluate_sync(
save_results: bool = False,
save_every: int = -1,
independent_scoring: bool = False,
max_retries: int = 0,
) -> GenerateOutputs:
"""
Evaluate model on the Environment evaluation dataset synchronously.
Expand All @@ -1125,6 +1129,7 @@ def evaluate_sync(
save_results=save_results,
save_every=save_every,
independent_scoring=independent_scoring,
max_retries=max_retries,
)

# setters for use by trainers
Expand Down
7 changes: 7 additions & 0 deletions verifiers/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ def main():
default={},
help='Extra environment as JSON object (e.g., \'{"key": "value", "num": 42}\'). Passed to environment constructor.',
)
parser.add_argument(
"--max-retries",
type=int,
default=0,
help="Max retries for transient infrastructure errors (default: 0)",
)
args = parser.parse_args()

setup_logging("DEBUG" if args.verbose else os.getenv("VF_LOG_LEVEL", "INFO"))
Expand Down Expand Up @@ -343,6 +349,7 @@ def main():
max_concurrent=args.max_concurrent,
max_concurrent_generation=args.max_concurrent_generation,
max_concurrent_scoring=args.max_concurrent_scoring,
max_retries=args.max_retries,
# logging
print_results=True,
verbose=args.verbose,
Expand Down
1 change: 1 addition & 0 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class EvalConfig(BaseModel):
max_concurrent_scoring: int | None = None
independent_scoring: bool = False
extra_env_kwargs: dict = {}
max_retries: int = 0
# logging
print_results: bool = False
verbose: bool = False
Expand Down
66 changes: 65 additions & 1 deletion verifiers/utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import inspect
import logging
from time import perf_counter
from typing import Any, AsyncContextManager, Callable, Optional
from collections.abc import Coroutine
from typing import Any, AsyncContextManager, Callable, Optional, TypeVar

import tenacity as tc

import verifiers as vf

logger = logging.getLogger(__name__)

T = TypeVar("T")


async def maybe_await(func: Callable, *args, **kwargs):
Expand Down Expand Up @@ -85,3 +94,58 @@ async def run(self):
def run_in_background(self):
"""Run the event loop lag monitor as a background task."""
return asyncio.create_task(self.run())


def _raise_error_from_state(result):
"""Re-raise InfraError from state(s) to trigger retry."""
if isinstance(result, dict):
err = result.get("error")
if err and isinstance(err, vf.InfraError):
raise err
elif isinstance(result, list):
for state in result:
err = state.get("error")
if err and isinstance(err, vf.InfraError):
raise err


def maybe_retry(
func: Callable[..., Coroutine[Any, Any, T]],
max_retries: int = 0,
initial: float = 1.0,
max_wait: float = 60.0,
) -> Callable[..., Coroutine[Any, Any, T]]:
"""
Return retry-wrapped function if max_retries > 0, else return func unchanged.
Re-raises vf.InfraError from state["error"] to trigger retry.

Usage:
state = await maybe_retry(self.run_rollout, max_retries=3)(input, client, ...)
"""
if max_retries <= 0:
return func

def log_retry(retry_state):
logger.warning(
"Retrying %s (attempt %s/%s): %s",
retry_state.fn.__name__,
retry_state.attempt_number + 1,
max_retries + 1,
retry_state.outcome.exception(),
)

async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
_raise_error_from_state(result)
return result

wrapper.__name__ = getattr(func, "__name__", "unknown")
wrapper.__qualname__ = getattr(func, "__qualname__", "unknown")

return tc.AsyncRetrying(
retry=tc.retry_if_exception_type(vf.InfraError),
stop=tc.stop_after_attempt(max_retries + 1),
wait=tc.wait_exponential_jitter(initial=initial, max=max_wait),
before_sleep=log_retry,
reraise=True,
).wraps(wrapper)
1 change: 1 addition & 0 deletions verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs:
save_results=config.save_results,
save_every=config.save_every,
independent_scoring=config.independent_scoring,
max_retries=config.max_retries,
)
end_time = time.time()
logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds")
Expand Down
Loading