Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 56 additions & 49 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,43 @@ async def rollout(
):
"""Simple test rollout implementation."""
state = await self.init_state(input, client=client, model=model)
state = await self.setup_state(state)
try:
state = await self.setup_state(state)

prompt_messages = state["prompt"]
response = await self.get_model_response(state, prompt_messages)
prompt_messages = state["prompt"]
response = await self.get_model_response(state, prompt_messages)

from verifiers.utils.response_utils import parse_response_messages
from verifiers.utils.response_utils import parse_response_messages

completion_messages = await parse_response_messages(response, self.message_type)
from verifiers.types import TrajectoryStep
from verifiers.utils.response_utils import parse_response_tokens

tokens = await parse_response_tokens(response, self.message_type)
trajectory_step = TrajectoryStep(
prompt=prompt_messages,
completion=completion_messages,
response=response,
tokens=tokens,
reward=None,
advantage=None,
is_truncated=False,
trajectory_id=state["trajectory_id"],
extras={},
)
state["trajectory"].append(trajectory_step)
state["is_completed"] = True
completion_messages = await parse_response_messages(
response, self.message_type
)
from verifiers.types import TrajectoryStep
from verifiers.utils.response_utils import parse_response_tokens

tokens = await parse_response_tokens(response, self.message_type)
trajectory_step = TrajectoryStep(
prompt=prompt_messages,
completion=completion_messages,
response=response,
tokens=tokens,
reward=None,
advantage=None,
is_truncated=False,
trajectory_id=state["trajectory_id"],
extras={},
)
state["trajectory"].append(trajectory_step)
state["is_completed"] = True

from verifiers.utils.message_utils import concat_messages
from verifiers.utils.message_utils import concat_messages

last_prompt = state["trajectory"][-1]["prompt"]
last_completion = state["trajectory"][-1]["completion"]
full_conversation = concat_messages([last_prompt, last_completion])
state["completion"] = full_conversation[len(state["prompt"]) :]
last_prompt = state["trajectory"][-1]["prompt"]
last_completion = state["trajectory"][-1]["completion"]
full_conversation = concat_messages([last_prompt, last_completion])
state["completion"] = full_conversation[len(state["prompt"]) :]
except vf.Error as e:
state["error"] = e

return state

Expand Down Expand Up @@ -578,8 +583,8 @@ class TestMaybeRetry:
"""Test cases for maybe_retry functionality in Environment.generate()."""

@pytest.mark.asyncio
async def test_retry_succeeds_after_transient_infra_error(self, mock_openai_client):
"""InfraError on first 2 attempts, succeeds on 3rd with max_retries=3."""
async def test_retry_after_retryable_error(self, mock_openai_client):
"""Retry after error on first 2 attempts, succeeds on 3rd with max_retries=3."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=2, dataset=dataset, parser=Parser(), rubric=Rubric()
Expand All @@ -600,11 +605,15 @@ async def test_retry_succeeds_after_transient_infra_error(self, mock_openai_clie
assert env.call_counts[0] == 3

@pytest.mark.asyncio
async def test_retry_fails_after_max_retries_exhausted(self, mock_openai_client):
"""InfraError persists after all retries exhausted."""
async def test_no_retry_after_non_retryable_error(self, mock_openai_client):
"""Non-retryable error type is NOT retried even with max_retries > 0."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=10, dataset=dataset, parser=Parser(), rubric=Rubric()
fail_count=10,
error_type=vf.ToolError,
dataset=dataset,
parser=Parser(),
rubric=Rubric(),
)

inputs = [
Expand All @@ -615,23 +624,20 @@ async def test_retry_fails_after_max_retries_exhausted(self, mock_openai_client)
)
]

with pytest.raises(vf.InfraError):
await env.generate(
inputs, client=mock_openai_client, model="test-model", max_retries=2
)
results = await env.generate(
inputs, client=mock_openai_client, model="test-model", max_retries=3
)

assert env.call_counts[0] == 3 # 1 initial + 2 retries
assert env.call_counts[0] == 1 # No retries for non-retryable error
assert results["state"][0].get("error") is not None
assert isinstance(results["state"][0]["error"], vf.ToolError)

@pytest.mark.asyncio
async def test_non_infra_error_not_retried(self, mock_openai_client):
"""ToolError is NOT retried even with max_retries > 0."""
async def test_error_in_state_after_max_retries_exhausted(self, mock_openai_client):
"""Error persists in state after all retries exhausted."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
env = RetryCounterEnv(
fail_count=10,
error_type=vf.ToolError,
dataset=dataset,
parser=Parser(),
rubric=Rubric(),
fail_count=10, dataset=dataset, parser=Parser(), rubric=Rubric()
)

inputs = [
Expand All @@ -642,12 +648,13 @@ async def test_non_infra_error_not_retried(self, mock_openai_client):
)
]

with pytest.raises(vf.ToolError):
await env.generate(
inputs, client=mock_openai_client, model="test-model", max_retries=3
)
results = await env.generate(
inputs, client=mock_openai_client, model="test-model", max_retries=2
)

assert env.call_counts[0] == 1 # No retries for non-InfraError
assert env.call_counts[0] == 3 # 1 initial + 2 retries
assert results["state"][0].get("error") is not None
assert isinstance(results["state"][0]["error"], vf.InfraError)


class TestEmptyModelResponseErrors:
Expand Down
24 changes: 24 additions & 0 deletions verifiers/utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def maybe_retry(
"""
Return retry-wrapped function if max_retries > 0, else return func unchanged.
Re-raises specified errors from state["error"] to trigger tenacity retry.
Returns result with error in state if retries are exhausted (does not crash).

Usage:
state = await maybe_retry(self.run_rollout, max_retries=3)(input, client, ...)
Expand Down Expand Up @@ -147,8 +148,30 @@ def log_retry(retry_state: tc.RetryCallState) -> None:
f"Caught {error_chain} in {caller}. Retrying in {print_time(next_action)} (retry {retry_state.attempt_number}/{max_retries})"
)

last_result = None

def return_last_result(retry_state: tc.RetryCallState):
"""Return the last result when retries are exhausted (instead of raising)."""
caller = retry_state.fn.__name__ if retry_state.fn else "unknown function"
error_chain = (
repr(
ErrorChain(
retry_state.outcome.exception() or Exception("Unknown exception")
)
)
if retry_state.outcome
else None
)
logger.error(
f"Retries exhausted for {caller} after {max_retries} attempts. "
f"Last error: {error_chain}. Continuing with error in state."
)
return last_result

async def wrapper(*args, **kwargs):
nonlocal last_result
result = await func(*args, **kwargs)
last_result = result # store result
reraise_error_from_state(result, error_types)
return result

Expand All @@ -160,5 +183,6 @@ async def wrapper(*args, **kwargs):
stop=tc.stop_after_attempt(max_retries + 1),
wait=tc.wait_exponential_jitter(initial=initial, max=max_wait),
before_sleep=log_retry,
retry_error_callback=return_last_result,
reraise=True,
).wraps(wrapper)
Loading