Skip to content

Commit 78f283f

Browse files
committed
[Spec Decode] Make speculative decoding compatible with pipeline parallelism
Signed-off-by: Xin Yang <xyangx@amazon.com>
1 parent 038bede commit 78f283f

20 files changed

+383
-110
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests which cover integration of the speculative decoding framework with
3+
pipeline parallelism.
4+
"""
5+
6+
from typing import Optional
7+
8+
import pytest
9+
import torch
10+
11+
from vllm.platforms import current_platform
12+
13+
from .conftest import run_equality_correctness_test_tp
14+
15+
16+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
17+
reason="Need at least 2 GPUs to run the test.")
18+
@pytest.mark.parametrize(
19+
"common_llm_kwargs",
20+
[[
21+
# Skip cuda graph recording for fast test.
22+
"--enforce-eager",
23+
"--pipeline-parallel-size",
24+
"2",
25+
26+
# precision
27+
"--dtype",
28+
"bfloat16",
29+
]])
30+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
31+
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
32+
@pytest.mark.parametrize("model, test_llm_kwargs",
33+
[("JackFram/llama-68m", [
34+
"--speculative-model",
35+
"JackFram/llama-68m",
36+
"--num-speculative-tokens",
37+
"5",
38+
"--speculative-draft-pipeline-parallel-size",
39+
"1",
40+
]),
41+
("ibm-granite/granite-3b-code-instruct", [
42+
"--speculative-model",
43+
"ibm-granite/granite-3b-code-instruct",
44+
"--num-speculative-tokens",
45+
"5",
46+
"--speculative-draft-pipeline-parallel-size",
47+
"1",
48+
])])
49+
@pytest.mark.parametrize("batch_size", [2])
50+
@pytest.mark.parametrize("seed", [1])
51+
def test_draft_model_pp_lt_target_model_pp2(model, common_llm_kwargs,
52+
per_test_common_llm_kwargs,
53+
baseline_llm_kwargs,
54+
test_llm_kwargs, batch_size: int,
55+
seed: int):
56+
"""Verify spec decode works well with smaller pp for draft models.
57+
"""
58+
if current_platform.is_rocm():
59+
pytest.skip("hip is not well-supported yet")
60+
run_equality_correctness_test_tp(model,
61+
common_llm_kwargs,
62+
per_test_common_llm_kwargs,
63+
baseline_llm_kwargs,
64+
test_llm_kwargs,
65+
batch_size,
66+
max_output_len=32,
67+
seed=seed,
68+
temperature=0.0)
69+
70+
71+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
72+
reason="Need at least 2 GPUs to run the test.")
73+
@pytest.mark.parametrize(
74+
"common_llm_kwargs",
75+
[[
76+
# Skip cuda graph recording for fast test.
77+
"--enforce-eager",
78+
"--pipeline-parallel-size",
79+
"2",
80+
81+
# precision
82+
"--dtype",
83+
"bfloat16",
84+
]])
85+
@pytest.mark.parametrize(
86+
"per_test_common_llm_kwargs",
87+
[["--enable-chunked-prefill", "False"],
88+
[
89+
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
90+
"--max-num-seqs", "4"
91+
]])
92+
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
93+
@pytest.mark.parametrize("model, test_llm_kwargs", [("JackFram/llama-68m", [
94+
"--speculative-model",
95+
"JackFram/llama-68m",
96+
"--num-speculative-tokens",
97+
"3",
98+
"--speculative-draft-pipeline-parallel-size",
99+
"1",
100+
])])
101+
@pytest.mark.parametrize("logprobs", [None, 2])
102+
@pytest.mark.parametrize("batch_size", [2])
103+
@pytest.mark.parametrize("seed", [1])
104+
def test_spec_decode_chunked_prefill_pp2(model, common_llm_kwargs,
105+
per_test_common_llm_kwargs,
106+
baseline_llm_kwargs, test_llm_kwargs,
107+
logprobs: Optional[int],
108+
batch_size: int, seed: int):
109+
"""Verify spec decode works well with same and different PP size for
110+
the draft model with chunked prefill.
111+
"""
112+
if logprobs:
113+
test_llm_kwargs.extend(
114+
["--disable-logprobs-during-spec-decoding", "False"])
115+
run_equality_correctness_test_tp(model,
116+
common_llm_kwargs,
117+
per_test_common_llm_kwargs,
118+
baseline_llm_kwargs,
119+
test_llm_kwargs,
120+
batch_size,
121+
max_output_len=32,
122+
seed=seed,
123+
temperature=0.0,
124+
logprobs=logprobs)

tests/spec_decode/test_multi_step_worker.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.model_executor.utils import set_random_seed
1313
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
1414
get_all_seq_ids)
15-
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
15+
from vllm.spec_decode.draft_model_runner import TP1PP1DraftModelRunner
1616
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1717
from vllm.spec_decode.top1_proposer import Top1Proposer
1818
from vllm.worker.worker import Worker
@@ -91,7 +91,7 @@ def test_same_output_for_single_step():
9191
block_size,
9292
num_gpu_blocks,
9393
seed,
94-
model_runner_cls=TP1DraftModelRunner,
94+
model_runner_cls=TP1PP1DraftModelRunner,
9595
)
9696
worker = create_worker(
9797
Worker,
@@ -304,7 +304,7 @@ def test_multi_step_with_batch_expansion_correct_output():
304304
block_size,
305305
num_gpu_blocks,
306306
seed,
307-
model_runner_cls=TP1DraftModelRunner,
307+
model_runner_cls=TP1PP1DraftModelRunner,
308308
)
309309
multi_step_worker.set_include_gpu_probs_tensor()
310310
worker = create_worker(
@@ -399,7 +399,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
399399
block_size,
400400
num_gpu_blocks,
401401
seed,
402-
model_runner_cls=TP1DraftModelRunner,
402+
model_runner_cls=TP1PP1DraftModelRunner,
403403
)
404404
multi_step_worker.set_include_gpu_probs_tensor()
405405
worker = create_worker(
@@ -502,13 +502,14 @@ def test_multi_step_correct_kvcache(num_steps, attn_backend):
502502

503503
with global_force_attn_backend_context_manager(attn_backend):
504504
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
505-
multi_step_worker = create_worker(MultiStepWorker,
506-
model_name,
507-
block_size,
508-
num_gpu_blocks,
509-
seed,
510-
model_runner_cls=TP1DraftModelRunner,
511-
dtype=dtype)
505+
multi_step_worker = create_worker(
506+
MultiStepWorker,
507+
model_name,
508+
block_size,
509+
num_gpu_blocks,
510+
seed,
511+
model_runner_cls=TP1PP1DraftModelRunner,
512+
dtype=dtype)
512513
multi_step_worker.set_include_gpu_probs_tensor()
513514
worker = create_worker(Worker,
514515
model_name,
@@ -771,7 +772,7 @@ def test_use_draft_model_runner_advance_step():
771772
block_size,
772773
num_gpu_blocks,
773774
seed,
774-
model_runner_cls=TP1DraftModelRunner,
775+
model_runner_cls=TP1PP1DraftModelRunner,
775776
)
776777

777778
# Mock "_gpu_advance_step" to raise an exception when called.

tests/spec_decode/test_spec_decode_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.model_executor.utils import set_random_seed
1313
from vllm.sequence import ExecuteModelRequest, SequenceOutput
1414
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
15-
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
15+
from vllm.spec_decode.draft_model_runner import TP1PP1DraftModelRunner
1616
from vllm.spec_decode.interfaces import SpeculativeProposals
1717
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
1818
SpecDecodeWorkerMetrics)
@@ -929,7 +929,7 @@ def test_correctly_load_weight_for_eagle():
929929
block_size,
930930
num_gpu_blocks,
931931
seed,
932-
model_runner_cls=TP1DraftModelRunner,
932+
model_runner_cls=TP1PP1DraftModelRunner,
933933
)
934934

935935
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")

vllm/attention/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(
152152
# by bind_kv_cache
153153
# this variable will not be accessed if use_direct_call is True
154154
self.kv_cache = [
155-
torch.tensor([]) for _ in range(get_current_vllm_config(
156-
).parallel_config.pipeline_parallel_size)
155+
torch.tensor([]) for _ in range(
156+
get_current_vllm_config().parallel_config.virtual_engine_size)
157157
]
158158

159159
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)

vllm/config.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,7 @@ class ParallelConfig:
13981398
data_parallel_master_ip: str = "127.0.0.1"
13991399
data_parallel_master_port: int = 29500 # Port of the data parallel master.
14001400
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
1401+
virtual_engine_size: int = 1 # Number of virtual engine.
14011402

14021403
# Maximum number of multiple batches
14031404
# when load model sequentially. To avoid RAM OOM when using tensor
@@ -1960,6 +1961,7 @@ class SpeculativeConfig:
19601961
method: Optional[str] = None
19611962
acceptance_method: str = "rejection_sampler"
19621963
draft_tensor_parallel_size: Optional[int] = None
1964+
draft_pipeline_parallel_size: Optional[int] = None
19631965
disable_logprobs: bool = True
19641966

19651967
model: Optional[str] = None
@@ -2173,6 +2175,13 @@ def __post_init__(self):
21732175
self.draft_model_config.hf_config
21742176
)
21752177

2178+
self.draft_pipeline_parallel_size = \
2179+
SpeculativeConfig._verify_and_get_draft_pp(
2180+
self.target_parallel_config,
2181+
self.draft_pipeline_parallel_size,
2182+
self.draft_model_config.hf_config
2183+
)
2184+
21762185
self.draft_model_config.max_model_len = (
21772186
SpeculativeConfig._maybe_override_draft_max_model_len(
21782187
self.max_model_len,
@@ -2183,7 +2192,8 @@ def __post_init__(self):
21832192
self.draft_parallel_config = (
21842193
SpeculativeConfig.create_draft_parallel_config(
21852194
self.target_parallel_config,
2186-
self.draft_tensor_parallel_size))
2195+
self.draft_tensor_parallel_size,
2196+
self.draft_pipeline_parallel_size))
21872197

21882198
if self.acceptance_method == "typical_acceptance_sampler":
21892199
if self.posterior_threshold is None:
@@ -2257,19 +2267,41 @@ def _verify_and_get_draft_tp(
22572267
f"other value than 1 or target model tensor_parallel_size")
22582268
return speculative_draft_tensor_parallel_size
22592269

2270+
@staticmethod
2271+
def _verify_and_get_draft_pp(
2272+
target_parallel_config: ParallelConfig,
2273+
speculative_draft_pipeline_parallel_size: Optional[int],
2274+
draft_hf_config: PretrainedConfig) -> int:
2275+
"""
2276+
Verifies and adjusts the tensor parallel size for a draft model
2277+
specified using speculative_draft_pipeline_parallel_size.
2278+
"""
2279+
# If speculative_draft_pipeline_parallel_size is unset then set it
2280+
# appropriately else verify that it is set correctly.
2281+
if speculative_draft_pipeline_parallel_size is None:
2282+
speculative_draft_pipeline_parallel_size = \
2283+
target_parallel_config.pipeline_parallel_size
2284+
elif speculative_draft_pipeline_parallel_size not in (
2285+
1, target_parallel_config.pipeline_parallel_size):
2286+
raise ValueError(
2287+
f"{speculative_draft_pipeline_parallel_size=} cannot be "
2288+
f"other value than 1 or target model pipeline_parallel_size")
2289+
return speculative_draft_pipeline_parallel_size
2290+
22602291
@staticmethod
22612292
def create_draft_parallel_config(
22622293
target_parallel_config: ParallelConfig,
22632294
speculative_draft_tensor_parallel_size: int,
2295+
speculative_draft_pipeline_parallel_size: int,
22642296
) -> ParallelConfig:
22652297
"""Create a parallel config for use by the draft worker.
22662298
22672299
This is mostly a copy of the target parallel config, except the tp_size.
22682300
"""
22692301
draft_parallel_config = ParallelConfig(
2270-
pipeline_parallel_size=target_parallel_config.
2271-
pipeline_parallel_size,
2302+
pipeline_parallel_size=speculative_draft_pipeline_parallel_size,
22722303
tensor_parallel_size=speculative_draft_tensor_parallel_size,
2304+
virtual_engine_size=target_parallel_config.virtual_engine_size,
22732305
distributed_executor_backend=target_parallel_config.
22742306
distributed_executor_backend,
22752307
max_parallel_loading_workers=target_parallel_config.

vllm/distributed/parallel_state.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,11 +1011,13 @@ def model_parallel_is_initialized():
10111011

10121012

10131013
_TP_STATE_PATCHED = False
1014+
_PP_STATE_PATCHED = False
10141015

10151016

10161017
@contextmanager
1017-
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
1018-
"""Patch the tp group temporarily until this function ends.
1018+
def patch_model_parallel_group(tp_group: GroupCoordinator,
1019+
pp_group: GroupCoordinator):
1020+
"""Patch the tp and pp group temporarily until this function ends.
10191021
10201022
This method is for draft workers of speculative decoding to run draft model
10211023
with different tp degree from that of target model workers.
@@ -1026,16 +1028,26 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
10261028
global _TP_STATE_PATCHED
10271029
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
10281030

1031+
global _PP_STATE_PATCHED
1032+
assert not _PP_STATE_PATCHED, "Should not call when it's already patched"
1033+
10291034
_TP_STATE_PATCHED = True
10301035
old_tp_group = get_tp_group()
10311036
global _TP
10321037
_TP = tp_group
1038+
1039+
_PP_STATE_PATCHED = True
1040+
old_pp_group = get_pp_group()
1041+
global _PP
1042+
_PP = pp_group
10331043
try:
10341044
yield
10351045
finally:
10361046
# restore the original state
10371047
_TP_STATE_PATCHED = False
10381048
_TP = old_tp_group
1049+
_PP_STATE_PATCHED = False
1050+
_PP = old_pp_group
10391051

10401052

10411053
def get_tensor_model_parallel_world_size():

vllm/engine/arg_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class EngineArgs:
187187
speculative_model: Optional[str] = None
188188
speculative_model_quantization: Optional[str] = None
189189
speculative_draft_tensor_parallel_size: Optional[int] = None
190+
speculative_draft_pipeline_parallel_size: Optional[int] = None
190191
num_speculative_tokens: Optional[int] = None
191192
speculative_disable_mqa_scorer: Optional[bool] = False
192193
speculative_max_model_len: Optional[int] = None
@@ -843,6 +844,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
843844
default=EngineArgs.speculative_draft_tensor_parallel_size,
844845
help='Number of tensor parallel replicas for '
845846
'the draft model in speculative decoding.')
847+
parser.add_argument(
848+
'--speculative-draft-pipeline-parallel-size',
849+
'-spec-draft-pp',
850+
type=int,
851+
default=EngineArgs.speculative_draft_pipeline_parallel_size,
852+
help='Number of pipeline parallel replicas for '
853+
'the draft model in speculative decoding.')
846854

847855
parser.add_argument(
848856
'--speculative-max-model-len',
@@ -1257,6 +1265,8 @@ def create_speculative_config(
12571265
"max_model_len": self.speculative_max_model_len,
12581266
"draft_tensor_parallel_size":
12591267
self.speculative_draft_tensor_parallel_size,
1268+
"draft_pipeline_parallel_size":
1269+
self.speculative_draft_pipeline_parallel_size,
12601270
"num_speculative_tokens": self.num_speculative_tokens,
12611271
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
12621272
"disable_by_batch_size":
@@ -1369,6 +1379,7 @@ def create_engine_config(
13691379
parallel_config = ParallelConfig(
13701380
pipeline_parallel_size=self.pipeline_parallel_size,
13711381
tensor_parallel_size=self.tensor_parallel_size,
1382+
virtual_engine_size=self.pipeline_parallel_size,
13721383
data_parallel_size=self.data_parallel_size,
13731384
enable_expert_parallel=self.enable_expert_parallel,
13741385
max_parallel_loading_workers=self.max_parallel_loading_workers,

0 commit comments

Comments
 (0)