Skip to content

Commit 411e0d2

Browse files
NickLucchewallashss
authored andcommitted
[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (vllm-project#10132)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
1 parent 45844a3 commit 411e0d2

16 files changed

+469
-166
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Optional, Sequence, Tuple, Union
33

44
import pytest
5+
import torch
56

67
from vllm import LLM, SamplingParams
78
from vllm.distributed import cleanup_dist_env_and_memory
@@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
154155
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
155156
assert spec_pos_logprob.rank == -1
156157
assert spec_pos_logprob.logprob == 0.0
158+
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
159+
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
157160
assert spec_pos_logprob_token_id in baseline_pos_logprobs
158161

159162

@@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
244247
batch_size: int,
245248
max_output_len: int,
246249
seed: int = 0,
247-
temperature: float = 0.0):
250+
temperature: float = 0.0,
251+
logprobs: Optional[int] = None):
248252
"""Helper method that compares the outputs of both the baseline LLM and
249253
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
250254
the same when temperature is zero.
@@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
257261
results = []
258262

259263
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
260-
261264
for args, env in ((arg1, env1), (arg2, env2)):
262265
with RemoteOpenAIServer(model,
263266
args,
@@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
269272
prompt=prompts,
270273
max_tokens=max_output_len,
271274
seed=seed,
272-
temperature=temperature)
275+
temperature=temperature,
276+
logprobs=logprobs)
273277

274278
results.append({
275279
"test":
276280
"seeded_sampling",
277281
"text": [choice.text for choice in completion.choices],
282+
"logprobs": [choice.logprobs for choice in completion.choices],
278283
"finish_reason":
279284
[choice.finish_reason for choice in completion.choices],
280285
"usage":
@@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
284289
n = len(results) // 2
285290
arg1_results = results[:n]
286291
arg2_results = results[n:]
292+
# Separate logprobs to avoid asserting exact equality.
293+
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
294+
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
295+
287296
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
288297
assert arg1_result == arg2_result, (
289298
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
290299
f"{arg1_result=} != {arg2_result=}")
300+
if logprobs:
301+
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
302+
for l1, l2 in zip(logs1, logs2):
303+
assert l1.tokens == l2.tokens

tests/spec_decode/e2e/test_integration_dist_tp2.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
tensor parallelism.
33
"""
44

5+
from typing import Optional
6+
57
import pytest
68
import torch
79

@@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
154156
"--speculative-draft-tensor-parallel-size",
155157
"1",
156158
])])
159+
@pytest.mark.parametrize("logprobs", [None, 2])
157160
@pytest.mark.parametrize("batch_size", [2])
158161
@pytest.mark.parametrize("seed", [1])
159162
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
160163
per_test_common_llm_kwargs,
161164
baseline_llm_kwargs, test_llm_kwargs,
165+
logprobs: Optional[int],
162166
batch_size: int, seed: int):
163167
"""Verify spec decode works well with same and different TP size for
164168
the draft model with chunked prefill.
165169
"""
170+
if logprobs:
171+
test_llm_kwargs.extend(
172+
["--disable_logprobs_during_spec_decoding", "False"])
166173
run_equality_correctness_test_tp(model,
167174
common_llm_kwargs,
168175
per_test_common_llm_kwargs,
@@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
171178
batch_size,
172179
max_output_len=32,
173180
seed=seed,
174-
temperature=0.0)
181+
temperature=0.0,
182+
logprobs=logprobs)

tests/spec_decode/e2e/test_logprobs.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,27 @@
44

55
from vllm import SamplingParams
66

7+
from ..utils import maybe_enable_chunked_prefill
78
from .conftest import run_equality_correctness_test
89

910

1011
@pytest.mark.parametrize(
1112
"common_llm_kwargs",
1213
[{
13-
"model_name": "JackFram/llama-68m",
14+
"model_name": "JackFram/llama-160m",
1415
1516
# Skip cuda graph recording for fast test.
16-
"enforce_eager": True,
17+
"enforce_eager": True
1718
}])
1819
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
1920
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
2021
@pytest.mark.parametrize("test_llm_kwargs",
2122
[{
22-
"speculative_model": "JackFram/llama-160m",
23+
"speculative_model": "JackFram/llama-68m",
2324
"num_speculative_tokens": 3,
2425
"disable_logprobs_during_spec_decoding": False,
2526
}, {
26-
"speculative_model": "JackFram/llama-160m",
27+
"speculative_model": "JackFram/llama-68m",
2728
"num_speculative_tokens": 3,
2829
"disable_logprobs_during_spec_decoding": True,
2930
}])
@@ -36,12 +37,15 @@
3637
])
3738
@pytest.mark.parametrize("seed", [1])
3839
@pytest.mark.parametrize("logprobs", [1, 6])
40+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
3941
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
4042
per_test_common_llm_kwargs, baseline_llm_kwargs,
4143
test_llm_kwargs, batch_size: int, output_len: int,
42-
seed: int, logprobs: int):
43-
"""Verify output logprobs are equal with and without speculative decoding.
44+
seed: int, logprobs: int, prefill_chunk_size: int):
45+
"""Verify output logprobs are equal with and without speculative decoding,
46+
as well as with and without chunked prefill.
4447
"""
48+
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
4549
run_equality_correctness_test(vllm_runner,
4650
common_llm_kwargs,
4751
per_test_common_llm_kwargs,

tests/spec_decode/e2e/test_medusa_correctness.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytest
2323

24+
from ..utils import maybe_enable_chunked_prefill
2425
from .conftest import run_equality_correctness_test
2526

2627
# main model
@@ -67,12 +68,14 @@
6768
])
6869
@pytest.mark.parametrize("batch_size", [1, 32])
6970
@pytest.mark.parametrize("seed", [1])
71+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
7072
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
7173
per_test_common_llm_kwargs,
7274
baseline_llm_kwargs, test_llm_kwargs,
7375
batch_size: int, output_len: int,
74-
seed: int):
76+
seed: int, prefill_chunk_size: int):
7577
"""Verify greedy equality with different batch size."""
78+
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
7679
run_equality_correctness_test(vllm_runner,
7780
common_llm_kwargs,
7881
per_test_common_llm_kwargs,
@@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
119122
@pytest.mark.parametrize("batch_size", [8])
120123
@pytest.mark.parametrize("seed", [1])
121124
@pytest.mark.parametrize("logprobs", [1, 6])
125+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
122126
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
123127
per_test_common_llm_kwargs,
124128
baseline_llm_kwargs, test_llm_kwargs,
125129
batch_size: int, output_len: int,
126-
seed: int, logprobs: int):
130+
seed: int, logprobs: int,
131+
prefill_chunk_size: int):
127132
"""Verify greedy equality with different batch size."""
133+
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
128134
run_equality_correctness_test(vllm_runner,
129135
common_llm_kwargs,
130136
per_test_common_llm_kwargs,
@@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
167173
])
168174
@pytest.mark.parametrize("batch_size", [1, 32])
169175
@pytest.mark.parametrize("seed", [1])
176+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
170177
def test_medusa_e2e_greedy_correctness_cuda_graph(
171178
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
172179
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
173-
seed: int):
180+
seed: int, prefill_chunk_size: int):
174181
"""Verify greedy equality with cuda graph enabled and different
175182
batch sizes."""
183+
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
176184
run_equality_correctness_test(vllm_runner,
177185
common_llm_kwargs,
178186
per_test_common_llm_kwargs,
@@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
217225
])
218226
@pytest.mark.parametrize("batch_size", [4])
219227
@pytest.mark.parametrize("seed", [1])
228+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
220229
def test_medusa_e2e_greedy_correctness_with_preemption(
221230
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
222231
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
223-
seed: int):
232+
seed: int, prefill_chunk_size: int):
224233
"""Verify greedy equality, even when some sequences are preempted mid-
225234
generation.
226235
"""
236+
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
227237
run_equality_correctness_test(vllm_runner,
228238
common_llm_kwargs,
229239
per_test_common_llm_kwargs,
@@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
267277
32,
268278
])
269279
@pytest.mark.parametrize("seed", [1])
280+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
270281
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
271282
per_test_common_llm_kwargs, baseline_llm_kwargs,
272283
test_llm_kwargs, batch_size: int, output_len: int,
273-
seed: int):
284+
seed: int, prefill_chunk_size: int):
274285
"""Verify that medusa speculative decoding produces exact equality
275286
to without spec decode with different values of num_speculative_tokens.
276287
"""
288+
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
277289
run_equality_correctness_test(vllm_runner,
278290
common_llm_kwargs,
279291
per_test_common_llm_kwargs,
@@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
313325
32,
314326
])
315327
@pytest.mark.parametrize("seed", [1])
328+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
316329
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
317330
per_test_common_llm_kwargs, baseline_llm_kwargs,
318331
test_llm_kwargs, batch_size: int,
319-
output_len: int, seed: int):
332+
output_len: int, seed: int,
333+
prefill_chunk_size: int):
320334
"""Verify that medusa speculative decoding produces exact equality
321335
to without spec decode when speculation is disabled for large
322336
batch sizes.
323337
"""
338+
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
324339
run_equality_correctness_test(vllm_runner,
325340
common_llm_kwargs,
326341
per_test_common_llm_kwargs,
@@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
361376
32,
362377
])
363378
@pytest.mark.parametrize("seed", [1])
379+
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
364380
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
365381
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
366-
output_len: int, seed: int):
382+
output_len: int, seed: int, prefill_chunk_size: int):
367383
"""Verify that speculative decoding generates the same output
368384
with batch expansion scorer and mqa scorer.
369385
"""
386+
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
370387
run_equality_correctness_test(vllm_runner,
371388
common_llm_kwargs,
372389
per_test_common_llm_kwargs,

0 commit comments

Comments
 (0)