Skip to content

Commit 31e3a89

Browse files
ShangmingCaierictang000
authored andcommitted
[V1][Usage] Refactor speculative decoding configuration and tests (vllm-project#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
1 parent 70e1cdd commit 31e3a89

20 files changed

+1061
-808
lines changed

docs/source/features/spec_decode.md

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
3030
llm = LLM(
3131
model="facebook/opt-6.7b",
3232
tensor_parallel_size=1,
33-
speculative_model="facebook/opt-125m",
34-
num_speculative_tokens=5,
33+
speculative_config={
34+
"model": "facebook/opt-125m",
35+
"num_speculative_tokens": 5,
36+
},
3537
)
3638
outputs = llm.generate(prompts, sampling_params)
3739

@@ -45,10 +47,14 @@ To perform the same with an online mode launch the server:
4547

4648
```bash
4749
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
48-
--seed 42 -tp 1 --speculative_model facebook/opt-125m \
49-
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
50+
--seed 42 -tp 1 --gpu_memory_utilization 0.8 \
51+
--speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'
5052
```
5153

54+
:::{warning}
55+
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
56+
:::
57+
5258
Then use a client:
5359

5460
```python
@@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
101107
llm = LLM(
102108
model="facebook/opt-6.7b",
103109
tensor_parallel_size=1,
104-
speculative_model="[ngram]",
105-
num_speculative_tokens=5,
106-
ngram_prompt_lookup_max=4,
110+
speculative_config={
111+
"method": "ngram",
112+
"num_speculative_tokens": 5,
113+
"prompt_lookup_max": 4,
114+
},
107115
)
108116
outputs = llm.generate(prompts, sampling_params)
109117

@@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
131139
llm = LLM(
132140
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
133141
tensor_parallel_size=4,
134-
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
135-
speculative_draft_tensor_parallel_size=1,
142+
speculative_config={
143+
"model": "ibm-ai-platform/llama3-70b-accelerator",
144+
"draft_tensor_parallel_size": 1,
145+
},
136146
)
137147
outputs = llm.generate(prompts, sampling_params)
138148

@@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
175185
llm = LLM(
176186
model="meta-llama/Meta-Llama-3-8B-Instruct",
177187
tensor_parallel_size=4,
178-
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
179-
speculative_draft_tensor_parallel_size=1,
188+
speculative_config={
189+
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
190+
"draft_tensor_parallel_size": 1,
191+
},
180192
)
181193

182194
outputs = llm.generate(prompts, sampling_params)
@@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models:
194206
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
195207
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
196208
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
197-
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
198-
the latest version of vLLM, please leave a comment or raise an issue.
209+
and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue.
199210

200211
2. The EAGLE based draft models need to be run without tensor parallelism
201-
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
212+
(i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although
202213
it is possible to run the main model using tensor parallelism (see example above).
203214

204215
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is

examples/offline_inference/mlpspeculator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def time_generation(llm: LLM, prompts: list[str],
5050
# Create an LLM with spec decoding
5151
llm = LLM(
5252
model="meta-llama/Llama-2-13b-chat-hf",
53-
speculative_model="ibm-ai-platform/llama-13b-accelerator",
53+
speculative_config={
54+
"model": "ibm-ai-platform/llama-13b-accelerator",
55+
},
5456
)
5557

5658
print("With speculation")

tests/spec_decode/e2e/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def generate():
5656
def maybe_assert_ngram_worker(llm):
5757
# Verify the proposer worker is ngram if ngram is specified.
5858
if (llm.llm_engine.speculative_config is not None
59-
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
59+
and llm.llm_engine.speculative_config.method == "ngram"):
6060
from vllm.spec_decode.ngram_worker import NGramWorker
6161
assert isinstance(
6262
llm.llm_engine.model_executor.driver_worker.proposer_worker,

tests/spec_decode/e2e/test_compatibility.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,39 @@
77
from .conftest import get_output_from_llm_generator
88

99

10-
@pytest.mark.parametrize("common_llm_kwargs", [{
11-
"model": "meta-llama/Llama-3.2-1B-Instruct",
12-
"speculative_model": "JackFram/llama-68m",
13-
"num_speculative_tokens": 5,
14-
}])
10+
@pytest.mark.parametrize("common_llm_kwargs",
11+
[{
12+
"model": "meta-llama/Llama-3.2-1B-Instruct",
13+
}])
1514
@pytest.mark.parametrize(
1615
"per_test_common_llm_kwargs",
1716
[
1817
{
1918
# Speculative max model len > overridden max model len should raise.
19+
"speculative_config": {
20+
"model": "JackFram/llama-68m",
21+
"num_speculative_tokens": 5,
22+
"max_model_len": 129,
23+
},
2024
"max_model_len": 128,
21-
"speculative_max_model_len": 129,
2225
},
2326
{
2427
# Speculative max model len > draft max model len should raise.
2528
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
26-
"speculative_max_model_len": 2048 + 1,
29+
"speculative_config": {
30+
"model": "JackFram/llama-68m",
31+
"num_speculative_tokens": 5,
32+
"max_model_len": 2048 + 1,
33+
},
2734
},
2835
{
2936
# Speculative max model len > target max model len should raise.
30-
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
31-
"speculative_max_model_len": 131072 + 1,
37+
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
38+
"speculative_config": {
39+
"model": "JackFram/llama-68m",
40+
"num_speculative_tokens": 5,
41+
"max_model_len": 131072 + 1,
42+
},
3243
},
3344
])
3445
@pytest.mark.parametrize("test_llm_kwargs", [{}])

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@
5757
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
5858
@pytest.mark.parametrize("test_llm_kwargs", [
5959
{
60-
"speculative_model": SPEC_MODEL,
61-
"num_speculative_tokens": MAX_SPEC_TOKENS,
60+
"speculative_config": {
61+
"model": SPEC_MODEL,
62+
"num_speculative_tokens": MAX_SPEC_TOKENS,
63+
},
6264
},
6365
])
6466
@pytest.mark.parametrize("output_len", [
@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
9597
}])
9698
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
9799
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
98-
@pytest.mark.parametrize("test_llm_kwargs", [
99-
{
100-
"speculative_model": SPEC_MODEL,
100+
@pytest.mark.parametrize("test_llm_kwargs", [{
101+
"speculative_config": {
102+
"model": SPEC_MODEL,
101103
"num_speculative_tokens": MAX_SPEC_TOKENS,
102-
"disable_logprobs_during_spec_decoding": False,
104+
"disable_logprobs": False,
103105
},
104-
{
105-
"speculative_model": SPEC_MODEL,
106+
}, {
107+
"speculative_config": {
108+
"model": SPEC_MODEL,
106109
"num_speculative_tokens": MAX_SPEC_TOKENS,
107-
"disable_logprobs_during_spec_decoding": True,
110+
"disable_logprobs": True,
108111
},
109-
])
112+
}])
110113
@pytest.mark.parametrize("output_len", [
111114
128,
112115
])
@@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
119122
batch_size: int, output_len: int, seed: int,
120123
logprobs: int):
121124

122-
run_equality_correctness_test(vllm_runner,
123-
common_llm_kwargs,
124-
per_test_common_llm_kwargs,
125-
baseline_llm_kwargs,
126-
test_llm_kwargs,
127-
batch_size,
128-
output_len,
129-
seed,
130-
logprobs=logprobs,
131-
prompt_logprobs=logprobs,
132-
disable_logprobs=test_llm_kwargs[
133-
'disable_logprobs_during_spec_decoding'])
125+
run_equality_correctness_test(
126+
vllm_runner,
127+
common_llm_kwargs,
128+
per_test_common_llm_kwargs,
129+
baseline_llm_kwargs,
130+
test_llm_kwargs,
131+
batch_size,
132+
output_len,
133+
seed,
134+
logprobs=logprobs,
135+
prompt_logprobs=logprobs,
136+
disable_logprobs=test_llm_kwargs["speculative_config"]
137+
["disable_logprobs"])
134138

135139

136140
@pytest.mark.parametrize(
@@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
151155
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
152156
@pytest.mark.parametrize("test_llm_kwargs", [
153157
{
154-
"speculative_model": SPEC_MODEL,
155-
"num_speculative_tokens": MAX_SPEC_TOKENS,
158+
"speculative_config": {
159+
"model": SPEC_MODEL,
160+
"num_speculative_tokens": MAX_SPEC_TOKENS,
161+
},
156162
},
157163
])
158164
@pytest.mark.parametrize("output_len", [
@@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
193199
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
194200
@pytest.mark.parametrize("test_llm_kwargs", [
195201
{
196-
"speculative_model": SPEC_MODEL,
197-
"num_speculative_tokens": MAX_SPEC_TOKENS,
202+
"speculative_config": {
203+
"model": SPEC_MODEL,
204+
"num_speculative_tokens": MAX_SPEC_TOKENS,
205+
},
198206
},
199207
])
200208
@pytest.mark.parametrize(
@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
236244
"test_llm_kwargs",
237245
[
238246
{
239-
"speculative_model": SPEC_MODEL,
240-
"num_speculative_tokens": k,
247+
"speculative_config": {
248+
"model": SPEC_MODEL,
249+
"num_speculative_tokens": k,
250+
},
241251
}
242252
# Try a range of num. speculative tokens
243253
for k in range(1, 1 + MAX_SPEC_TOKENS)
@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
277287
}])
278288
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
279289
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
280-
@pytest.mark.parametrize("test_llm_kwargs",
281-
[{
282-
"speculative_model": SPEC_MODEL,
283-
"num_speculative_tokens": MAX_SPEC_TOKENS,
284-
"speculative_disable_by_batch_size": 4
285-
}])
290+
@pytest.mark.parametrize("test_llm_kwargs", [{
291+
"speculative_config": {
292+
"model": SPEC_MODEL,
293+
"num_speculative_tokens": MAX_SPEC_TOKENS,
294+
"disable_by_batch_size": 4,
295+
},
296+
}])
286297
@pytest.mark.parametrize("batch_size", [1, 5])
287298
@pytest.mark.parametrize(
288299
"output_len",
@@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
324335
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
325336
@pytest.mark.parametrize("test_llm_kwargs", [
326337
{
327-
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
328-
"num_speculative_tokens": MAX_SPEC_TOKENS,
338+
"speculative_config": {
339+
"model": "yuhuili/EAGLE-llama2-chat-7B",
340+
"num_speculative_tokens": MAX_SPEC_TOKENS,
341+
},
329342
},
330343
])
331344
@pytest.mark.parametrize(
@@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
372385
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
373386
@pytest.mark.parametrize("test_llm_kwargs", [
374387
{
375-
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
376-
"num_speculative_tokens": MAX_SPEC_TOKENS,
388+
"speculative_config": {
389+
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
390+
"num_speculative_tokens": MAX_SPEC_TOKENS,
391+
},
377392
},
378393
])
379394
@pytest.mark.parametrize(
@@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
420435
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
421436
@pytest.mark.parametrize("test_llm_kwargs", [
422437
{
423-
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
424-
"num_speculative_tokens": MAX_SPEC_TOKENS,
438+
"speculative_config": {
439+
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
440+
"num_speculative_tokens": MAX_SPEC_TOKENS,
441+
},
425442
},
426443
])
427444
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)