Skip to content

[RLHF] use worker_extension_cls for compatibility with V0 and V1 #14185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 6, 2025
6 changes: 4 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ steps:
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
- pushd ../examples/offline_inference
- python3 rlhf.py
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd

- label: Metrics, Tracing Test # 10min
num_gpus: 2
Expand Down
66 changes: 3 additions & 63 deletions examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,72 +18,11 @@
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker


def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


class MyWorker(Worker):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""

def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)

def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())

self.model_runner.model.load_weights(weights=[(name, weight)])

del weight

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class MyLLM(LLM):
Expand Down Expand Up @@ -129,7 +68,7 @@ def __init__(self, *args, **kwargs):
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
Expand Down Expand Up @@ -159,6 +98,7 @@ def __init__(self, *args, **kwargs):

handle = llm.collective_rpc.remote("init_weight_update_group",
args=(master_address, master_port, 1, 3))

model_update_group = stateless_init_process_group(master_address, master_port,
0, 3, torch.device("cuda:0"))
ray.get(handle)
Expand Down
36 changes: 1 addition & 35 deletions examples/offline_inference/rlhf_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,6 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from vllm import LLM
from vllm.worker.worker import Worker


class MyWorker(Worker):

def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid

def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class MyLLM(LLM):
Expand Down Expand Up @@ -150,7 +116,7 @@ def get_weight_ipc_handles(self):
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.4,
Expand Down
105 changes: 105 additions & 0 deletions examples/offline_inference/rlhf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


class WorkerExtension:
"""
The class for vLLM's worker to inherit from.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""

def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)

def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())

self.model_runner.model.load_weights(weights=[(name, weight)])

del weight
Comment on lines +46 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are update_weightand check_weights_changed supposed to be used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_weight is called externally, by collective_rpc("update_weight")


def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""

def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid

def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,7 @@ class ParallelConfig:
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
worker_extension_cls: str = ""

# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
Expand Down Expand Up @@ -1523,6 +1524,9 @@ def _verify_args(self) -> None:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

assert isinstance(self.worker_extension_cls, str), (
"worker_extension_cls must be a string (qualified class name).")


@dataclass
class SchedulerConfig:
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class EngineArgs:
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"
worker_extension_cls: str = ""

kv_transfer_config: Optional[KVTransferConfig] = None

Expand Down Expand Up @@ -1015,6 +1016,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
'--worker-extension-cls',
type=str,
default="",
help='The worker extension class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=nullable_str,
Expand Down Expand Up @@ -1209,6 +1217,7 @@ def create_engine_config(self,
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
)

max_model_len = model_config.max_model_len
Expand Down
27 changes: 27 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,37 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
logger.warning(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
if self.vllm_config.parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}.")
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls, )
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls, worker_class, extended_calls)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
Expand Down