Skip to content

Commit 6562aab

Browse files
Merge pull request #205 from JasonWei05/v0.2_tb_simplified
Terminal Bench Integration into rLLM (Simplified)
2 parents f51403f + 5e3e47f commit 6562aab

File tree

5 files changed

+461
-0
lines changed

5 files changed

+461
-0
lines changed

examples/terminal/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
### Terminal-Bench example
2+
3+
- Requires Python >= 3.12
4+
- Set OPENAI API key
5+
- Install Terminal-Bench:
6+
7+
```bash
8+
pip install terminal-bench
9+
```
10+
11+
After installing, you can run the sample script in this folder to evaluate o4-mini on the terminal-bench-core v0.1.1 dataset with Terminal Bench's terminus 1 agent.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import Any
5+
6+
import yaml
7+
from terminal_bench.dataset.dataset import Dataset
8+
9+
10+
def load_terminal_bench_dataset(
11+
dataset_name: str,
12+
dataset_version: str = "head",
13+
task_ids: list[str] | None = None,
14+
n_tasks: int | None = None,
15+
local_registry_path: Path | None = None,
16+
registry_url: str | None = None,
17+
exclude_task_ids: list[str] | None = None,
18+
) -> list[dict[str, Any]]:
19+
"""Load Terminal-Bench dataset and convert to minimal rLLM task dicts.
20+
21+
Args:
22+
dataset_name: Dataset registry name.
23+
dataset_version: Concrete version or "head".
24+
task_ids: Optional subset of task IDs (supports glob patterns).
25+
n_tasks: Optional cap on number of tasks.
26+
local_registry_path: Optional path to a local registry file.
27+
registry_url: Optional registry URL.
28+
exclude_task_ids: Optional list of task IDs (glob patterns) to exclude.
29+
30+
Returns:
31+
List[Dict[str, Any]]: Each dict includes ``task_path``, ``task_id``, and ``instruction``.
32+
"""
33+
dataset = Dataset(
34+
name=dataset_name,
35+
version=dataset_version,
36+
task_ids=task_ids,
37+
n_tasks=n_tasks,
38+
exclude_task_ids=exclude_task_ids or [],
39+
local_registry_path=local_registry_path,
40+
registry_url=registry_url,
41+
)
42+
43+
tasks: list[dict[str, Any]] = []
44+
for task_path in dataset:
45+
task_config = load_task_config(task_path)
46+
47+
task_dict = {
48+
"task_path": str(task_path),
49+
"task_id": task_path.name,
50+
"instruction": task_config["instruction"],
51+
}
52+
tasks.append(task_dict)
53+
54+
return tasks
55+
56+
57+
def load_task_config(task_path: Path) -> dict[str, Any]:
58+
"""Load and validate task configuration from task.yaml file.
59+
60+
Args:
61+
task_path: Path to a Terminal-Bench task directory.
62+
63+
Returns:
64+
Dict[str, Any]: Parsed YAML mapping.
65+
"""
66+
task_yaml_path = task_path / "task.yaml"
67+
68+
if not task_yaml_path.exists():
69+
raise FileNotFoundError(f"task.yaml not found at {task_yaml_path}")
70+
71+
with open(task_yaml_path) as f:
72+
config = yaml.safe_load(f)
73+
74+
required_fields = ["instruction"]
75+
for field in required_fields:
76+
if field not in config:
77+
raise ValueError(f"Missing required field '{field}' in {task_yaml_path}")
78+
79+
return config

examples/terminal/run_terminus.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import asyncio
2+
import os
3+
4+
from prepare_terminal_data import load_terminal_bench_dataset
5+
from terminus_workflow import TerminalTerminusWorkflow
6+
7+
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
8+
from rllm.engine.rollout.openai_engine import OpenAIEngine
9+
10+
11+
async def main():
12+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
13+
14+
# Dataset selection (matches v0.2_tb style)
15+
dataset_name = "terminal-bench-core"
16+
dataset_version = "0.1.1"
17+
18+
model_name = "o4-mini"
19+
rollout_engine = OpenAIEngine(model=model_name)
20+
21+
max_steps = 50
22+
global_agent_timeout_sec = 600.0
23+
workflow_engine = AgentWorkflowEngine(
24+
workflow_cls=TerminalTerminusWorkflow,
25+
workflow_args={
26+
"model_name": model_name,
27+
"env_args": {
28+
"cleanup": True,
29+
},
30+
"max_steps": max_steps,
31+
"global_agent_timeout_sec": global_agent_timeout_sec,
32+
},
33+
rollout_engine=rollout_engine,
34+
n_parallel_tasks=1,
35+
retry_limit=1, # TB already retries inside the agent loop
36+
)
37+
38+
await workflow_engine.initialize_pool()
39+
40+
# Load dataset
41+
tasks = load_terminal_bench_dataset(
42+
dataset_name=dataset_name,
43+
dataset_version=dataset_version,
44+
)
45+
print(f"Loaded {len(tasks)} tasks from {dataset_name} {dataset_version}")
46+
47+
# Execute all tasks
48+
episodes = await workflow_engine.execute_tasks(tasks=tasks)
49+
50+
total = len(episodes)
51+
correct = sum(ep.is_correct for ep in episodes)
52+
print(f"Accuracy: {correct}/{total} = {correct / total:.3f}")
53+
54+
55+
if __name__ == "__main__":
56+
asyncio.run(main())
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from pathlib import Path
2+
3+
from terminal_bench.handlers.trial_handler import TrialHandler
4+
from terminal_bench.parsers.base_parser import UnitTestStatus
5+
from terminal_bench.parsers.parser_factory import ParserFactory
6+
from terminal_bench.terminal.docker_compose_manager import DockerComposeManager
7+
from terminal_bench.terminal.terminal import Terminal
8+
9+
from rllm.agents.agent import Episode
10+
from rllm.integrations.terminal_terminus_1 import RLLMModel
11+
from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow
12+
13+
14+
class TerminalTerminusWorkflow(Workflow):
15+
"""Run Terminus 1 with a generic rollout engine and return an Episode."""
16+
17+
def __init__(
18+
self,
19+
rollout_engine,
20+
model_name: str,
21+
env_args: dict | None = None,
22+
max_steps: int = 50,
23+
global_agent_timeout_sec: float | None = 600.0,
24+
**kwargs,
25+
):
26+
super().__init__(rollout_engine=rollout_engine, **kwargs)
27+
self.model_name = model_name
28+
self.env_args = dict(env_args) if env_args is not None else {}
29+
self.max_steps = max_steps
30+
self.global_agent_timeout_sec = global_agent_timeout_sec
31+
32+
self.trial_handler: TrialHandler | None = None
33+
self.terminal: Terminal | None = None
34+
self.session = None
35+
self.parser = None
36+
self.terminus: RLLMModel | None = None
37+
38+
async def run(self, task: dict, uid: str, **kwargs) -> Episode:
39+
"""Reset, run Terminus to completion, evaluate, and package an Episode."""
40+
observation, info = await self.run_in_executor(self._reset_env, task=task, uid=uid)
41+
42+
prompt = observation["prompt"]
43+
assert self.session is not None and self.terminus is not None
44+
45+
trajectory, termination_reason = await self.terminus.run_agent_loop_with_engine(
46+
initial_prompt=prompt,
47+
session=self.session,
48+
)
49+
50+
try:
51+
reward = await self.run_in_executor(self._evaluate_completion_sync)
52+
finally:
53+
await self.run_in_executor(self._close_env)
54+
55+
episode = Episode(id=uid, task=task, is_correct=bool(reward > 0), trajectories=[("terminus", trajectory)])
56+
episode.termination_reason = termination_reason
57+
return episode
58+
59+
async def _eval_and_terminate(self) -> None:
60+
try:
61+
await self.run_in_executor(self._evaluate_completion_sync)
62+
finally:
63+
await self.run_in_executor(self._close_env)
64+
raise TerminationEvent(TerminationReason.ENV_DONE)
65+
66+
# ------------------------------ Sync helpers ------------------------------
67+
def _reset_env(self, task: dict, uid: str):
68+
"""Create trial, start containers and session, and build initial prompt."""
69+
output_path = Path("/tmp/rllm_terminal_bench_output")
70+
output_path.mkdir(parents=True, exist_ok=True)
71+
72+
task_path = Path(task.get("task_path"))
73+
instruction = task.get("instruction")
74+
task_id = task.get("task_id", "unknown")
75+
76+
self.trial_handler = TrialHandler(
77+
trial_name=f"{task_id}.{uid}.rllm-run",
78+
input_path=task_path,
79+
output_path=output_path,
80+
)
81+
82+
task_config = self.trial_handler.task
83+
self.parser = ParserFactory.get_parser(task_config.parser_name)
84+
85+
self.terminal = Terminal(
86+
client_container_name=self.trial_handler.client_container_name,
87+
client_image_name=self.trial_handler.client_image_name,
88+
docker_compose_path=self.trial_handler.task_paths.docker_compose_path,
89+
docker_image_name_prefix=self.trial_handler.docker_image_name_prefix,
90+
sessions_logs_path=self.trial_handler.trial_paths.sessions_path,
91+
agent_logs_path=self.trial_handler.trial_paths.agent_logging_dir,
92+
no_rebuild=self.env_args.get("no_rebuild", False),
93+
cleanup=self.env_args.get("cleanup", True),
94+
)
95+
self.terminal.start()
96+
self.session = self.terminal.create_session("agent", is_active_stream=False, as_configured_user=True)
97+
98+
self.terminus = RLLMModel(
99+
rollout_engine=self.rollout_engine,
100+
model_name=self.model_name,
101+
max_episodes=self.max_steps,
102+
global_agent_timeout_sec=self.global_agent_timeout_sec,
103+
api_base=self.env_args.get("api_base"),
104+
)
105+
106+
initial_prompt = self.terminus.build_initial_prompt(instruction=instruction, terminal_state=self.session.capture_pane())
107+
108+
observation = {"prompt": initial_prompt, "type": "initial"}
109+
info = {
110+
"task_id": task_id,
111+
"episode": 0,
112+
"max_steps": self.max_steps,
113+
"instruction": instruction,
114+
}
115+
return observation, info
116+
117+
def _evaluate_completion_sync(self) -> float:
118+
"""Copy tests, run them, parse output, and return a binary reward."""
119+
assert self.trial_handler is not None and self.terminal is not None
120+
121+
# Copy tests into the container
122+
paths = [self.trial_handler.task_paths.run_tests_path]
123+
if self.trial_handler.task_paths.test_dir.exists():
124+
paths.append(self.trial_handler.task_paths.test_dir)
125+
self.terminal.copy_to_container(
126+
paths=paths,
127+
container_dir=str(DockerComposeManager.CONTAINER_TEST_DIR),
128+
)
129+
130+
# Choose session per config
131+
if self.trial_handler.task.run_tests_in_same_shell:
132+
print(1)
133+
test_session = self.session
134+
else:
135+
print(2)
136+
test_session = self.terminal.create_session("tests", is_active_stream=False, as_configured_user=False)
137+
138+
# Execute tests
139+
test_script_path = str(DockerComposeManager.CONTAINER_TEST_DIR / "run-tests.sh")
140+
try:
141+
test_session.send_keys(
142+
[f"bash {test_script_path}", "Enter"],
143+
block=True,
144+
max_timeout_sec=self.trial_handler.task.max_test_timeout_sec,
145+
)
146+
test_output = test_session.capture_pane(capture_entire=True)
147+
parser_results = self.parser.parse(test_output)
148+
149+
all_passed = parser_results and all(status == UnitTestStatus.PASSED for status in parser_results.values())
150+
except Exception:
151+
all_passed = False
152+
153+
return 1.0 if all_passed else 0.0
154+
155+
def _close_env(self):
156+
"""Stop/cleanup terminal containers if present."""
157+
if self.terminal:
158+
self.terminal.stop()

0 commit comments

Comments
 (0)