Skip to content

[RFC]: Add runtime weight update API #5723

Closed
@lyuqin-scale

Description

@lyuqin-scale

Motivation.

In online RL training, vLLM can significantly accelerate the rollout stage. To achieve this, we need weight sync from main training process to vLLM worker process, and then call the existing API in vLLM to update the weights by
model_runner.model.load_weights
An example of such implementation can be found in OpenRLHF, https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_worker_wrap.py

However, user has to monkey patch vLLM worker to introduce such behavior. It would be great if vLLM naturally supports weight sync at runtime.

Proposed Change.

  1. Add a NCCL-based weight sync process group during vLLM initialization, so that main process can dist.broadcast weight to vLLM worker process later
  2. Expose a weight sync API, for example:
    def update_weight(self, name, dtype, shape)

then in master process, user can achieve weight sync via the following (modified from OpenRLHF):

for name, param in model.named_parameters():
    # Fire all vllm engines for broadcast
    if torch.distributed.get_rank() == 0:
        shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
        refs = [
            engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
            for engine in self.vllm_engines
        ]

        torch.distributed.broadcast(param.data, 0, group=self._model_update_group)
        ray.get(refs)

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCunstaleRecieved activity after being labelled stale

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions