Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions examples/vimgolf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# VimGolf Training Examples

This directory contains examples for training of VimGolf agent models using the RLLM framework. The VimGolf agent training pipeline uses 612 VimGolf public challenges and VimGolf validator for checking agent solutions.

You need to have Vim installed on local machine.

Our examples use the following:

- **deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B** as the base model
- **VimGolf Public Challenges** for training data

## Dataset Preparation

First prepare the dataset:

```bash
cd examples/vimgolf
python prepare_vimgolf_data.py
```

This will generate a dataset named `vimgolf-public-challenges` at `DatasetRegistry`.


## Model Hosting

### Option 1: Using vLLM

Start a vLLM server with OpenAI-compatible API:

```bash
python -m vllm.entrypoints.openai.api_server \
--model "<model_saved_path>" \
--host 0.0.0.0 \
--port 30000 \
--dtype bfloat16
```

### Option 2: Using SGLang

```bash
python -m sglang_router.launch_server \
--model-path "<model_saved_path>" \
--dp-size 1 \
--dtype bfloat16
# increase dp_size to enable data-parallel processing on multi-GPU
```

The server should be accessible at `http://localhost:30000/v1`

## Training

Install dependencies:

```bash
pip install -r requirements.txt
```

Run training with the `vimgolf-public-challenges` dataset:

```bash
bash train_vimgolf_agent.sh
```

**Configuration Options:**
You can modify the training script parameters:
- `actor_rollout_ref.model.path`: Base model to train
- `trainer.total_epochs`: Number of training epochs
- `data.train_batch_size`: Total batch size across all GPUs
- `data.micro_batch_size_per_gpu`: Batch size per GPU
- `data.max_prompt_length`: Maximum prompt length
- `data.max_response_length`: Maximum response length

The training script will:
- Load the base model (deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)
- Fine-tune on the `vimgolf-public-challenges` dataset
- Save checkpoints to `checkpoints/${trainer.project_name}/${trainer.experiment_name}`

## Evaluation

You have to host the trained model at `http://localhost:30000/v1` first before evaluation.

Evaluate the trained model using the saved checkpoint:

```bash
cd examples/vimgolf
python run_vimgolf.py --model_name "<model_saved_path>"
```

Replace `<model_saved_path>` with the actual path to your trained model checkpoint, usually at `checkpoints/${trainer.project_name}/${trainer.experiment_name}`.
127 changes: 127 additions & 0 deletions examples/vimgolf/lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""VimGolf agents and environments implementation."""

import copy
import json
from typing import Any

import vimgolf_gym
import vimgolf_gym.dataclasses

from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
from rllm.environments import SingleTurnEnvironment
from rllm.rewards import RewardOutput


class VimGolfSingleTurnAgent(BaseAgent):
"""
A single turn VimGolf Agent.
"""

def __init__(self, accumulate_thinking=True):
"""
Initialize the VimGolfSingleTurnAgent.
"""
self._trajectory = Trajectory()
self.messages = []
self.accumulate_thinking = accumulate_thinking

def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
"""Process environment feedback and update internal state."""

# Format observation based on whether it's the initial problem or subsequent feedback
if not self.trajectory.steps:
# Initial problem presentation
assert isinstance(observation, dict) and "question" in observation
question = observation["question"]
formatted_observation = question
else:
# Follow-up correction prompt (never used, to be changed in multi-turn agent)
formatted_observation = "Your previous answer may contain a mistake. Please review it carefully and answer again."

self.messages.append({"role": "user", "content": formatted_observation})

def update_from_model(self, response: str, **kwargs) -> Action:
"""
Updates the agent's internal state based on the model's response.
"""
self.messages.append({"role": "assistant", "content": response})
new_step = Step(chat_completions=copy.deepcopy(self.chat_completions))
self.trajectory.steps.append(new_step)

return Action(action=response)

def reset(self):
"""Reset agent state for new episode."""
self._trajectory = Trajectory()
self.messages = []

@property
def chat_completions(self) -> list[dict[str, str]]:
"""Return conversation history for model interaction."""
# remove thinking from assistant messages if not accumulate_thinking except the last one
messages = copy.deepcopy(self.messages)
if not self.accumulate_thinking:
for msg in messages[:-1]:
if msg["role"] == "assistant":
_, sep, after = msg["content"].partition("</think>")
if sep:
msg["content"] = after
return messages

@property
def trajectory(self) -> Trajectory:
"""Return complete interaction trajectory."""
return self._trajectory

def get_current_state(self) -> Step:
"""Returns the current step/state of the agent."""
assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called."
return self._trajectory.steps[-1]


def vimgolf_reward_function(task_info: dict, action: str) -> RewardOutput:
task_data_str = task_info.get("ground_truth")
task_data = json.loads(task_data_str)

input = task_data["input"]
target = task_data["target"]
challenge_id = task_data["id"]

solution = get_last_non_empty_line(action)
custom_challenge = vimgolf_gym.dataclasses.VimGolfCustomChallenge(input=input, output=target, solution=solution, name=challenge_id)
verified = run_vimgolf_local(custom_challenge)
if verified:
reward = 1.0
is_correct = True
else:
reward = 0.0
is_correct = False
ret = RewardOutput(reward=reward, is_correct=is_correct, metadata={})
return ret


def run_vimgolf_local(custom_challenge: vimgolf_gym.dataclasses.VimGolfCustomChallenge):
validated = False
with vimgolf_gym.make(
"vimgolf-custom",
custom_challenge=custom_challenge,
) as env:
if custom_challenge.solution:
validated = env.verify_keys(custom_challenge.solution)
return validated


def get_last_non_empty_line(content: str):
lines = content.splitlines()
lines = [it.strip() for it in lines if it.strip()]
if lines:
return lines[-1]
else:
return ""


class VimGolfSingleTurnEnv(SingleTurnEnvironment):
"""Single turn environment for VimGolf."""

def __init__(self, task=None, reward_fn=None, **kwargs):
super().__init__(task=task, reward_fn=vimgolf_reward_function, **kwargs)
81 changes: 81 additions & 0 deletions examples/vimgolf/prepare_vimgolf_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import os

from datasets import Dataset

from rllm.data.dataset import DatasetRegistry

_DATASET_PATH = os.path.join(os.path.dirname(__file__), "vimgolf_public_challenges.jsonl")


def prepare_vimgolf_data():
"""
Prepare the vimgolf dataset for training.

Even if we use all the data for training, we will not exhaust the game, since we can ask the model to use fewer keystrokes.
"""
datalist = []
with open(_DATASET_PATH) as f:
for line in f:
line_data = json.loads(line)
input = line_data["input"]
target = line_data["target"]
details = line_data["metadata"]["detail"]
challenge_data = dict(
input=input,
output=target,
challenge_id=line_data["id"],
)
question_prompt = f"""
Vimgolf is a game where you try to transform text using the fewest number of keystrokes in Vim.

Your task is to solve the following Vimgolf challenge with details:

Details:

{details}

The input file wrapped in triple backticks:

```
{input}
```

The output file wrapped in triple backticks:

```
{target}
```

Your keystokes must be less than the length of output file. Do not naively copy and paste the output file. You must use Vim commands to transform the input file into the output file.

Here are some example solutions, for format demostration (all solutions shall be in one line):

iHello World<Esc>:wq<NL>

:%s/abcdef/defabc/g<NL>:wq<NL>

Your last line of response will be treated as solution. Do not wrap the solution around any marker (like triple backticks), just write it in plain style. Do not write it in multiline style. Do not write any comment or explanation. Do not write any other text. Just write the solution. If your solution contains multiple steps, you will concatenate these steps into one line, optionally using <NL> as separator, depending on the situation.

Example response:

I think the following solution is optimal:

iHello World<Esc>:s/World/Earth/g<NL>:wq<NL>

Please write your solution according to the rules and the example response:
"""
it = {
"question": question_prompt,
"ground_truth": json.dumps(challenge_data),
"data_source": "vimgolf-public-challenges",
}
datalist.append(it)

train_dataset = Dataset.from_list(datalist)
train_dataset = DatasetRegistry.register_dataset(name="vimgolf-public-challenges", data=train_dataset, split="train")


if __name__ == "__main__":
train_dataset = prepare_vimgolf_data()
print(train_dataset)
3 changes: 3 additions & 0 deletions examples/vimgolf/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
vimgolf-gym==0.1.1
hydra-core
omegaconf
57 changes: 57 additions & 0 deletions examples/vimgolf/run_vimgolf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio

from lib import VimGolfSingleTurnAgent, VimGolfSingleTurnEnv
from transformers import AutoTokenizer

from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_execution_engine import AgentExecutionEngine
from rllm.utils import compute_pass_at_k


def load_vimgolf_data():
if DatasetRegistry.dataset_exists(name="vimgolf-public-challenges", split="train"):
test_dataset = DatasetRegistry.load_dataset(name="vimgolf-public-challenges", split="train")
return test_dataset.get_data()
raise ValueError("vimgolf-public-challenges dataset not found. Please run `python prepare_vimgolf_data.py` to create the dataset.")


if __name__ == "__main__":
import argparse
import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"

parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Model name to be used for evaluation",
)
args = parser.parse_args()

model_name = args.model_name # to be passed via command line
tokenizer = AutoTokenizer.from_pretrained(model_name)
sampling_params = {"temperature": 1, "model": model_name}

engine = AgentExecutionEngine(
agent_class=VimGolfSingleTurnAgent,
env_class=VimGolfSingleTurnEnv,
agent_args={},
env_args={},
engine_name="openai",
tokenizer=tokenizer,
sampling_params=sampling_params,
rollout_engine_args={
"base_url": "http://localhost:30000/v1",
"api_key": "None",
},
n_parallel_agents=48,
max_response_length=65536,
max_prompt_length=4096,
)

tasks = load_vimgolf_data()

results = asyncio.run(engine.execute_tasks(tasks))
compute_pass_at_k(results)
26 changes: 26 additions & 0 deletions examples/vimgolf/train_vimgolf_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import hydra
from lib import VimGolfSingleTurnAgent, VimGolfSingleTurnEnv

from rllm.data import DatasetRegistry
from rllm.trainer.agent_trainer import AgentTrainer


@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="ppo_trainer",
version_base=None,
)
def main(config):
dataset = DatasetRegistry.load_dataset(name="vimgolf-public-challenges", split="train")

trainer = AgentTrainer(
agent_class=VimGolfSingleTurnAgent,
env_class=VimGolfSingleTurnEnv,
config=config,
train_dataset=dataset,
)
trainer.train()


if __name__ == "__main__":
main()
Loading