Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/test_async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Tests for the async_utils module."""

from unittest.mock import MagicMock, patch

import pytest
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage

from verifiers.utils.async_utils import tqdm_gather_with_metrics


class MockChatCompletion:
"""Mock ChatCompletion with usage."""

def __init__(self, completion_tokens: int):
self.usage = MagicMock(spec=CompletionUsage)
self.usage.completion_tokens = completion_tokens

self.choices = [MagicMock(spec=Choice)]
self.choices[0].message = MagicMock(spec=ChatCompletionMessage)
self.choices[0].message.content = "Hello"
self.choices[0].message.role = "assistant"


@pytest.mark.asyncio
async def test_tqdm_gather_with_metrics():
state1 = {"responses": [MockChatCompletion(10)]}
state2 = {"responses": [MockChatCompletion(14)]}

async def task1():
return (0.8, state1)

async def task2():
return (0.9, state2)

with patch("tqdm.asyncio.tqdm") as mock_tqdm:
mock_pbar = MagicMock()
mock_tqdm.return_value = mock_pbar

tasks = [task1(), task2()]
results = await tqdm_gather_with_metrics(tasks, total=2, desc="Running eval")

assert len(results) == 2
assert results[0] == (0.8, state1)
assert results[1] == (0.9, state2)

mock_tqdm.assert_called_once_with(total=2, desc="Running eval")
assert mock_pbar.update.call_count == 2
mock_pbar.close.assert_called_once()

set_desc_calls = mock_pbar.set_description.call_args_list
assert len(set_desc_calls) == 2

desc1 = set_desc_calls[0][0][0]
assert "reward=0.800" in desc1
assert "completion_length=10" in desc1

desc2 = set_desc_calls[1][0][0]
assert "reward=0.850" in desc2
assert "completion_length=12" in desc2


@pytest.mark.asyncio
async def test_tqdm_gather_with_metrics_missing_usage():
class MockResponseNoUsage:
"""Mock response without usage field."""

pass

state = {"responses": [MockResponseNoUsage()]}

async def task():
return (0.7, state)

with patch("tqdm.asyncio.tqdm") as mock_tqdm:
mock_pbar = MagicMock()
mock_tqdm.return_value = mock_pbar

tasks = [task()]
results = await tqdm_gather_with_metrics(tasks, total=1, desc="No usage test")

assert len(results) == 1
assert results[0] == (0.7, state)
mock_pbar.update.assert_called_once()
10 changes: 6 additions & 4 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_overlong_prompt_dummy_response,
sanitize_tool_calls,
)
from verifiers.utils.async_utils import tqdm_gather_with_metrics

if TYPE_CHECKING:
from transformers.tokenization_utils_base import ( # type: ignore
Expand Down Expand Up @@ -474,7 +475,7 @@ async def a_generate(
else None
)

async def run_one(i: int) -> None:
async def run_one(i: int) -> tuple[float, State]:
prompt_i = results.prompt[i]
answer_i = results.answer[i]
task_i = results.task[i]
Expand Down Expand Up @@ -534,11 +535,12 @@ async def run_one(i: int) -> None:
metrics[k] = [0.0] * n
metrics[k][i] = v

return rs.reward, state_i

tasks = [run_one(i) for i in range(n)]
from tqdm.asyncio import tqdm_asyncio

await tqdm_asyncio.gather(
*tasks, total=n, desc=f"Running {n} rollouts (interleaved)"
await tqdm_gather_with_metrics(
tasks, total=n, desc=f"Running {n} rollouts (interleaved)"
)

results.completion = results_completion # type: ignore[assignment]
Expand Down
47 changes: 47 additions & 0 deletions verifiers/utils/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,56 @@
import inspect
from typing import Callable

import asyncio


async def maybe_await(func: Callable, *args, **kwargs):
result = func(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result


async def tqdm_gather_with_metrics(tasks: list, total: int, desc: str) -> list:
"""
Gather async tasks with tqdm progress bar showing running averages
for reward and completion token count.
"""
from tqdm.asyncio import tqdm as async_tqdm

results = [None] * total
progress_bar = async_tqdm(total=total, desc=desc)

count = 0
sum_reward = 0.0
sum_tokens = 0

async def track(idx: int, task):
nonlocal count, sum_reward, sum_tokens
results[idx] = result = await task
count += 1

# we expect a tuple of (reward, state)
if result and len(result) == 2:
sum_reward += result[0]

# extract token count from state using OpenAI standard format
state = result[1]
if "responses" in state:
for response in state["responses"]:
if hasattr(response, "usage") and response.usage:
sum_tokens += response.usage.completion_tokens or 0

description_parts = [desc]
if sum_reward:
description_parts.append(f"reward={sum_reward / count:.3f}")
if sum_tokens:
description_parts.append(f"completion_length={sum_tokens / count:.0f}")

if len(description_parts) > 1:
progress_bar.set_description(" | ".join(description_parts))
progress_bar.update(1)

await asyncio.gather(*[track(i, t) for i, t in enumerate(tasks)])
progress_bar.close()
return results
Loading