Skip to content

Commit fa8cfa0

Browse files
committed
[Spec Decode] Make speculative decoding compatible with pipeline parallelism
Signed-off-by: Xin Yang <xyangx@amazon.com>
1 parent b8015ca commit fa8cfa0

File tree

11 files changed

+106
-45
lines changed

11 files changed

+106
-45
lines changed

vllm/attention/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
# this variable will not be accessed if use_direct_call is True
154154
self.kv_cache = [
155155
torch.tensor([]) for _ in range(
156-
get_current_vllm_config().parallel_config.virtual_engine_size)
156+
get_current_vllm_config().parallel_config.num_virtual_engine)
157157
]
158158

159159
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)

vllm/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,7 @@ class ParallelConfig:
14171417
data_parallel_master_ip: str = "127.0.0.1"
14181418
data_parallel_master_port: int = 29500 # Port of the data parallel master.
14191419
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
1420-
virtual_engine_size: int = 1 # Number of virtual engine.
1420+
num_virtual_engine: int = 1 # Number of virtual engines.
14211421

14221422
# Maximum number of multiple batches
14231423
# when load model sequentially. To avoid RAM OOM when using tensor
@@ -1927,6 +1927,9 @@ class SpeculativeConfig:
19271927
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
19281928
parallelism for the draft model. Can only be 1 or the same as the
19291929
target model's tensor parallel size.
1930+
- draft_pipeline_parallel_size (Optional[int]): The degree of the
1931+
pipeline parallelism for the draft model. Can only be 1 or the
1932+
same as the target model's pipeline parallel size.
19301933
- disable_logprobs (bool): If set to True, token log probabilities are
19311934
not returned during speculative decoding. If set to False, token
19321935
log probabilities are returned according to the log probability
@@ -2321,7 +2324,7 @@ def create_draft_parallel_config(
23212324
draft_parallel_config = ParallelConfig(
23222325
pipeline_parallel_size=speculative_draft_pipeline_parallel_size,
23232326
tensor_parallel_size=speculative_draft_tensor_parallel_size,
2324-
virtual_engine_size=target_parallel_config.virtual_engine_size,
2327+
num_virtual_engine=target_parallel_config.num_virtual_engine,
23252328
distributed_executor_backend=target_parallel_config.
23262329
distributed_executor_backend,
23272330
max_parallel_loading_workers=target_parallel_config.

vllm/engine/arg_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class EngineArgs:
181181
guided_decoding_backend: str = 'xgrammar'
182182
logits_processor_pattern: Optional[str] = None
183183

184-
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
184+
speculative_config: Optional[Dict[str, Any]] = None
185185

186186
qlora_adapter_name_or_path: Optional[str] = None
187187
show_hidden_metrics_for_version: Optional[str] = None
@@ -1189,7 +1189,7 @@ def create_engine_config(
11891189
parallel_config = ParallelConfig(
11901190
pipeline_parallel_size=self.pipeline_parallel_size,
11911191
tensor_parallel_size=self.tensor_parallel_size,
1192-
virtual_engine_size=self.pipeline_parallel_size,
1192+
num_virtual_engine=self.pipeline_parallel_size,
11931193
data_parallel_size=self.data_parallel_size,
11941194
enable_expert_parallel=self.enable_expert_parallel,
11951195
max_parallel_loading_workers=self.max_parallel_loading_workers,

vllm/engine/async_llm_engine.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -812,9 +812,9 @@ async def run_engine_loop(engine_ref: ReferenceType):
812812
if not engine:
813813
return
814814

815-
virtual_engine_size = \
816-
engine.engine.parallel_config.virtual_engine_size
817-
has_requests_in_progress = [False] * virtual_engine_size
815+
num_virtual_engine = \
816+
engine.engine.parallel_config.num_virtual_engine
817+
has_requests_in_progress = [False] * num_virtual_engine
818818
while True:
819819
if not any(has_requests_in_progress):
820820
logger.debug("Waiting for new requests...")
@@ -839,9 +839,9 @@ async def run_engine_loop(engine_ref: ReferenceType):
839839
logger.debug("Got new requests!")
840840
requests_in_progress = [
841841
asyncio.create_task(engine.engine_step(ve))
842-
for ve in range(virtual_engine_size)
842+
for ve in range(num_virtual_engine)
843843
]
844-
has_requests_in_progress = [True] * virtual_engine_size
844+
has_requests_in_progress = [True] * num_virtual_engine
845845

846846
# Abort if iteration takes too long due to unrecoverable errors
847847
# (eg. NCCL timeouts).
@@ -850,7 +850,7 @@ async def run_engine_loop(engine_ref: ReferenceType):
850850
done, _ = await asyncio.wait(
851851
requests_in_progress,
852852
return_when=asyncio.FIRST_COMPLETED)
853-
for _ in range(virtual_engine_size):
853+
for _ in range(num_virtual_engine):
854854
await asyncio.sleep(0)
855855
for task in done:
856856
result = task.result()

vllm/engine/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
266266
# Speculative decoding stats
267267
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
268268
name="vllm:spec_decode_draft_acceptance_rate",
269-
documentation="Speulative token acceptance rate.",
269+
documentation="Speculative token acceptance rate.",
270270
labelnames=labelnames,
271271
multiprocess_mode="sum")
272272
self.gauge_spec_decode_efficiency = self._gauge_cls(

vllm/sequence.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,8 @@ def update(self,
12461246
decode steps"""
12471247
assert len(seq_group_metadata_list) == len(hidden_states)
12481248
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1249+
if self.seq_group_metadata_list is not None:
1250+
self.seq_group_metadata_list.extend(seq_group_metadata_list)
12491251
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
12501252

12511253
if self.second_last_token_hidden_states is not None:
@@ -1270,6 +1272,10 @@ def prune(self,
12701272
# Batch contents changed - prune removed sequences.
12711273
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
12721274
self.hidden_states = self.hidden_states[index]
1275+
if self.seq_group_metadata_list is not None:
1276+
self.seq_group_metadata_list = [
1277+
self.seq_group_metadata_list[i] for i in index
1278+
]
12731279
if self.second_last_token_hidden_states is not None:
12741280
self.second_last_token_hidden_states = self\
12751281
.second_last_token_hidden_states[index]
@@ -1284,12 +1290,23 @@ def expand_with_bonus_tokens(
12841290
return
12851291

12861292
index = []
1287-
for seq_id in self._seq_ids:
1288-
i = self._seq_ids.index(seq_id)
1293+
expanded_seq_ids = []
1294+
expanded_seq_group_metadata_list = []
1295+
for i, seq_id in enumerate(self._seq_ids):
12891296
if seq_id in seq_with_bonus_token_in_last_step:
12901297
index.append(i + len(self._seq_ids))
1298+
expanded_seq_ids.append(seq_id)
1299+
if self.seq_group_metadata_list is not None:
1300+
expanded_seq_group_metadata_list.append(
1301+
self.seq_group_metadata_list[i])
12911302
index.append(i)
1303+
expanded_seq_ids.append(seq_id)
1304+
if self.seq_group_metadata_list is not None:
1305+
expanded_seq_group_metadata_list.append(
1306+
self.seq_group_metadata_list[i])
12921307

1308+
self._seq_ids = expanded_seq_ids
1309+
self.seq_group_metadata_list = expanded_seq_group_metadata_list
12931310
self.hidden_states = torch.cat(
12941311
[self.hidden_states, self.second_last_token_hidden_states])[index]
12951312

@@ -1370,7 +1387,7 @@ def clone(
13701387
virtual_engine=self.virtual_engine,
13711388
num_lookahead_slots=self.num_lookahead_slots,
13721389
running_queue_size=self.running_queue_size,
1373-
previous_hidden_states=self.previous_hidden_states,
1390+
previous_hidden_states=copy.copy(self.previous_hidden_states),
13741391
num_steps=self.num_steps,
13751392
finished_requests_ids=self.finished_requests_ids,
13761393
last_sampled_token_ids=self.last_sampled_token_ids.clone()

vllm/spec_decode/batch_expansion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def score_proposals(
4343
self,
4444
execute_model_req: ExecuteModelRequest,
4545
proposals: SpeculativeProposals,
46-
) -> Optional[SpeculativeScores]:
46+
) -> SpeculativeScores:
4747
"""Score the proposed tokens via the scorer model.
4848
4949
This converts each input sequence to a set of k+1 target sequences. The

vllm/spec_decode/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,5 @@ def score_proposals(
9494
self,
9595
execute_model_req: ExecuteModelRequest,
9696
proposals: SpeculativeProposals,
97-
) -> Optional[SpeculativeScores]:
97+
) -> SpeculativeScores:
9898
raise NotImplementedError

vllm/spec_decode/mqa_scorer.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Optional
4-
3+
from vllm.distributed import get_pp_group
4+
from vllm.model_executor.layers.sampler import SamplerOutput
55
from vllm.sequence import (ExecuteModelRequest, SequenceData,
66
SequenceGroupMetadata, get_all_seq_ids)
77
from vllm.spec_decode.interfaces import (SpeculativeProposals,
@@ -17,7 +17,7 @@ def score_proposals(
1717
self,
1818
execute_model_req: ExecuteModelRequest,
1919
proposals: SpeculativeProposals,
20-
) -> Optional[SpeculativeScores]:
20+
) -> SpeculativeScores:
2121
target_seq_group_metadata_list = []
2222
target_seq_id_start = max(
2323
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
@@ -70,8 +70,40 @@ def score_proposals(
7070
seq_group_metadata_list=target_seq_group_metadata_list))
7171

7272
target_sampler_output = target_sampler_output[0]
73-
if target_sampler_output is None:
74-
return None
73+
if get_pp_group().is_last_rank:
74+
assert len(
75+
target_sampler_output) == 1, "expected single-step output"
76+
target_sampler_output = target_sampler_output[0]
77+
# Store hidden states from target model execution, BxD.
78+
sampled_token_probs = target_sampler_output.sampled_token_probs
79+
logprobs = target_sampler_output.logprobs
80+
sampled_token_ids = target_sampler_output.sampled_token_ids
81+
hidden_states = target_sampler_output.hidden_states
82+
prefill_hidden_states = target_sampler_output.prefill_hidden_states
83+
tensors = {
84+
"sampled_token_probs": sampled_token_probs,
85+
"logprobs": logprobs,
86+
"sampled_token_ids": sampled_token_ids,
87+
"hidden_states": hidden_states,
88+
"prefill_hidden_states": prefill_hidden_states
89+
}
90+
get_pp_group().broadcast_tensor_dict(
91+
tensors, src=get_pp_group().world_size - 1)
92+
else:
93+
tensors = get_pp_group().broadcast_tensor_dict(
94+
src=get_pp_group().world_size - 1)
95+
sampled_token_probs = tensors["sampled_token_probs"]
96+
logprobs = tensors["logprobs"]
97+
sampled_token_ids = tensors["sampled_token_ids"]
98+
hidden_states = tensors["hidden_states"]
99+
prefill_hidden_states = tensors["prefill_hidden_states"]
100+
target_sampler_output = SamplerOutput(
101+
outputs=None,
102+
sampled_token_probs=sampled_token_probs,
103+
logprobs=logprobs,
104+
sampled_token_ids=sampled_token_ids,
105+
hidden_states=hidden_states,
106+
prefill_hidden_states=prefill_hidden_states)
75107

76108
k = execute_model_req.num_lookahead_slots
77109
bs = len(execute_model_req.seq_group_metadata_list)

vllm/spec_decode/spec_decode_worker.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def __init__(
337337

338338
# Hidden states from target model to pass to proposer
339339
# in the subsequent step.
340-
self.previous_hidden_states: Optional[HiddenStates] = None
340+
self.previous_hidden_states: Dict[int, Optional[HiddenStates]] = {}
341341
self._disable_logprobs = disable_logprobs
342342
self._disable_log_stats = disable_log_stats
343343
self._num_spec_prefill_steps = num_spec_prefill_steps
@@ -374,11 +374,13 @@ def init_device(self) -> None:
374374
self.proposer_worker.maybe_load_lm_head_weight(
375375
target_lm_head_weight)
376376

377-
self._metrics.init_tensors(self.rank, device_type=self.device)
378377
if model_parallel_is_initialized():
378+
self._metrics.init_tensors(get_tp_group().rank_in_group,
379+
device_type=self.device)
379380
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
380381
device_type=self.device)
381382
else:
383+
self._metrics.init_tensors(self.rank, device_type=self.device)
382384
self.spec_decode_sampler.init_tensors(self.rank,
383385
device_type=self.device)
384386

@@ -467,7 +469,9 @@ def execute_model(
467469
) -> List[SamplerOutput]:
468470
"""Perform speculative decoding on the input batch.
469471
"""
470-
if self.rank % self.tensor_parallel_size != self._driver_rank:
472+
rank = get_tp_group().rank_in_group if model_parallel_is_initialized(
473+
) else self.rank
474+
if rank != self._driver_rank:
471475
self._run_non_driver_rank()
472476
return []
473477

@@ -721,14 +725,19 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
721725
hidden_states = hidden_states[
722726
torch.where(sampler_output.sampled_token_ids -
723727
VLLM_INVALID_TOKEN_ID)[0]]
724-
if self.previous_hidden_states is None and len(
725-
seq_group_meta_with_hidden):
726-
self.previous_hidden_states = HiddenStates(
727-
hidden_states, seq_group_meta_with_hidden)
728-
elif self.previous_hidden_states and len(
729-
seq_group_meta_with_hidden):
730-
self.previous_hidden_states.update(hidden_states,
731-
seq_group_meta_with_hidden)
728+
if execute_model_req.virtual_engine not in \
729+
self.previous_hidden_states and \
730+
len(seq_group_meta_with_hidden):
731+
self.previous_hidden_states[
732+
execute_model_req.virtual_engine] = HiddenStates(
733+
hidden_states, seq_group_meta_with_hidden)
734+
elif execute_model_req.virtual_engine in \
735+
self.previous_hidden_states and \
736+
len(seq_group_meta_with_hidden):
737+
previous_hidden_states: HiddenStates = \
738+
self.previous_hidden_states[execute_model_req.virtual_engine]
739+
previous_hidden_states.update(hidden_states,
740+
seq_group_meta_with_hidden)
732741

733742
if not skip_proposer:
734743
# We prepare the prefill hidden states here so that there no
@@ -804,17 +813,15 @@ def _run_speculative_decoding_step(
804813
Returns a list of SamplerOutput, each containing a single token per
805814
sequence.
806815
"""
807-
if self.previous_hidden_states is not None:
808-
self.previous_hidden_states.seq_group_metadata_list = execute_model_req.seq_group_metadata_list
809816
if get_pp_group().is_first_rank:
810817
# With prefill chunking, expect requests to have prompts first
811818
# so that backend gets prefill|decode.
812819
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
813820

814821
# Pass last hidden states from target model to proposer
815822
execute_model_req.previous_hidden_states = \
816-
self.previous_hidden_states
817-
self.previous_hidden_states = None
823+
self.previous_hidden_states[execute_model_req.virtual_engine]
824+
self.previous_hidden_states.pop(execute_model_req.virtual_engine)
818825

819826
with Timer() as proposal_timer:
820827
# Generate proposals using draft worker.
@@ -883,8 +890,8 @@ def _run_speculative_decoding_step(
883890

884891
with Timer() as verification_timer:
885892
accepted_token_ids, target_logprobs = self._verify_tokens(
886-
execute_model_req.seq_group_metadata_list, proposal_scores,
887-
proposals, execute_model_req.num_lookahead_slots)
893+
execute_model_req, proposal_scores, proposals,
894+
execute_model_req.num_lookahead_slots)
888895

889896
stage_times = (proposal_execute_time, scoring_timer.elapsed_time_ms,
890897
verification_timer.elapsed_time_ms)
@@ -901,7 +908,7 @@ def _run_speculative_decoding_step(
901908
@nvtx_range("spec_decode_worker._verify_tokens")
902909
def _verify_tokens(
903910
self,
904-
seq_group_metadata_list: List[SequenceGroupMetadata],
911+
execute_model_req: ExecuteModelRequest,
905912
proposal_scores: SpeculativeScores,
906913
proposals: SpeculativeProposals,
907914
max_proposal_len: int,
@@ -912,6 +919,7 @@ def _verify_tokens(
912919
Returns a tuple of Tensors, one for the accepted token ids and one for
913920
the logprobs according to the scoring model.
914921
"""
922+
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
915923
proposal_lens_list = proposals.proposal_lens.tolist()
916924

917925
# vLLM currently only supports proposal lens equal to zero or the batch
@@ -991,9 +999,10 @@ def _verify_tokens(
991999
second_last_token_hidden_states = hidden_states[:, -2] # b x d
9921000
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
9931001
# Store hidden states from target model for subsequent decode step
994-
self.previous_hidden_states = HiddenStates(
995-
hidden_states, terminal_metadata,
996-
second_last_token_hidden_states)
1002+
self.previous_hidden_states[
1003+
execute_model_req.virtual_engine] = HiddenStates(
1004+
hidden_states, terminal_metadata,
1005+
second_last_token_hidden_states)
9971006
return accepted_token_ids, logprobs
9981007

9991008
def _create_output_sampler_list(

0 commit comments

Comments
 (0)