Skip to content

[RFC] verl with Agent/Env #1172

Open
Open
@eric-haibin-lin

Description

@eric-haibin-lin

Posting on behalf of @zw0610

Contemporary Train Pattern

In verl.trainer.ppo.ray_trainer.RayTrainer, the fit(self) method follows this pattern:

  1. Fetch a batch (from a parquet file): ray_trainer.py#L826
  2. Generate responses as actions: ray_trainer.py#L849
  3. Compute advantages
  4. Critic loss + critic update
  5. Actor loss + actor update
  6. Update rollout
    This train pattern is not compatible with Agent/Env RL training, which will be illustrated below.

Agent/Env

There might be thousands of definitions of what Agent is. In this case, we can presume an agent is a proxy with certain capability with help from a LLM inference server to interact with some environment (Game, Github Code Issue, etc.) for some tasks.
The introduction of Agent turns the prompt generation in RLHF closer to traditional RL training and expands the cases that LLM can work with.

Train with Agent/Env

There are two main conflicts between training with Agent and the contemporary train pattern:

  1. Separation of train process (the driver process in our single-controller paradigm) and prompt data
    Unlike fetching directly from offline data files, prompt data now is generated in Agents, which usually run as different processes.
    To address this issue, we introduced the class MockedServer to collect prompts and worked as a special Dataset for minimal change to the existing train pattern.
  2. Dynamic prompt/train batch generation
    If the Agents/Envs require multiple rounds to get the final score, which means traditional Dataset with determined content and length cannot be used if we wish to overlap generation and train stages. (The long-tail issue becomes more severe with multi-round agents/envs.)
    To address this issue, we introduced the class BufferDataset to serve as a pipe between generation and train stages and still maintain a single-controller paradigm in driver script.

Function Call v.s. Mocked Server

There is discussion on whether making the Agent as Sever waiting for function call from training process or attaching a mocked LLM inference server to collect prompt posted by Agents.
We decide to adopt the mocked server solution given:

  1. The high diversity of agents now and in the future, release engineers from endless work to transform the agent to a server every time a new agent emerges;
  2. The consistency and code reuse between generation stage and evaluation stage.

Design

Architecture

Image

Flow

Here we explained an async flow. A synchronized training flow can be achieved by a barrier between fetch tasks from parquet file and the update rollout step.
We assume the tasks for agents are stored in a parquet file with a similar format to regular data files.

  1. Fetch a batch of tasks (from a parquet file)
  2. Enqueue the tasks into ServerWithTask
  3. May set this in a thread
  4. Sample a batch of prompts from BufferDataset of prompt
  5. Generate responses as actions
  6. Throw the unioned batch (prompt + response) into MultiRoundBufferDataset
  7. Sample a batch of prompt-response-score from MultiRoundBufferDataset
  8. Skip this sample action and the following steps if the length of MultiRoundBufferDataset is shorter than the batch size
  9. Compute advantages
  10. Critic loss + critic update
  11. Actor loss + actor update
  12. Update rollout

Demo

train_dataset = MultiRoundBufferDataset()
train_dataloader_iter = iter(Dataloader(train_dataset))

mocked_server, server_thread = create_and_launch_task_pool(
                                   notify_score_fn=train_dataset.notify_score)
                                   
prompt_dataset = BufferDataset(buffer=mocked_server.prompt_queue)
prompt_dataloader_iter = iter(Dataloader(prompt_dataset))

# enqueue task into task server
threading.Thread(target=enqueue_task, args=("xxx.parquet",))

# sample prompt batch and generate sequences
batch = next(prompt_dataloader_iter)
responses = rollout.generate_sequences(prompt_batch)
batch.union(responses)
train_dataset.put_batch(batch)

# sample
if train_dataset.buffer.qsize() >= batch_size:
    train_batch = next(prompt_dataloader_iter)
    # do ppo

ongoing PR #808

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions