Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ model:
max_response_tokens: 2048
max_model_len: 4096
cluster: # 2 for explorer, 2 for trainer
node_num: 2
gpu_per_node: 2
node_num: 1
gpu_per_node: 4
buffer:
total_epochs: 1
batch_size: 4
Expand Down
145 changes: 144 additions & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for trainer."""

import asyncio
import multiprocessing
import os
import shutil
Expand All @@ -9,6 +10,7 @@
from datetime import datetime
from unittest import mock

import httpx
import ray
from parameterized import parameterized_class

Expand All @@ -22,7 +24,8 @@
get_unittest_dataset_config,
get_vision_language_model_path,
)
from trinity.cli.launcher import bench, both, explore, run, train
from trinity.buffer import get_buffer_reader
from trinity.cli.launcher import bench, both, explore, run, serve, train
from trinity.common.config import (
AlgorithmConfig,
BufferConfig,
Expand All @@ -40,6 +43,7 @@
SyncStyle,
)
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
from trinity.explorer.explorer_client import ExplorerClient
from trinity.manager.state_manager import StateManager


Expand Down Expand Up @@ -475,6 +479,19 @@ def run_both(config: Config) -> None:
both(config)


def run_serve(config: Config) -> None:
ray.init(
namespace=config.ray_namespace,
runtime_env={
"env_vars": {
LOG_DIR_ENV_VAR: config.log.save_dir,
LOG_LEVEL_ENV_VAR: "INFO",
}
},
)
serve(config)


@parameterized_class(
("use_priority_queue", "strategy"),
[(False, "fsdp"), (True, "fsdp"), (True, "megatron")],
Expand Down Expand Up @@ -841,6 +858,132 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)


async def run_math_workflow(serve_url: str, task: dict):
from trinity.common.rewards.math_reward import MathRewardFn

explorer_client = ExplorerClient(serve_url)
openai_client = explorer_client.get_openai_async_client()

query = task["question"]
truth = task["answer"]

reward_fn = MathRewardFn()

system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think>
<answer> answer here </answer>.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
]

models = await openai_client.models.list()
model = models.data[0].id

response = await openai_client.chat.completions.create(
model=model,
messages=messages,
)
answer = response.choices[0].message.content
reward = reward_fn(response=answer, truth=truth, prompt=query)
await explorer_client.feedback_async(sum(reward.values()))


class TestServeWithTrainer(unittest.IsolatedAsyncioTestCase):
def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)
checkpoint_path = get_checkpoint_path()
shutil.rmtree(os.path.join(checkpoint_path, "unittest"), ignore_errors=True)

async def test_serve_with_trainer(self):
config = get_template_config()
config.project = "unittest"
config.name = f"serve_with_trainer_{datetime.now().strftime('%Y%m%d%H%M%S')}"
config.checkpoint_root_dir = get_checkpoint_path()
config.model.model_path = get_model_path()
config.buffer.batch_size = 4
config.algorithm.algorithm_type = "ppo"
config.algorithm.repeat_times = 1
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.SQL,
)
config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
config.buffer.train_batch_size = 4
config.trainer.total_steps = 4
config.trainer.save_interval = 4
config.synchronizer.sync_interval = 2
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.explorer.rollout_model.engine_num = 2
config.explorer.rollout_model.enable_openai_api = True
config.explorer.rollout_model.tensor_parallel_size = 1
config.explorer.service_status_check_interval = 10

trainer_config = deepcopy(config)
trainer_config.mode = "train"
trainer_config.check_and_update()

trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,))
trainer_process.start()

await asyncio.sleep(10)
serve_config = deepcopy(config)
serve_config.mode = "serve"
serve_config.check_and_update()
serve_process = multiprocessing.Process(target=run_serve, args=(serve_config,))
serve_process.start()

ray.init(ignore_reinit_error=True)
while True:
try:
ray.get_actor("sql-exp_buffer", namespace=trainer_config.ray_namespace)
break
except ValueError:
print("waiting for trainer to start.")
await asyncio.sleep(5)

state_manager = StateManager(
path=serve_config.checkpoint_job_dir,
explorer_name=serve_config.explorer.name,
)

# wait for explorer initialization
for i in range(30):
try:
server_url = state_manager.load_explorer_server_url()
except Exception:
server_url = None
if server_url:
break
await asyncio.sleep(3)
if not server_url:
raise RuntimeError("Explorer server URL not found.")
# wait for server setup
for i in range(10):
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{server_url}/health")
if response.status_code == 200:
break
except Exception:
pass
await asyncio.sleep(2)

reader = get_buffer_reader(serve_config.buffer.explorer_input.taskset, serve_config.buffer)

for i in range(2):
# generate data for 2 trainer steps
tasks = reader.read(batch_size=8)
await asyncio.gather(*(run_math_workflow(server_url, task.raw_task) for task in tasks))

# wait for synchronizer started
end_time = time.time()
while time.time() - end_time < config.explorer.service_status_check_interval:
await asyncio.sleep(1)


class TestMultiModalGRPO(BaseTrainerCase):
@unittest.skip("Require specific vllm/transformers version")
def test_trainer(self):
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def create_inference_models(
allocator = _BundleAllocator(node_bundle_map)
namespace = ray.get_runtime_context().namespace
# create rollout models
# in 'serve' mode, we always enable openai api for rollout model
if config.mode == "serve":
config.explorer.rollout_model.enable_openai_api = True
for i in range(config.explorer.rollout_model.engine_num):
bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size)
config.explorer.rollout_model.bundle_indices = ",".join(
Expand Down
28 changes: 28 additions & 0 deletions trinity/common/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,34 @@ def get_checkpoint_dir_with_step_num(
raise NotImplementedError(f"Unsupported trainer type {trainer_type}")


def get_latest_state_dict(
checkpoint_root_path: str,
trainer_type: str = "verl",
) -> Tuple[str, int]:
"""Get the latest state dict from a root checkpoint directory.

Args:
checkpoint_root_path (str): The root checkpoint directory.

Returns:
Tuple[str, int]: The state dict path and the iteration of the state dict.
If the state dict does not exist, return (None, 0).
"""
if trainer_type != "verl":
raise NotImplementedError(f"Unsupported trainer type {trainer_type}")
latest_state_dict_iteration_path = os.path.join(
checkpoint_root_path, "latest_state_dict_iteration.txt"
)
if os.path.exists(latest_state_dict_iteration_path):
with open(latest_state_dict_iteration_path, "r", encoding="utf-8") as f:
iteration = f.read().strip()
state_dict_path = os.path.join(
checkpoint_root_path, f"global_step_{iteration}", "actor"
)
return state_dict_path, int(iteration)
return None, 0 # type: ignore


def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, Tuple[str, str]]:
"""Load state dict from a checkpoint dir.

Expand Down
35 changes: 32 additions & 3 deletions trinity/explorer/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,21 @@ async def chat_completions(request: Request):
content=f"Error forwarding request to model at {url}: {traceback.format_exc()}",
)
resp_data = resp.json()
await request.app.state.service.record_experience(resp_data)
await request.app.state.service.record_experience(
resp_data, session_id=body.get("session_id", None)
)
return JSONResponse(content=resp_data)


@app.get("/v1/models")
async def show_available_models(request: Request):
body = await request.json()
if hasattr(request.app.state, "models"):
return JSONResponse(content=request.app.state.models)
url = await request.app.state.service.allocate_model(increase_count=False)
async with httpx.AsyncClient() as client:
resp = await client.get(f"{url}/v1/models", json=body)
print(f"Fetching models from {url}/v1/models")
resp = await client.get(f"{url}/v1/models")
request.app.state.models = resp.json()
return JSONResponse(content=resp.json())


Expand All @@ -52,6 +57,30 @@ async def metrics(request: Request):
return JSONResponse(content=metrics)


@app.get("/allocate")
async def allocate(request: Request):
"""Allocate a new session."""
return JSONResponse(content={"session_id": request.app.state.service.allocate_session()})


@app.post("/feedback")
async def feedback(request: Request):
"""Receive feedback for the current session."""
body = await request.json()
session_id = body.get("session_id", None)
reward = body.get("reward", None)
if session_id is None or reward is None:
return JSONResponse(
status_code=400, content={"error": "session_id and reward are required"}
)
if not isinstance(session_id, int) or not isinstance(reward, (int, float)):
return JSONResponse(
status_code=400, content={"error": "session_id must be int and reward must be float"}
)
await request.app.state.service.record_feedback(session_id, reward)
return JSONResponse(content={"status": "success"})


async def serve_http(app: FastAPI, host: str, port: int = None):
config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
Expand Down
Loading