Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions verifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from .rubrics.rubric import Rubric # noqa # isort: skip
from .envs.environment import Environment # noqa # isort: skip
from .envs.multiturn_env import MultiTurnEnv # noqa # isort: skip
from .envs.agent import Agent # noqa # isort: skip
from .envs.protocol import Protocol, RoundRobinProtocol # noqa # isort: skip
from .envs.multiagent_env import MultiAgentEnv # noqa # isort: skip
from .envs.tool_env import ToolEnv # noqa # isort: skip
from .clients.client import Client # noqa # isort: skip
from .clients.anthropic_messages_client import AnthropicMessagesClient # noqa # isort: skip
Expand Down Expand Up @@ -71,8 +74,12 @@
"MCPEnv",
"BrowserEnv",
"OpenEnvEnv",
"Agent",
"Protocol",
"RoundRobinProtocol",
"Environment",
"MultiTurnEnv",
"MultiAgentEnv",
"SingleTurnEnv",
"PythonEnv",
"SandboxEnv",
Expand Down
31 changes: 30 additions & 1 deletion verifiers/clients/openai_chat_completions_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,42 @@ async def get_prompt_ids(
that this method is called *before* making the model response from
prompt_messages, i.e. the previous turn's prompt and completion do not yet
include the environment response and next turn's model response.

For multi-agent environments, we find the previous turn for the CURRENT agent
(not just the global last turn) to correctly compute incremental tokens.
"""
prev_turn_tokens = state["trajectory"][-1]["tokens"]
# For multi-agent: find the last trajectory step for THIS agent
current_agent_id = state.get("extras", {}).get("current_agent_id")
prev_turn = None

if current_agent_id:
# Multi-agent: find last step for this specific agent
for step in reversed(state["trajectory"]):
if step.get("extras", {}).get("agent_id") == current_agent_id:
prev_turn = step
break

# Fallback to last step (single-agent or agent's first turn)
if prev_turn is None:
prev_turn = state["trajectory"][-1]

prev_turn_tokens = prev_turn["tokens"]
assert prev_turn_tokens is not None
prev_turn_prompt_ids = prev_turn_tokens["prompt_ids"]
prev_turn_completion_ids = prev_turn_tokens["completion_ids"]
prev_turn_ids = prev_turn_prompt_ids + prev_turn_completion_ids

# For multi-agent: if this agent has no previous turn, use full tokenization
if (
current_agent_id
and prev_turn.get("extras", {}).get("agent_id") != current_agent_id
):
return await self.tokenize(
messages=prompt_messages,
tools=oai_tools,
model=state["model"],
)

def compute_suffix_ids(lst: list[int], value: int) -> list[int]:
"""Returns all tokens after the last occurrence of `value` in `lst`, if any."""

Expand Down
35 changes: 35 additions & 0 deletions verifiers/envs/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Agent: A participant in multi-agent environments.

Contains agent metadata (id, system prompt, trainability).
"""

from dataclasses import dataclass


@dataclass
class Agent:
"""
An agent in a multi-agent environment.

Fields:
id: Unique identifier for this agent (e.g., "player_0", "guesser")
system_prompt: The agent's specific instructions
is_trainable: Whether to compute gradients for this agent's actions
"""

id: str
system_prompt: str = ""
is_trainable: bool = True

def __hash__(self) -> int:
return hash(self.id)

def __eq__(self, other: object) -> bool:
if isinstance(other, Agent):
return self.id == other.id
return False

def __repr__(self) -> str:
trainable_str = "trainable" if self.is_trainable else "frozen"
return f"Agent(id={self.id!r}, {trainable_str})"
235 changes: 235 additions & 0 deletions verifiers/envs/multiagent_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""
Multi-agent environment for turn-based games.

This module provides the base class for multi-agent RL environments, extending
MultiTurnEnv with support for:
- Multiple agents with distinct system prompts
- Turn order management via Protocol
- Per-agent trajectory tagging for credit assignment

Key concepts:
- Agent: A participant with its own identity/prompt (defined in agent.py)
- Protocol: Defines turn order and interaction patterns (defined in protocol.py)

Environment Implementation:
- Subclasses must provide a Protocol for turn management
- Subclasses implement these main hooks:
- build_agent_prompt(agent_id, state): Build fresh prompt for this agent
- on_turn_complete(state): Update game state after each turn
"""

from abc import abstractmethod

import verifiers as vf
from verifiers.envs.agent import Agent
from verifiers.envs.multiturn_env import MultiTurnEnv
from verifiers.envs.protocol import Protocol
from verifiers.types import Messages, State, TrajectoryStep


class MultiAgentEnv(MultiTurnEnv):
"""
Base class for multi-agent environments.

Turn order is managed by a Protocol, which must be provided at init.
This keeps turn logic reusable and separate from environment logic.

Subclasses must implement:
- build_agent_prompt(): Build prompt for current agent

Subclasses may optionally override:
- on_turn_complete(): Game logic after each turn
"""

# List of agent IDs this environment uses (e.g., ["player_0", "player_1"])
# Subclasses should override this or set in __init__
agents: list[str] = []

def __init__(self, protocol: Protocol, **kwargs):
"""
Initialize multi-agent environment.

Args:
protocol: Protocol for turn order management.
**kwargs: Passed to MultiTurnEnv
"""
super().__init__(**kwargs)
self._protocol = protocol
self._agent_registry: dict[str, Agent] = {}

def register_agent(self, agent: Agent) -> None:
"""Register an Agent for lookup by get_agent()."""
self._agent_registry[agent.id] = agent
if agent.id not in self.agents:
self.agents.append(agent.id)

def get_agent(self, agent_id: str) -> Agent:
"""Get an agent by ID."""
if agent_id not in self._agent_registry:
raise KeyError(
f"Agent '{agent_id}' not found. Did you call register_agent()?"
)
return self._agent_registry[agent_id]

# -------------------------------------------------------------------------
# Turn Management (delegated to Protocol)
# -------------------------------------------------------------------------

def get_initial_agent(self, state: State) -> str:
"""Return the agent ID that starts the rollout."""
return self._protocol.get_initial_agent(state)

def get_next_agent(self, state: State) -> str:
"""Return the agent ID for the next turn."""
return self._protocol.get_next_agent(state)

# -------------------------------------------------------------------------
# Agent Prompt Building (Subclasses Implement This)
# -------------------------------------------------------------------------

@abstractmethod
async def build_agent_prompt(self, agent_id: str, state: State) -> Messages:
"""
Build the prompt for the given agent's turn.

This is called BEFORE the model generates a response.
Build a fresh prompt with whatever context this agent needs.

Args:
agent_id: The agent who will respond (e.g., "player_0")
state: Current game state with trajectory and extras

Returns:
Messages list with system prompt and user content
"""
pass

# -------------------------------------------------------------------------
# Game Logic Hook
# -------------------------------------------------------------------------

async def on_turn_complete(self, state: State) -> None:
"""
Update game state after a turn completes.

This is called AFTER the model response is stored in trajectory.
Use this for game logic:
- Update scores, counters, flags
- Check win conditions
- Parse and validate actions

The last turn's info is in state["trajectory"][-1]:
- ["completion"][-1]["content"]: The model's response text
- ["extras"]["agent_id"]: Which agent just responded

Args:
state: Current game state (mutate extras as needed)
"""
pass

# -------------------------------------------------------------------------
# State Setup
# -------------------------------------------------------------------------

async def setup_state(self, state: State) -> State:
"""Initialize multi-agent state fields."""
state = await super().setup_state(state)
state["extras"] = state.get("extras", {})
state["extras"]["current_agent_id"] = None
return state

# -------------------------------------------------------------------------
# Parent Class Requirement (env_response)
# -------------------------------------------------------------------------

async def env_response(
self, messages: Messages, state: State, **kwargs
) -> Messages:
"""
Satisfy MultiTurnEnv's abstract requirement.

MultiAgentEnv uses on_turn_complete() instead, which is called
explicitly in our rollout() after storing the response.
"""
return []

# -------------------------------------------------------------------------
# Trajectory Management
# -------------------------------------------------------------------------

async def add_trajectory_step(
self, state: State, trajectory_step: TrajectoryStep
) -> None:
"""Tag trajectory step with agent_id."""
current_agent_id = state["extras"].get("current_agent_id")
if current_agent_id:
trajectory_step["extras"]["agent_id"] = current_agent_id
# Copy trainability from Agent to step
agent = self.get_agent(current_agent_id)
trajectory_step["extras"]["is_trainable"] = agent.is_trainable
await super().add_trajectory_step(state, trajectory_step)

# -------------------------------------------------------------------------
# Main Rollout Loop
# -------------------------------------------------------------------------

async def rollout(
self,
input,
client,
model,
sampling_args=None,
) -> State:
"""
Run a multi-agent episode.

Flow:
1. Setup state
2. Loop until game ends:
a. Determine current agent
b. Build prompt via build_agent_prompt()
c. Get model response
d. Store in trajectory
e. Process via on_turn_complete()
3. Return final state
"""
state = await self.init_state(input, client, model, sampling_args)
try:
state = await self.setup_state(state)
except vf.Error as e:
state["error"] = e
return state

# Determine first agent
state["extras"]["current_agent_id"] = self.get_initial_agent(state)

while not await self.is_completed(state):
agent_id = state["extras"]["current_agent_id"]

try:
# 1. Build prompt for this agent
prompt_messages = await self.build_agent_prompt(agent_id, state)

# 2. Get model response
response = await self.get_model_response(state, prompt_messages)

# 3. Store in trajectory (tags with agent_id)
await self.add_model_response(state, prompt_messages, response)

# 4. Process turn (game logic)
await self.on_turn_complete(state)

# 5. Determine next agent (if game continues)
if not await self.is_completed(state):
state["extras"]["current_agent_id"] = self.get_next_agent(state)

except vf.OverlongPromptError:
state["prompt_too_long"] = True
state["is_truncated"] = True
break
except vf.Error as e:
state["error"] = e
break

await self.render_completion(state)
return state
Loading