-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[RLHF] use worker_extension_cls for compatibility with V0 and V1 #14185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
7d0cae3
add adapter support
youkaichao c7cfb53
check duplicate
youkaichao 52c9f0e
update colocate
youkaichao f7bb5d7
add files
youkaichao 1525a3a
add in engine args
youkaichao 08416a6
add logging
youkaichao f5db641
add logging
youkaichao dd7b3c3
comments
youkaichao 722c9ad
Merge branch 'main' into worker_adapter
youkaichao 731d4f6
use mixin
youkaichao ae2d12f
use mixin
youkaichao 3661da3
use mixin
youkaichao 1b43146
use mixin
youkaichao 00c6adc
polish logging
youkaichao 86743da
rename to worker_extension_cls
youkaichao 6d9f76b
rename
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
|
||
|
||
def stateless_init_process_group(master_address, master_port, rank, world_size, | ||
device): | ||
""" | ||
vLLM provides `StatelessProcessGroup` to create a process group | ||
without considering the global process group in torch.distributed. | ||
It is recommended to create `StatelessProcessGroup`, and then initialize | ||
the data-plane communication (NCCL) between external (train processes) | ||
and vLLM workers. | ||
""" | ||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | ||
from vllm.distributed.utils import StatelessProcessGroup | ||
pg = StatelessProcessGroup.create(host=master_address, | ||
port=master_port, | ||
rank=rank, | ||
world_size=world_size) | ||
pynccl = PyNcclCommunicator(pg, device=device) | ||
return pynccl | ||
|
||
|
||
class WorkerExtension: | ||
""" | ||
The class for vLLM's worker to inherit from. | ||
By defining an extension class, the code can work no matter what is | ||
the underlying worker class. This way, the code can be compatible | ||
with both vLLM V0 and V1. | ||
NOTE: we define this class in a separate module, and the main module | ||
should pass the full qualified name as `worker_extension_cls` argument. | ||
""" | ||
|
||
def init_weight_update_group(self, master_address, master_port, | ||
rank_offset, world_size): | ||
from vllm.distributed.parallel_state import get_world_group | ||
rank = get_world_group().rank + rank_offset | ||
self.model_update_group = stateless_init_process_group( | ||
master_address, | ||
master_port, | ||
rank, | ||
world_size, | ||
self.device, | ||
) | ||
|
||
def update_weight(self, name, dtype, shape): | ||
weight = torch.empty(shape, dtype=dtype, device="cuda") | ||
self.model_update_group.broadcast(weight, | ||
src=0, | ||
stream=torch.cuda.current_stream()) | ||
|
||
self.model_runner.model.load_weights(weights=[(name, weight)]) | ||
|
||
del weight | ||
|
||
def check_weights_changed(self): | ||
""" | ||
Check if the weights are updated to 0. | ||
""" | ||
weights_updated = True | ||
for name, p in self.model_runner.model.named_parameters(): | ||
weights_updated = weights_updated and torch.allclose( | ||
p, torch.zeros_like(p)) | ||
return weights_updated | ||
|
||
|
||
class ColocateWorkerExtension: | ||
""" | ||
The class for vLLM's worker to inherit from, in the colocate setting. | ||
By defining an extension class, the code can work no matter what is | ||
the underlying worker class. This way, the code can be compatible | ||
with both vLLM V0 and V1. | ||
NOTE: we define this class in a separate module, and the main module | ||
should pass the full qualified name as `worker_extension_cls` argument. | ||
""" | ||
|
||
def report_device_id(self) -> str: | ||
from vllm.platforms import current_platform | ||
self.device_uuid = current_platform.get_device_uuid(self.device.index) | ||
return self.device_uuid | ||
|
||
def update_weights_from_ipc_handles(self, ipc_handles): | ||
handles = ipc_handles[self.device_uuid] | ||
device_id = self.device.index | ||
weights = [] | ||
for name, handle in handles.items(): | ||
func, args = handle | ||
list_args = list(args) | ||
# the key is to change device id to the current device id | ||
# in case two processes have different CUDA_VISIBLE_DEVICES | ||
list_args[6] = device_id | ||
tensor = func(*list_args) | ||
weights.append((name, tensor)) | ||
self.model_runner.model.load_weights(weights=weights) | ||
torch.cuda.synchronize() | ||
|
||
def check_weights_changed(self): | ||
""" | ||
Check if the weights are updated to 0. | ||
""" | ||
weights_updated = True | ||
for name, p in self.model_runner.model.named_parameters(): | ||
weights_updated = weights_updated and torch.allclose( | ||
p, torch.zeros_like(p)) | ||
return weights_updated |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are
update_weight
andcheck_weights_changed
supposed to be used?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_weight
is called externally, bycollective_rpc("update_weight")