-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative Decoding] Support draft model on different tensor-parallel size than target model #5414
[Speculative Decoding] Support draft model on different tensor-parallel size than target model #5414
Changes from 51 commits
f5b5f94
709de21
0eacc96
2011ed0
2e16c4e
b412a51
593ccfa
c5d3476
44e623b
98caf17
7fc4ff5
a96e720
db39576
b2e8595
756442a
32094f1
7890191
53b2ea9
a29c9c5
52ba09d
d26ef08
80c4994
0f16f3f
140f478
3fd7e91
495aa30
3a5a47f
07ddbb8
b0a677d
96782a2
9998b9c
e92ecdc
b421607
386ab9b
b25f74e
8b51f08
d4b283c
dfc90cb
9bef5e4
85d087d
9af36b7
5a0bf45
531c9f0
287da20
08d1b2a
237c966
0bb38c2
c097d6c
957a325
3ec8cb5
8a8a1e4
7f06f64
1e87579
abc546c
7880cb0
2ebe6f3
90d46ee
7e1426c
ad52d93
355475b
9cfdb5b
6a6c5ff
ddef229
965f648
1bb5534
ea6b8f5
71977d2
bc5f77a
5655a49
eabc16a
f748edf
c099c94
4b74a45
c9786ad
a42664a
ac7701a
eea6a7e
a648f5d
f23ba8c
aa9af93
56c8927
385b4f8
43f37eb
99350e2
a9f3e23
6ba250d
3e78613
6532af7
6839797
aac586b
98e584d
2d5e64d
ba88bd4
46e5274
85f4f25
c1b5373
4a58617
b09e7be
7168d78
fe0bd5b
2e0d170
36f8aa5
54bf514
bfd7d2f
f337428
4654b9f
e39926e
1c6eefd
f2d2ee5
302955c
3d4754e
620b224
b245d3c
1e71e98
a01c00d
debffc2
39fe67f
af1b0be
834c6e0
5bc2bc3
8740369
4d82ca1
7bf831c
3fccc76
e8d0e93
91c2e43
fac7e68
271822e
ae0d7f1
b84a070
86fda24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -550,6 +550,10 @@ def init_distributed_environment( | |
global _WORLD | ||
if _WORLD is None: | ||
ranks = list(range(torch.distributed.get_world_size())) | ||
if world_size != -1: | ||
assert world_size == len(ranks), ( | ||
"given world_size does not match with world_size of torch") | ||
wooyeonlee0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
_WORLD = GroupCoordinator( | ||
group_ranks=[ranks], | ||
local_rank=local_rank, | ||
|
@@ -558,7 +562,7 @@ def init_distributed_environment( | |
use_custom_allreduce=False, | ||
) | ||
else: | ||
assert _WORLD.world_size == torch.distributed.get_world_size(), ( | ||
assert _WORLD.world_size == world_size, ( | ||
wooyeonlee0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"world group already initialized with a different world size") | ||
|
||
|
||
|
@@ -673,6 +677,31 @@ def model_parallel_is_initialized(): | |
return (_TP is not None and _PP is not None) | ||
|
||
|
||
OVERRIDE_TP_STATE = False | ||
|
||
|
||
@contextlib.contextmanager | ||
def patch_tensor_parallel_group(world_group, tp_group): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this global variable patching potentially create problem? For example, is it possible that other workers will use this context unknowingly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the current design of speculative decoding, draft and target workers execute sequentially. But if draft and target worker execute concurrently in the future, the code should be redesigned to prevent states being mixed with each other. |
||
"""Patch the tp group temporarily until this function ends.""" | ||
global OVERRIDE_TP_STATE | ||
if OVERRIDE_TP_STATE: | ||
return | ||
|
||
OVERRIDE_TP_STATE = True | ||
old_world_group = get_world_group() | ||
old_tp_group = get_tp_group() | ||
global _WORLD, _TP | ||
_WORLD = world_group | ||
_TP = tp_group | ||
try: | ||
yield | ||
finally: | ||
# restore the original state | ||
OVERRIDE_TP_STATE = False | ||
_WORLD = old_world_group | ||
_TP = old_tp_group | ||
|
||
|
||
def get_tensor_model_parallel_world_size(): | ||
"""Return world size for the tensor model parallel group.""" | ||
return get_tp_group().world_size | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
from typing import List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
import torch.distributed | ||
|
||
from vllm.config import ParallelConfig | ||
from vllm.distributed.parallel_state import (_ENABLE_CUSTOM_ALL_REDUCE, | ||
GroupCoordinator, get_tp_group, | ||
get_world_group, | ||
patch_tensor_parallel_group) | ||
from vllm.logger import init_logger | ||
from vllm.lora.request import LoRARequest | ||
from vllm.sequence import ExecuteModelRequest, SamplerOutput | ||
from vllm.spec_decode.interfaces import SpeculativeProposals | ||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase | ||
from vllm.worker.worker import Worker | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class SmallerTpProposerWorker(ProposerWorkerBase): | ||
"""Class which allows a speculative draft model to run with smaller tensor | ||
parallel degree than target model. | ||
This reduces the communication overhead of small draft models. | ||
|
||
This is implemented by changing vLLM's tensor parallel group to a group of | ||
size temporarily during forward passes of draft models. | ||
""" | ||
|
||
@classmethod | ||
def maybe_wrap_worker(cls, worker, draft_parallel_config: ParallelConfig, | ||
target_parallel_config: ParallelConfig, rank: int): | ||
"""Wrap the worker in a SmallerTpProposerWorker if necessary. | ||
""" | ||
draft_tp = draft_parallel_config.tensor_parallel_size | ||
target_tp = target_parallel_config.tensor_parallel_size | ||
|
||
if draft_tp == target_tp: | ||
return worker | ||
|
||
if draft_tp > target_tp: | ||
raise ValueError( | ||
f"{cls} only supports draft_tp smaller than target_tp." | ||
f"{draft_tp=} {target_tp=}") | ||
|
||
# gpu ranks that will generate draft tokens together | ||
ranks = list(range(draft_tp)) | ||
|
||
if rank in ranks: | ||
logger.info("Wrapping {%s} in {%s}", type(worker), cls) | ||
return cls(worker, ranks) | ||
else: | ||
# for workers not participating in the draft generation | ||
logger.info("Returning dummy worker") | ||
return DummyProposerWorker(worker) | ||
|
||
def __init__(self, worker: Union[Worker, ProposerWorkerBase], | ||
ranks: List[int]): | ||
self._worker = worker | ||
self._ranks = ranks | ||
self._world_group = None | ||
self._tp_group = None | ||
|
||
def _patch_tensor_parallel_group(self): | ||
return patch_tensor_parallel_group(self._world_group, self._tp_group) | ||
|
||
def init_device(self): | ||
"""Initialize the device. | ||
|
||
This also creates an additional tensor-parallel process group containing | ||
only a subset of the whole ranks. | ||
""" | ||
local_rank = get_world_group().local_rank | ||
world_backend = torch.distributed.get_backend( | ||
get_world_group().device_group) | ||
tp_backend = torch.distributed.get_backend(get_tp_group().device_group) | ||
|
||
self._world_group = GroupCoordinator( | ||
group_ranks=[self._ranks], | ||
local_rank=local_rank, | ||
torch_distributed_backend=world_backend, | ||
use_pynccl=False, | ||
use_custom_allreduce=False, | ||
) | ||
self._tp_group = GroupCoordinator( | ||
group_ranks=[self._ranks], | ||
local_rank=local_rank, | ||
torch_distributed_backend=tp_backend, | ||
use_pynccl=True, | ||
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE, | ||
) | ||
|
||
with self._patch_tensor_parallel_group(): | ||
self._worker.init_device() | ||
|
||
def set_include_gpu_probs_tensor(self): | ||
self._worker.set_include_gpu_probs_tensor() | ||
|
||
def load_model(self): | ||
with self._patch_tensor_parallel_group(): | ||
self._worker.load_model() | ||
|
||
def determine_num_available_blocks(self): | ||
with self._patch_tensor_parallel_group(): | ||
return self._worker.determine_num_available_blocks() | ||
|
||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int): | ||
with self._patch_tensor_parallel_group(): | ||
self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) | ||
|
||
@torch.inference_mode() | ||
def sampler_output( | ||
self, | ||
execute_model_req: ExecuteModelRequest, | ||
sample_len: int, | ||
) -> Tuple[Optional[List[SamplerOutput]], bool]: | ||
# do not call _parch_tensor_parallel_group, because | ||
# it's always called after tp_group has already been overridden | ||
return self._worker.sampler_output(execute_model_req, sample_len) | ||
|
||
def get_spec_proposals( | ||
self, | ||
execute_model_req: ExecuteModelRequest, | ||
) -> SpeculativeProposals: | ||
with self._patch_tensor_parallel_group(): | ||
return self._worker.get_spec_proposals(execute_model_req) | ||
|
||
@torch.inference_mode() | ||
def execute_model( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> List[SamplerOutput]: | ||
with self._patch_tensor_parallel_group(): | ||
return self._worker.execute_model(execute_model_req) | ||
|
||
def get_cache_block_size_bytes(self) -> int: | ||
return self._worker.get_cache_block_size_bytes() | ||
|
||
def add_lora(self, lora_request: LoRARequest) -> bool: | ||
raise NotImplementedError | ||
|
||
def remove_lora(self, lora_id: int) -> bool: | ||
raise NotImplementedError | ||
|
||
def list_loras(self) -> Set[int]: | ||
raise NotImplementedError | ||
|
||
@property | ||
def max_model_len(self) -> int: | ||
return self._worker.max_model_len | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
return self._worker.vocab_size | ||
|
||
|
||
class DummyProposerWorker(ProposerWorkerBase): | ||
"""Dummy proposer worker that do nothing. | ||
It's for workers that do not participate in draft generation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
worker: Union[Worker, ProposerWorkerBase], | ||
): | ||
self._worker = worker | ||
|
||
def init_device(self): | ||
pass | ||
|
||
def load_model(self): | ||
pass | ||
|
||
def determine_num_available_blocks(self): | ||
pass | ||
|
||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int): | ||
pass | ||
|
||
def sampler_output( | ||
self, | ||
execute_model_req: ExecuteModelRequest, | ||
sample_len: int, | ||
) -> Tuple[List[SamplerOutput], bool]: | ||
return [], True | ||
|
||
def get_spec_proposals( | ||
self, | ||
execute_model_req: ExecuteModelRequest, | ||
) -> SpeculativeProposals: | ||
return SpeculativeProposals(None, None, None) | ||
|
||
def execute_model( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> List[SamplerOutput]: | ||
return [] | ||
|
||
def get_cache_block_size_bytes(self) -> int: | ||
return 0 | ||
|
||
def add_lora(self, lora_request: LoRARequest) -> bool: | ||
return False | ||
|
||
def remove_lora(self, lora_id: int) -> bool: | ||
return False | ||
|
||
def list_loras(self) -> Set[int]: | ||
return set() | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
return self._worker.vocab_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if the user provides
--speculative-tensor-parallel-size 0
, this branch causes unexpected behavior. Can we explicitly guard against this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch!
To prevent spec_tp from being set as target_tp when given spec_tp is 0, I've changed the code as below:
In addition, to prevent tp value from being 0, I think we need to make a separate PR to handle that case by adding a check in ParallelConfig._verify_args(). Because It seems to be the same in the
--tensor-parallel-size 0
case.What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'll look at your new changes; no need to completely fix this (just a nit)