Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/test_rubric_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,79 @@ def reward_func(completion, parser, answer, **_):

assert state["reward"] == 1.0
assert recorded_parsers == [xml_parser]

@pytest.mark.asyncio
async def test_rubric_group_score_rollout_timing(self):
"""Test that generation_ms + scoring_ms == total_ms after score_rollout."""

def func1(completion, **kwargs):
return 1.0

def func2(completion, **kwargs):
return 0.5

rubric1 = Rubric(funcs=[func1], weights=[1.0])
rubric2 = Rubric(funcs=[func2], weights=[1.0])

group = RubricGroup(rubrics=[rubric1, rubric2])

state = State(
input=RolloutInput(
prompt=[{"role": "user", "content": "test"}],
answer="test",
task="default",
example_id=0,
)
)
state["completion"] = [{"role": "assistant", "content": "test"}]
state["trajectory"] = []
state["timing"] = RolloutTiming(
generation_ms=100.0,
scoring_ms=0.0,
total_ms=100.0,
start_time=0.0,
)

await group.score_rollout(state)

assert state["timing"]["generation_ms"] == 100.0
assert state["timing"]["scoring_ms"] > 0.0
assert state["timing"]["total_ms"] == 100.0 + state["timing"]["scoring_ms"]

@pytest.mark.asyncio
async def test_rubric_group_score_group_timing(self):
"""Test that generation_ms + scoring_ms == total_ms after score_group."""

def func1(completion, **kwargs):
return 1.0

def func2(completion, **kwargs):
return 0.5

rubric1 = Rubric(funcs=[func1], weights=[1.0])
rubric2 = Rubric(funcs=[func2], weights=[1.0])

group = RubricGroup(rubrics=[rubric1, rubric2])

state = State(
input=RolloutInput(
prompt=[{"role": "user", "content": "test"}],
answer="test",
task="default",
example_id=0,
)
)
state["completion"] = [{"role": "assistant", "content": "test"}]
state["trajectory"] = []
state["timing"] = RolloutTiming(
generation_ms=100.0,
scoring_ms=0.0,
total_ms=100.0,
start_time=0.0,
)

await group.score_group([state])

assert state["timing"]["generation_ms"] == 100.0
assert state["timing"]["scoring_ms"] > 0.0
assert state["timing"]["total_ms"] == 100.0 + state["timing"]["scoring_ms"]
15 changes: 15 additions & 0 deletions verifiers/rubrics/rubric_group.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Any

from verifiers.rubrics.rubric import Rubric
Expand Down Expand Up @@ -56,12 +57,14 @@ async def score_rollout(self, state: State):
"""
Evaluate all reward functions in-place for a single rollout.
"""
start_time = time.time()
total_reward = 0.0
aggregated_metrics: dict[str, float] = {}
original_reward = state.get("reward", 0.0)
original_metrics = (
state.get("metrics", {}).copy() if state.get("metrics") else {}
)
original_timing = state["timing"].copy()
for rubric in self.rubrics:
await rubric.score_rollout(state)
rubric_reward = state.get("reward", 0.0)
Expand All @@ -74,20 +77,27 @@ async def score_rollout(self, state: State):
# restore original values for next rubric
state["reward"] = original_reward
state["metrics"] = original_metrics.copy()
state["timing"] = original_timing.copy()
state["reward"] = total_reward
state["metrics"] = aggregated_metrics
end_time = time.time()
scoring_ms = (end_time - start_time) * 1000
state["timing"]["scoring_ms"] = scoring_ms
state["timing"]["total_ms"] += scoring_ms

async def score_group(self, states: list[State]):
"""
Evaluate all reward functions in-place for a group of rollouts.
"""
start_time = time.time()
aggregated_rewards = [0.0] * len(states)
aggregated_metrics: dict[str, list[float]] = {}
original_rewards = [state.get("reward", 0.0) for state in states]
original_metrics = [
state.get("metrics", {}).copy() if state.get("metrics") else {}
for state in states
]
original_timings = [state["timing"].copy() for state in states]
for rubric in self.rubrics:
await rubric.score_group(states)
for i, state in enumerate(states):
Expand All @@ -102,10 +112,15 @@ async def score_group(self, states: list[State]):
aggregated_metrics[key][i] += value
state["reward"] = original_rewards[i]
state["metrics"] = original_metrics[i].copy()
state["timing"] = original_timings[i].copy()
end_time = time.time()
scoring_ms = (end_time - start_time) * 1000
for i, state in enumerate(states):
state["reward"] = aggregated_rewards[i]
if aggregated_metrics:
if "metrics" not in state:
state["metrics"] = {}
for key, values in aggregated_metrics.items():
state["metrics"][key] = values[i]
state["timing"]["scoring_ms"] = scoring_ms
state["timing"]["total_ms"] += scoring_ms