Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 2266a5e

Browse files
leiwen83wenlei03
authored andcommitted
[Speculative decoding] Add ngram prompt lookup decoding (vllm-project#4237)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
1 parent 0662b71 commit 2266a5e

File tree

14 files changed

+1004
-319
lines changed

14 files changed

+1004
-319
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from itertools import cycle
23
from typing import List, Optional, Tuple, Union
34

45
import pytest
@@ -185,3 +186,60 @@ def get_output_from_llm_generator(
185186
del llm
186187

187188
return tokens, token_ids
189+
190+
191+
def run_greedy_equality_correctness_test(baseline_llm_generator,
192+
test_llm_generator,
193+
batch_size,
194+
max_output_len,
195+
force_output_len: bool,
196+
print_tokens: bool = False):
197+
"""Helper method that compares the outputs of both the baseline LLM and
198+
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
199+
the same when temperature is zero.
200+
"""
201+
temperature = 0.0
202+
203+
prompts = [
204+
"Hello, my name is",
205+
"The president of the United States is",
206+
"The capital of France is",
207+
"The future of AI is",
208+
"San Francisco is know for its",
209+
"Facebook was created in 2004 by",
210+
"Curious George is a",
211+
"Python 3.11 brings improvements to its",
212+
]
213+
214+
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
215+
216+
# If the test requires that we generated max_output_len tokens, then set the
217+
# sampling params to ignore eos token.
218+
ignore_eos = force_output_len
219+
220+
sampling_params = SamplingParams(
221+
max_tokens=max_output_len,
222+
ignore_eos=ignore_eos,
223+
temperature=temperature,
224+
)
225+
226+
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
227+
test_llm_generator, prompts, sampling_params)
228+
229+
(baseline_batch_tokens,
230+
baseline_batch_token_ids) = get_output_from_llm_generator(
231+
baseline_llm_generator, prompts, sampling_params)
232+
233+
assert len(baseline_batch_token_ids) == len(prompts)
234+
assert len(spec_batch_token_ids) == len(prompts)
235+
236+
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
237+
spec_tokens) in enumerate(
238+
zip(baseline_batch_token_ids, baseline_batch_tokens,
239+
spec_batch_token_ids, spec_batch_tokens)):
240+
if print_tokens:
241+
print(f'{i=} {baseline_tokens=}')
242+
print(f'{i=} {spec_tokens=}')
243+
print(f'{i=} {baseline_token_ids=}')
244+
print(f'{i=} {spec_token_ids=}')
245+
assert baseline_token_ids == spec_token_ids

tests/spec_decode/e2e/test_correctness.py renamed to tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535

3636
from vllm import SamplingParams
3737

38-
from .conftest import get_output_from_llm_generator
38+
from .conftest import (get_output_from_llm_generator,
39+
run_greedy_equality_correctness_test)
3940

4041

4142
@pytest.mark.parametrize(
@@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
545546
batch_size,
546547
max_output_len=output_len,
547548
force_output_len=True)
548-
549-
550-
def run_greedy_equality_correctness_test(baseline_llm_generator,
551-
test_llm_generator,
552-
batch_size,
553-
max_output_len,
554-
force_output_len: bool,
555-
print_tokens: bool = False):
556-
"""Helper method that compares the outputs of both the baseline LLM and
557-
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
558-
the same when temperature is zero.
559-
"""
560-
temperature = 0.0
561-
562-
prompts = [
563-
"Hello, my name is",
564-
"The president of the United States is",
565-
"The capital of France is",
566-
"The future of AI is",
567-
"San Francisco is know for its",
568-
"Facebook was created in 2004 by",
569-
"Curious George is a",
570-
"Python 3.11 brings improvements to its",
571-
]
572-
573-
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
574-
575-
# If the test requires that we generated max_output_len tokens, then set the
576-
# sampling params to ignore eos token.
577-
ignore_eos = force_output_len
578-
579-
sampling_params = SamplingParams(
580-
max_tokens=max_output_len,
581-
ignore_eos=ignore_eos,
582-
temperature=temperature,
583-
)
584-
585-
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
586-
test_llm_generator, prompts, sampling_params)
587-
588-
(baseline_batch_tokens,
589-
baseline_batch_token_ids) = get_output_from_llm_generator(
590-
baseline_llm_generator, prompts, sampling_params)
591-
592-
assert len(baseline_batch_token_ids) == len(prompts)
593-
assert len(spec_batch_token_ids) == len(prompts)
594-
595-
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
596-
spec_tokens) in enumerate(
597-
zip(baseline_batch_token_ids, baseline_batch_tokens,
598-
spec_batch_token_ids, spec_batch_tokens)):
599-
if print_tokens:
600-
print(f'{i=} {baseline_tokens=}')
601-
print(f'{i=} {spec_tokens=}')
602-
print(f'{i=} {baseline_token_ids=}')
603-
print(f'{i=} {spec_token_ids=}')
604-
assert baseline_token_ids == spec_token_ids
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""This docstring details important information on the testing methodology.
2+
3+
Most of the tests rely on "greedy equality", where we expect the output of
4+
speculative decoding on a sequence to exactly match the output of normal non-
5+
speculative decoding.
6+
7+
Since speculative decoding with rejection sampling guarantees that the output
8+
distribution matches the target model's output distribution (up to hardware
9+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
10+
equality.
11+
12+
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
13+
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
14+
Since there is no model is needed for generate the proposal, we could make
15+
the testcase much simpler than drafter multi-step one.
16+
17+
However, we still need to verify below scenario could be passed:
18+
* Batch size 1 greedy equality
19+
* Batch size >1 greedy equality
20+
* Test greedy equality under preemption
21+
* Test greedy equality under various ngram sizes / speculative sizes
22+
23+
With those tests, we can say at least, ngram spec would not break the correctess
24+
for the target model outputs.
25+
"""
26+
27+
import pytest
28+
29+
from .conftest import run_greedy_equality_correctness_test
30+
31+
32+
@pytest.mark.parametrize(
33+
"common_llm_kwargs",
34+
[{
35+
# Skip cuda graph recording for fast test.
36+
"enforce_eager": True,
37+
38+
# Required for spec decode.
39+
"use_v2_block_manager": True,
40+
41+
# Print spec metrics.
42+
"disable_log_stats": False,
43+
}])
44+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
45+
{
46+
"model": "JackFram/llama-68m",
47+
},
48+
])
49+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
50+
@pytest.mark.parametrize("test_llm_kwargs", [
51+
{
52+
"speculative_model": "[ngram]",
53+
"num_speculative_tokens": 5,
54+
"ngram_prompt_lookup_max": 3,
55+
},
56+
])
57+
@pytest.mark.parametrize("output_len", [
58+
256,
59+
])
60+
@pytest.mark.parametrize("batch_size", [1, 64])
61+
@pytest.mark.parametrize("seed", [1])
62+
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
63+
test_llm_generator, batch_size: int,
64+
output_len: int):
65+
"""Verify greedy equality on a tiny model with different batch size."""
66+
run_greedy_equality_correctness_test(baseline_llm_generator,
67+
test_llm_generator,
68+
batch_size,
69+
max_output_len=output_len,
70+
force_output_len=True)
71+
72+
73+
@pytest.mark.parametrize(
74+
"common_llm_kwargs",
75+
[{
76+
"block_size": 8,
77+
# 2 for small prompt, 256//8 for generated.
78+
"num_gpu_blocks_override": 2 + 256 // 8,
79+
"max_model_len": (2 + 256 // 8) * 8,
80+
81+
# Skip cuda graph recording for fast test.
82+
"enforce_eager": True,
83+
84+
# Required for spec decode.
85+
"use_v2_block_manager": True
86+
}])
87+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
88+
{
89+
"model": "JackFram/llama-160m",
90+
},
91+
])
92+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
93+
@pytest.mark.parametrize("test_llm_kwargs", [
94+
{
95+
"speculative_model": "[ngram]",
96+
"num_speculative_tokens": 5,
97+
"ngram_prompt_lookup_max": 3,
98+
},
99+
])
100+
@pytest.mark.parametrize(
101+
"output_len",
102+
[
103+
# Use small output len for fast test.
104+
256,
105+
])
106+
@pytest.mark.parametrize("batch_size", [4])
107+
@pytest.mark.parametrize("seed", [1])
108+
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
109+
test_llm_generator,
110+
batch_size: int,
111+
output_len: int):
112+
"""Verify greedy equality, even when some sequences are preempted mid-
113+
generation.
114+
"""
115+
run_greedy_equality_correctness_test(baseline_llm_generator,
116+
test_llm_generator,
117+
batch_size,
118+
max_output_len=output_len,
119+
force_output_len=True)
120+
121+
122+
@pytest.mark.parametrize(
123+
"common_llm_kwargs",
124+
[{
125+
"model": "JackFram/llama-68m",
126+
127+
# Skip cuda graph recording for fast test.
128+
"enforce_eager": True,
129+
130+
# Required for spec decode.
131+
"use_v2_block_manager": True
132+
}])
133+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
134+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
135+
@pytest.mark.parametrize(
136+
"test_llm_kwargs",
137+
[
138+
{
139+
"speculative_model": "[ngram]",
140+
"num_speculative_tokens": k,
141+
"ngram_prompt_lookup_max": 3,
142+
}
143+
# Try a range of common k, as well as large speculation.
144+
for k in [1, 3, 5]
145+
] + [
146+
{
147+
"speculative_model": "[ngram]",
148+
"num_speculative_tokens": k,
149+
"ngram_prompt_lookup_max": 1,
150+
}
151+
# Try a range of common k, as well as large speculation.
152+
for k in [1, 3, 5]
153+
])
154+
@pytest.mark.parametrize("batch_size", [2])
155+
@pytest.mark.parametrize(
156+
"output_len",
157+
[
158+
# Use smaller output len for fast test.
159+
32,
160+
])
161+
@pytest.mark.parametrize("seed", [1])
162+
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
163+
batch_size: int, output_len: int):
164+
"""Verify that ngram speculative decoding produces exact equality
165+
to without spec decode with many different values of k and
166+
different ngram_prompt_lookup_max.
167+
"""
168+
run_greedy_equality_correctness_test(baseline_llm_generator,
169+
test_llm_generator,
170+
batch_size,
171+
max_output_len=output_len,
172+
force_output_len=True)

0 commit comments

Comments
 (0)