Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,9 @@ def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None:
self.interleaved_rollouts = interleaved_rollouts
for env in self.envs:
env.set_interleaved_rollouts(interleaved_rollouts)

def set_score_rollouts(self, score_rollouts: bool) -> None:
"""Set the score rollouts flag for this environment group and all sub-environments."""
self.score_rollouts = score_rollouts
for env in self.envs:
env.set_score_rollouts(score_rollouts)
11 changes: 10 additions & 1 deletion verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
map_kwargs: dict = {},
max_seq_len: int | None = None,
interleaved_rollouts: bool = False,
score_rollouts: bool = True,
**kwargs,
):
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
Expand All @@ -100,6 +101,7 @@ def __init__(
self.max_seq_len = max_seq_len

self.set_interleaved_rollouts(interleaved_rollouts)
self.set_score_rollouts(score_rollouts)

if self.message_type == "chat":
if dataset is not None:
Expand Down Expand Up @@ -711,7 +713,10 @@ async def run_group(
for input in group_inputs
]
group_states = await asyncio.gather(*rollout_tasks)
await self.rubric.score_group(group_states, score_sem=score_sem)
if self.score_rollouts:
await self.rubric.score_group(group_states, score_sem=score_sem)
else:
await self.rubric.dummy_score_group(group_states)
return list(group_states)

def _prepare_rollout_results(
Expand Down Expand Up @@ -1062,6 +1067,10 @@ def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None:
f"{self.__class__.__name__} is configured to use interleaved rollouts. All model responses after the first turn will be pre-tokenized before being sent to the model. Currently, this is a hand-crafted feature for PRIME-RL's vLLM server extension."
)

def set_score_rollouts(self, score_rollouts: bool) -> None:
"""Set the score rollouts flag for this environment."""
self.score_rollouts = score_rollouts

make_dataset = make_dataset


Expand Down
8 changes: 8 additions & 0 deletions verifiers/rubrics/rubric.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ async def score_rollout(self, state: State, score_sem: AsyncContextManager):
state["reward"] = rewards["reward"]
state["metrics"] = rewards["metrics"]

async def dummy_score_group(self, states: list[State]):
"""
Score a group of rollouts together with dummy rewards.
"""
for state in states:
state["reward"] = 0.0
state["metrics"] = {}

async def score_group(self, states: list[State], score_sem: AsyncContextManager):
"""
Score a group of rollouts together.
Expand Down
Loading