Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import TYPE_CHECKING, AsyncContextManager, Mapping
from typing import TYPE_CHECKING, AsyncContextManager, Mapping, final

from datasets import Dataset, concatenate_datasets
from openai import AsyncOpenAI
Expand Down Expand Up @@ -197,7 +197,7 @@ def add_task(example):
f"Initialized EnvGroup with {len(envs)} environments: {self.env_names}"
)

def format_dataset(
def _format_dataset(
self,
dataset: Dataset,
system_prompt: str | None = None,
Expand Down Expand Up @@ -232,7 +232,7 @@ def add_example_id(example, i):
)
return dataset

def format_completion_dataset(
def _format_completion_dataset(
self, dataset: Dataset, map_kwargs: dict = {}
) -> Dataset:
"""
Expand All @@ -253,20 +253,7 @@ def add_example_id(example, i):
)
return dataset

async def init_state(
self,
input: RolloutInput,
client: AsyncOpenAI,
model: str,
sampling_args: SamplingArgs | None = None,
) -> vf.State:
env = self.get_env_for_task(input["task"])
return await env.init_state(input, client, model, sampling_args)

async def setup_state(self, state: vf.State) -> vf.State:
env = self.get_env_for_task(state["task"])
return await env.setup_state(state)

@final
async def rollout(
self,
input: RolloutInput,
Expand All @@ -281,19 +268,19 @@ def get_env_for_task(self, task: str) -> vf.Environment:
return self.env_map.get(task, self.envs[0])

def set_max_seq_len(self, max_seq_len: int | None) -> None:
"""Set the maximum sequence length for this environment group and all sub-environments."""
"""Set the max_seq_len value for this environment group and all sub-environments."""
self.max_seq_len = max_seq_len
for env in self.envs:
env.set_max_seq_len(max_seq_len)

def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None:
"""Set the interleaved rollouts flag for this environment group and all sub-environments."""
"""Set the interleaved_rollouts flag for this environment group and all sub-environments."""
self.interleaved_rollouts = interleaved_rollouts
for env in self.envs:
env.set_interleaved_rollouts(interleaved_rollouts)

def set_score_rollouts(self, score_rollouts: bool) -> None:
"""Set the score rollouts flag for this environment group and all sub-environments."""
"""Set the score_rollouts flag for this environment group and all sub-environments."""
self.score_rollouts = score_rollouts
for env in self.envs:
env.set_score_rollouts(score_rollouts)
48 changes: 17 additions & 31 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Literal,
TypeVar,
cast,
final,
)

from datasets import Dataset
Expand All @@ -47,7 +48,6 @@
from verifiers.utils.error_utils import ErrorChain
from verifiers.utils.eval_utils import make_dataset, save_rollout_results
from verifiers.utils.message_utils import (
concat_messages,
strip_nones_from_content,
)
from verifiers.utils.path_utils import get_results_path
Expand Down Expand Up @@ -106,13 +106,13 @@ def __init__(

if self.message_type == "chat":
if dataset is not None:
self.dataset = self.format_dataset(
self.dataset = self._format_dataset(
dataset, self.system_prompt, self.few_shot, map_kwargs=map_kwargs
)
else:
self.dataset = None
if eval_dataset is not None:
self.eval_dataset = self.format_dataset(
self.eval_dataset = self._format_dataset(
eval_dataset,
self.system_prompt,
self.few_shot,
Expand All @@ -128,13 +128,13 @@ def __init__(
'to contain a "prompt" column.'
)
if dataset is not None:
self.dataset = self.format_completion_dataset(
self.dataset = self._format_completion_dataset(
dataset, map_kwargs=map_kwargs
)
else:
self.dataset = None
if eval_dataset is not None:
self.eval_dataset = self.format_completion_dataset(
self.eval_dataset = self._format_completion_dataset(
eval_dataset, map_kwargs=map_kwargs
)
else:
Expand Down Expand Up @@ -280,7 +280,7 @@ def add_task(example):
dataset = dataset.map(add_task, **map_kwargs)
return dataset

def format_dataset(
def _format_dataset(
self,
dataset: Dataset,
system_prompt: str | None = None,
Expand All @@ -299,7 +299,7 @@ def format_dataset(
dataset = self._ensure_task(dataset, map_kwargs)
return dataset

def format_completion_dataset(
def _format_completion_dataset(
self, dataset: Dataset, map_kwargs: dict = {}
) -> Dataset:
"""
Expand All @@ -309,6 +309,7 @@ def format_completion_dataset(
dataset = self._ensure_task(dataset, map_kwargs)
return dataset

@final
def get_dataset(self, n: int = -1, seed: int | None = None) -> Dataset:
if self.dataset is None:
raise ValueError("dataset is not set")
Expand All @@ -320,6 +321,7 @@ def get_dataset(self, n: int = -1, seed: int | None = None) -> Dataset:
return self.dataset.select(range(n))
return self.dataset

@final
def get_eval_dataset(self, n: int = -1, seed: int | None = None) -> Dataset:
if self.eval_dataset is None:
self.logger.warning(
Expand Down Expand Up @@ -565,6 +567,7 @@ async def get_model_response_with_tokens(
)
return response

@final
async def init_state(
self,
input: RolloutInput,
Expand Down Expand Up @@ -611,13 +614,6 @@ async def init_state(
)
return state

@abstractmethod
async def setup_state(self, state: State) -> State:
"""
Setup the state.
"""
return state

@abstractmethod
async def rollout(
self,
Expand Down Expand Up @@ -665,29 +661,17 @@ async def _render_timing(self, state: State):
state["timing"]["generation_ms"] = (end_time - start_time) * 1000
state["timing"]["total_ms"] = (end_time - start_time) * 1000

async def _render_completion(self, state: State):
if len(state["trajectory"]) == 0:
state["completion"] = []
return
last_prompt = state["trajectory"][-1]["prompt"]
last_completion = state["trajectory"][-1]["completion"]
full_conversation = concat_messages([last_prompt, last_completion])
if state.get("final_env_response"):
full_conversation = concat_messages(
[full_conversation, state["final_env_response"]]
)
state["completion"] = full_conversation[len(state["prompt"]) :]

@final
async def is_completed(self, state: State, **kwargs) -> bool:
"""Check all stop conditions. Sets state.is_completed=True if any condition is met."""
for condition in self._stop_conditions:
if await self._render_stop(state, condition):
await self._render_timing(state)
await self._render_completion(state)
await self._cleanup(state)
return True
return False

@final
async def run_rollout(
self,
sem: AsyncContextManager,
Expand All @@ -708,6 +692,7 @@ async def run_rollout(
)
return state

@final
async def run_group(
self,
group_inputs: list[RolloutInput],
Expand Down Expand Up @@ -974,7 +959,7 @@ def generate_sync(
executor.shutdown(wait=False)

# evaluation
def get_eval_inputs(
def _get_eval_inputs(
self, num_examples: int = -1, rollouts_per_example: int = 1
) -> List[RolloutInput]:
if self.eval_dataset is None:
Expand Down Expand Up @@ -1007,7 +992,7 @@ async def evaluate(
"""
Evaluate model on the Environment evaluation dataset.
"""
inputs = self.get_eval_inputs(num_examples, rollouts_per_example)
inputs = self._get_eval_inputs(num_examples, rollouts_per_example)
return await self.generate(
inputs,
client=client,
Expand Down Expand Up @@ -1041,7 +1026,7 @@ def evaluate_sync(
"""
Evaluate model on the Environment evaluation dataset synchronously.
"""
inputs = self.get_eval_inputs(num_examples, rollouts_per_example)
inputs = self._get_eval_inputs(num_examples, rollouts_per_example)
return self.generate_sync(
inputs,
client=client,
Expand All @@ -1056,6 +1041,7 @@ def evaluate_sync(
save_every=save_every,
)

# setters for use by trainers
def set_kwargs(self, **kwargs) -> None:
"""
Set environment attributes, using setter methods when available.
Expand Down
29 changes: 29 additions & 0 deletions verifiers/envs/experimental/rlm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from verifiers.types import Messages, ModelResponse, State, TrajectoryStep
from verifiers.utils.async_utils import maybe_await
from verifiers.utils.data_utils import extract_boxed_answer
from verifiers.utils.message_utils import concat_messages
from verifiers.utils.response_utils import (
parse_is_truncated,
parse_response_messages,
Expand Down Expand Up @@ -1724,6 +1725,34 @@ async def cleanup_rlm_state(self, state: State):
if (tunnel_url := state.get("tunnel_url")) and self._tunnel_pool:
await self._tunnel_pool.release_tunnel(tunnel_url)

async def render_completion(self, state: State):
"""Render completion from main model steps only, ignoring sub-LLM steps."""

if len(state["trajectory"]) == 0:
state["completion"] = []
return

# Find the last trajectory step from the main model (matching trajectory_id)
main_trajectory_id = state["trajectory_id"]
last_main_step = None
for step in reversed(state["trajectory"]):
if step.get("trajectory_id") == main_trajectory_id:
last_main_step = step
break

if last_main_step is None:
state["completion"] = []
return

last_prompt = last_main_step["prompt"]
last_completion = last_main_step["completion"]
full_conversation = concat_messages([last_prompt, last_completion])
if state.get("final_env_response"):
full_conversation = concat_messages(
[full_conversation, state["final_env_response"]]
)
state["completion"] = full_conversation[len(state["prompt"]) :]

async def post_rollout(self, state: State):
"""Read final answer from sandbox if not already set."""
await self._ensure_final_answer(state)
Loading
Loading