Skip to content

Commit ecd9f20

Browse files
committed
[Spec Decode] Make speculative decoding compatible with pipeline parallelism
Signed-off-by: Xin Yang <xyangx@amazon.com>
1 parent 374ee28 commit ecd9f20

16 files changed

+346
-95
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests which cover integration of the speculative decoding framework with
3+
pipeline parallelism.
4+
"""
5+
6+
from typing import Optional
7+
8+
import pytest
9+
import torch
10+
11+
from vllm.platforms import current_platform
12+
13+
from .conftest import run_equality_correctness_test_tp
14+
15+
16+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
17+
reason="Need at least 2 GPUs to run the test.")
18+
@pytest.mark.parametrize(
19+
"common_llm_kwargs",
20+
[[
21+
# Skip cuda graph recording for fast test.
22+
"--enforce-eager",
23+
"--pipeline-parallel-size",
24+
"2",
25+
26+
# precision
27+
"--dtype",
28+
"bfloat16",
29+
]])
30+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
31+
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
32+
@pytest.mark.parametrize("model, test_llm_kwargs",
33+
[("JackFram/llama-68m", [
34+
"--speculative-model",
35+
"JackFram/llama-68m",
36+
"--num-speculative-tokens",
37+
"5",
38+
"--speculative-draft-pipeline-parallel-size",
39+
"1",
40+
]),
41+
("ibm-granite/granite-3b-code-instruct", [
42+
"--speculative-model",
43+
"ibm-granite/granite-3b-code-instruct",
44+
"--num-speculative-tokens",
45+
"5",
46+
"--speculative-draft-pipeline-parallel-size",
47+
"1",
48+
])])
49+
@pytest.mark.parametrize("batch_size", [2])
50+
@pytest.mark.parametrize("seed", [1])
51+
def test_draft_model_pp_lt_target_model_pp2(model, common_llm_kwargs,
52+
per_test_common_llm_kwargs,
53+
baseline_llm_kwargs,
54+
test_llm_kwargs, batch_size: int,
55+
seed: int):
56+
"""Verify spec decode works well with smaller pp for draft models.
57+
"""
58+
if current_platform.is_rocm():
59+
pytest.skip("hip is not well-supported yet")
60+
run_equality_correctness_test_tp(model,
61+
common_llm_kwargs,
62+
per_test_common_llm_kwargs,
63+
baseline_llm_kwargs,
64+
test_llm_kwargs,
65+
batch_size,
66+
max_output_len=32,
67+
seed=seed,
68+
temperature=0.0)
69+
70+
71+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
72+
reason="Need at least 2 GPUs to run the test.")
73+
@pytest.mark.parametrize(
74+
"common_llm_kwargs",
75+
[[
76+
# Skip cuda graph recording for fast test.
77+
"--enforce-eager",
78+
"--pipeline-parallel-size",
79+
"2",
80+
81+
# precision
82+
"--dtype",
83+
"bfloat16",
84+
]])
85+
@pytest.mark.parametrize(
86+
"per_test_common_llm_kwargs",
87+
[["--enable-chunked-prefill", "False"],
88+
[
89+
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
90+
"--max-num-seqs", "4"
91+
]])
92+
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
93+
@pytest.mark.parametrize("model, test_llm_kwargs",
94+
[("JackFram/llama-68m", [
95+
"--speculative-model",
96+
"JackFram/llama-68m",
97+
"--num-speculative-tokens",
98+
"3",
99+
"--speculative-draft-pipeline-parallel-size",
100+
"1",
101+
])])
102+
@pytest.mark.parametrize("logprobs", [None, 2])
103+
@pytest.mark.parametrize("batch_size", [2])
104+
@pytest.mark.parametrize("seed", [1])
105+
def test_spec_decode_chunked_prefill_pp2(model, common_llm_kwargs,
106+
per_test_common_llm_kwargs,
107+
baseline_llm_kwargs, test_llm_kwargs,
108+
logprobs: Optional[int],
109+
batch_size: int, seed: int):
110+
"""Verify spec decode works well with same and different PP size for
111+
the draft model with chunked prefill.
112+
"""
113+
if logprobs:
114+
test_llm_kwargs.extend(
115+
["--disable-logprobs-during-spec-decoding", "False"])
116+
run_equality_correctness_test_tp(model,
117+
common_llm_kwargs,
118+
per_test_common_llm_kwargs,
119+
baseline_llm_kwargs,
120+
test_llm_kwargs,
121+
batch_size,
122+
max_output_len=32,
123+
seed=seed,
124+
temperature=0.0,
125+
logprobs=logprobs)

tests/spec_decode/test_multi_step_worker.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.model_executor.utils import set_random_seed
1313
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
1414
get_all_seq_ids)
15-
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
15+
from vllm.spec_decode.draft_model_runner import TP1PP1DraftModelRunner
1616
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1717
from vllm.spec_decode.top1_proposer import Top1Proposer
1818
from vllm.worker.worker import Worker
@@ -91,7 +91,7 @@ def test_same_output_for_single_step():
9191
block_size,
9292
num_gpu_blocks,
9393
seed,
94-
model_runner_cls=TP1DraftModelRunner,
94+
model_runner_cls=TP1PP1DraftModelRunner,
9595
)
9696
worker = create_worker(
9797
Worker,
@@ -304,7 +304,7 @@ def test_multi_step_with_batch_expansion_correct_output():
304304
block_size,
305305
num_gpu_blocks,
306306
seed,
307-
model_runner_cls=TP1DraftModelRunner,
307+
model_runner_cls=TP1PP1DraftModelRunner,
308308
)
309309
multi_step_worker.set_include_gpu_probs_tensor()
310310
worker = create_worker(
@@ -399,7 +399,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
399399
block_size,
400400
num_gpu_blocks,
401401
seed,
402-
model_runner_cls=TP1DraftModelRunner,
402+
model_runner_cls=TP1PP1DraftModelRunner,
403403
)
404404
multi_step_worker.set_include_gpu_probs_tensor()
405405
worker = create_worker(
@@ -502,13 +502,14 @@ def test_multi_step_correct_kvcache(num_steps, attn_backend):
502502

503503
with global_force_attn_backend_context_manager(attn_backend):
504504
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
505-
multi_step_worker = create_worker(MultiStepWorker,
506-
model_name,
507-
block_size,
508-
num_gpu_blocks,
509-
seed,
510-
model_runner_cls=TP1DraftModelRunner,
511-
dtype=dtype)
505+
multi_step_worker = create_worker(
506+
MultiStepWorker,
507+
model_name,
508+
block_size,
509+
num_gpu_blocks,
510+
seed,
511+
model_runner_cls=TP1PP1DraftModelRunner,
512+
dtype=dtype)
512513
multi_step_worker.set_include_gpu_probs_tensor()
513514
worker = create_worker(Worker,
514515
model_name,
@@ -771,7 +772,7 @@ def test_use_draft_model_runner_advance_step():
771772
block_size,
772773
num_gpu_blocks,
773774
seed,
774-
model_runner_cls=TP1DraftModelRunner,
775+
model_runner_cls=TP1PP1DraftModelRunner,
775776
)
776777

777778
# Mock "_gpu_advance_step" to raise an exception when called.

tests/spec_decode/test_spec_decode_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.model_executor.utils import set_random_seed
1313
from vllm.sequence import ExecuteModelRequest, SequenceOutput
1414
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
15-
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
15+
from vllm.spec_decode.draft_model_runner import TP1PP1DraftModelRunner
1616
from vllm.spec_decode.interfaces import SpeculativeProposals
1717
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
1818
SpecDecodeWorkerMetrics)
@@ -929,7 +929,7 @@ def test_correctly_load_weight_for_eagle():
929929
block_size,
930930
num_gpu_blocks,
931931
seed,
932-
model_runner_cls=TP1DraftModelRunner,
932+
model_runner_cls=TP1PP1DraftModelRunner,
933933
)
934934

935935
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")

vllm/config.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,7 @@ def maybe_create_spec_config(
18551855
speculative_model: Optional[str],
18561856
speculative_model_quantization: Optional[str],
18571857
speculative_draft_tensor_parallel_size: Optional[int],
1858+
speculative_draft_pipeline_parallel_size: Optional[int],
18581859
num_speculative_tokens: Optional[int],
18591860
speculative_disable_mqa_scorer: Optional[bool],
18601861
speculative_max_model_len: Optional[int],
@@ -1887,6 +1888,8 @@ def maybe_create_spec_config(
18871888
None, we assume the model weights are not quantized.
18881889
speculative_draft_tensor_parallel_size (Optional[int]): The degree
18891890
of the tensor parallelism for the draft model.
1891+
speculative_draft_pipeline_parallel_size (Optional[int]): The degree
1892+
of the pipeline parallelism for the draft model.
18901893
num_speculative_tokens (Optional[int]): The number of speculative
18911894
tokens, if provided. Will default to the number in the draft
18921895
model config if present, otherwise is required.
@@ -2029,6 +2032,12 @@ def maybe_create_spec_config(
20292032
speculative_draft_tensor_parallel_size,
20302033
draft_hf_config
20312034
)
2035+
speculative_draft_pipeline_parallel_size = \
2036+
SpeculativeConfig._verify_and_get_draft_model_pipeline_parallel_size(
2037+
target_parallel_config,
2038+
speculative_draft_pipeline_parallel_size,
2039+
draft_hf_config
2040+
)
20322041

20332042
draft_model_config.max_model_len = (
20342043
SpeculativeConfig._maybe_override_draft_max_model_len(
@@ -2040,7 +2049,8 @@ def maybe_create_spec_config(
20402049
draft_parallel_config = (
20412050
SpeculativeConfig.create_draft_parallel_config(
20422051
target_parallel_config,
2043-
speculative_draft_tensor_parallel_size, draft_hf_config))
2052+
speculative_draft_tensor_parallel_size,
2053+
speculative_draft_pipeline_parallel_size, draft_hf_config))
20442054

20452055
if num_speculative_tokens is None:
20462056
raise ValueError(
@@ -2136,19 +2146,40 @@ def _verify_and_get_draft_model_tensor_parallel_size(
21362146
f"other value than 1 or target model tensor_parallel_size")
21372147
return speculative_draft_tensor_parallel_size
21382148

2149+
@staticmethod
2150+
def _verify_and_get_draft_model_pipeline_parallel_size(
2151+
target_parallel_config: ParallelConfig,
2152+
speculative_draft_pipeline_parallel_size: Optional[int],
2153+
draft_hf_config: PretrainedConfig) -> int:
2154+
"""
2155+
Verifies and adjusts the tensor parallel size for a draft model
2156+
specified using speculative_draft_pipeline_parallel_size.
2157+
"""
2158+
# If speculative_draft_pipeline_parallel_size is unset then set it
2159+
# appropriately else verify that it is set correctly.
2160+
if speculative_draft_pipeline_parallel_size is None:
2161+
speculative_draft_pipeline_parallel_size = \
2162+
target_parallel_config.pipeline_parallel_size
2163+
elif speculative_draft_pipeline_parallel_size not in (
2164+
1, target_parallel_config.pipeline_parallel_size):
2165+
raise ValueError(
2166+
f"{speculative_draft_pipeline_parallel_size=} cannot be "
2167+
f"other value than 1 or target model pipeline_parallel_size")
2168+
return speculative_draft_pipeline_parallel_size
2169+
21392170
@staticmethod
21402171
def create_draft_parallel_config(
21412172
target_parallel_config: ParallelConfig,
21422173
speculative_draft_tensor_parallel_size: int,
2174+
speculative_draft_pipeline_parallel_size: int,
21432175
draft_hf_config: PretrainedConfig,
21442176
) -> ParallelConfig:
21452177
"""Create a parallel config for use by the draft worker.
21462178
21472179
This is mostly a copy of the target parallel config, except the tp_size.
21482180
"""
21492181
draft_parallel_config = ParallelConfig(
2150-
pipeline_parallel_size=target_parallel_config.
2151-
pipeline_parallel_size,
2182+
pipeline_parallel_size=speculative_draft_pipeline_parallel_size,
21522183
tensor_parallel_size=speculative_draft_tensor_parallel_size,
21532184
distributed_executor_backend=target_parallel_config.
21542185
distributed_executor_backend,

vllm/distributed/parallel_state.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,11 +1024,13 @@ def model_parallel_is_initialized():
10241024

10251025

10261026
_TP_STATE_PATCHED = False
1027+
_PP_STATE_PATCHED = False
10271028

10281029

10291030
@contextmanager
1030-
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
1031-
"""Patch the tp group temporarily until this function ends.
1031+
def patch_model_parallel_group(tp_group: GroupCoordinator,
1032+
pp_group: GroupCoordinator):
1033+
"""Patch the tp and pp group temporarily until this function ends.
10321034
10331035
This method is for draft workers of speculative decoding to run draft model
10341036
with different tp degree from that of target model workers.
@@ -1039,16 +1041,26 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
10391041
global _TP_STATE_PATCHED
10401042
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
10411043

1044+
global _PP_STATE_PATCHED
1045+
assert not _PP_STATE_PATCHED, "Should not call when it's already patched"
1046+
10421047
_TP_STATE_PATCHED = True
10431048
old_tp_group = get_tp_group()
10441049
global _TP
10451050
_TP = tp_group
1051+
1052+
_PP_STATE_PATCHED = True
1053+
old_pp_group = get_pp_group()
1054+
global _PP
1055+
_PP = pp_group
10461056
try:
10471057
yield
10481058
finally:
10491059
# restore the original state
10501060
_TP_STATE_PATCHED = False
10511061
_TP = old_tp_group
1062+
_PP_STATE_PATCHED = False
1063+
_PP = old_pp_group
10521064

10531065

10541066
def get_tensor_model_parallel_world_size():

vllm/engine/arg_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class EngineArgs:
181181
speculative_model: Optional[str] = None
182182
speculative_model_quantization: Optional[str] = None
183183
speculative_draft_tensor_parallel_size: Optional[int] = None
184+
speculative_draft_pipeline_parallel_size: Optional[int] = None
184185
num_speculative_tokens: Optional[int] = None
185186
speculative_disable_mqa_scorer: Optional[bool] = False
186187
speculative_max_model_len: Optional[int] = None
@@ -812,6 +813,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
812813
default=EngineArgs.speculative_draft_tensor_parallel_size,
813814
help='Number of tensor parallel replicas for '
814815
'the draft model in speculative decoding.')
816+
parser.add_argument(
817+
'--speculative-draft-pipeline-parallel-size',
818+
'-spec-draft-pp',
819+
type=int,
820+
default=EngineArgs.speculative_draft_pipeline_parallel_size,
821+
help='Number of pipeline parallel replicas for '
822+
'the draft model in speculative decoding.')
815823

816824
parser.add_argument(
817825
'--speculative-max-model-len',
@@ -1266,6 +1274,8 @@ def create_engine_config(
12661274
self.speculative_model_quantization,
12671275
speculative_draft_tensor_parallel_size = \
12681276
self.speculative_draft_tensor_parallel_size,
1277+
speculative_draft_pipeline_parallel_size = \
1278+
self.speculative_draft_pipeline_parallel_size,
12691279
num_speculative_tokens=self.num_speculative_tokens,
12701280
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
12711281
speculative_disable_by_batch_size=self.

vllm/model_executor/models/eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.model_executor.sampling_metadata import SamplingMetadata
1717
from vllm.sequence import IntermediateTensors
1818

19+
from .interfaces import SupportsPP
1920
from .utils import maybe_prefix
2021

2122
logger = init_logger(__name__)
@@ -41,7 +42,7 @@ def forward(self, x, residual):
4142
return x + residual, None
4243

4344

44-
class EAGLE(nn.Module):
45+
class EAGLE(nn.Module, SupportsPP):
4546
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
4647
Reference implementation: https://github.com/SafeAILab/EAGLE
4748

0 commit comments

Comments
 (0)