Closed
Description
Motivation.
In online RL training, vLLM can significantly accelerate the rollout stage. To achieve this, we need weight sync from main training process to vLLM worker process, and then call the existing API in vLLM to update the weights by
model_runner.model.load_weights
An example of such implementation can be found in OpenRLHF, https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_worker_wrap.py
However, user has to monkey patch vLLM worker to introduce such behavior. It would be great if vLLM naturally supports weight sync at runtime.
Proposed Change.
- Add a NCCL-based weight sync process group during vLLM initialization, so that main process can dist.broadcast weight to vLLM worker process later
- Expose a weight sync API, for example:
def update_weight(self, name, dtype, shape)
then in master process, user can achieve weight sync via the following (modified from OpenRLHF):
for name, param in model.named_parameters():
# Fire all vllm engines for broadcast
if torch.distributed.get_rank() == 0:
shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
refs = [
engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
for engine in self.vllm_engines
]
torch.distributed.broadcast(param.data, 0, group=self._model_update_group)
ray.get(refs)
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response