Description
Posting on behalf of @zw0610
Contemporary Train Pattern
In verl.trainer.ppo.ray_trainer.RayTrainer, the fit(self) method follows this pattern:
- Fetch a batch (from a parquet file): ray_trainer.py#L826
- Generate responses as actions: ray_trainer.py#L849
- Compute advantages
- Critic loss + critic update
- Actor loss + actor update
- Update rollout
This train pattern is not compatible with Agent/Env RL training, which will be illustrated below.
Agent/Env
There might be thousands of definitions of what Agent is. In this case, we can presume an agent is a proxy with certain capability with help from a LLM inference server to interact with some environment (Game, Github Code Issue, etc.) for some tasks.
The introduction of Agent turns the prompt generation in RLHF closer to traditional RL training and expands the cases that LLM can work with.
Train with Agent/Env
There are two main conflicts between training with Agent and the contemporary train pattern:
- Separation of train process (the driver process in our single-controller paradigm) and prompt data
Unlike fetching directly from offline data files, prompt data now is generated in Agents, which usually run as different processes.
To address this issue, we introduced the class MockedServer to collect prompts and worked as a special Dataset for minimal change to the existing train pattern. - Dynamic prompt/train batch generation
If the Agents/Envs require multiple rounds to get the final score, which means traditional Dataset with determined content and length cannot be used if we wish to overlap generation and train stages. (The long-tail issue becomes more severe with multi-round agents/envs.)
To address this issue, we introduced the class BufferDataset to serve as a pipe between generation and train stages and still maintain a single-controller paradigm in driver script.
Function Call v.s. Mocked Server
There is discussion on whether making the Agent as Sever waiting for function call from training process or attaching a mocked LLM inference server to collect prompt posted by Agents.
We decide to adopt the mocked server solution given:
- The high diversity of agents now and in the future, release engineers from endless work to transform the agent to a server every time a new agent emerges;
- The consistency and code reuse between generation stage and evaluation stage.
Design
Architecture
Flow
Here we explained an async flow. A synchronized training flow can be achieved by a barrier between fetch tasks from parquet file and the update rollout step.
We assume the tasks for agents are stored in a parquet file with a similar format to regular data files.
- Fetch a batch of tasks (from a parquet file)
- Enqueue the tasks into ServerWithTask
- May set this in a thread
- Sample a batch of prompts from BufferDataset of prompt
- Generate responses as actions
- Throw the unioned batch (prompt + response) into MultiRoundBufferDataset
- Sample a batch of prompt-response-score from MultiRoundBufferDataset
- Skip this sample action and the following steps if the length of MultiRoundBufferDataset is shorter than the batch size
- Compute advantages
- Critic loss + critic update
- Actor loss + actor update
- Update rollout
Demo
train_dataset = MultiRoundBufferDataset()
train_dataloader_iter = iter(Dataloader(train_dataset))
mocked_server, server_thread = create_and_launch_task_pool(
notify_score_fn=train_dataset.notify_score)
prompt_dataset = BufferDataset(buffer=mocked_server.prompt_queue)
prompt_dataloader_iter = iter(Dataloader(prompt_dataset))
# enqueue task into task server
threading.Thread(target=enqueue_task, args=("xxx.parquet",))
# sample prompt batch and generate sequences
batch = next(prompt_dataloader_iter)
responses = rollout.generate_sequences(prompt_batch)
batch.union(responses)
train_dataset.put_batch(batch)
# sample
if train_dataset.buffer.qsize() >= batch_size:
train_batch = next(prompt_dataloader_iter)
# do ppo
ongoing PR #808