Skip to content

Commit e870757

Browse files
wenlei03root
authored andcommitted
[Speculative decoding] Add ngram prompt lookup decoding
Algo details could refer to this blog post: https://huggingface.co/blog/assisted-generation Code directly refer to transformers's current implementation. huggingface/transformers#27775 Since we directly get draft from prompt, there is no need another model or modified model to get the proposal, it would be the most convenient way to enjoy the speedup of speculation.
1 parent a395a63 commit e870757

14 files changed

+1063
-278
lines changed

tests/spec_decode/e2e/test_compatibility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
9090
@pytest.mark.parametrize(
9191
"common_llm_kwargs",
9292
[{
93-
"model": "meta-llama/Llama-2-7b-chat-hf",
93+
"model": "NousResearch/Llama-2-7b-chat-hf",
9494
"speculative_model": "JackFram/llama-68m",
9595
"num_speculative_tokens": 5,
9696
@@ -112,7 +112,7 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
112112
},
113113
{
114114
# Speculative max model len > target max model len should raise.
115-
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
115+
# https://huggingface.co/NousResearch/Llama-2-7b-chat-hf/blob/37892f30c23786c0d5367d80481fa0d9fba93cf8/config.json#L11
116116
"speculative_max_model_len": 4096 + 1,
117117
},
118118
])

tests/spec_decode/e2e/test_correctness.py renamed to tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
264264
"common_llm_kwargs",
265265
[{
266266
# A "real" model (not tiny).
267-
"model": "meta-llama/Llama-2-7b-chat-hf",
267+
"model": "NousResearch/Llama-2-7b-chat-hf",
268268
269269
# Skip cuda graph recording for fast test.
270270
"enforce_eager": True,
@@ -308,7 +308,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
308308
"common_llm_kwargs",
309309
[{
310310
# A "real" model (not tiny).
311-
"model": "meta-llama/Llama-2-7b-chat-hf",
311+
"model": "NousResearch/Llama-2-7b-chat-hf",
312312
313313
# Skip cuda graph recording for fast test.
314314
"enforce_eager": True,
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""The tests in this file verify end-to-end speculative decoding correctness.
2+
3+
This docstring details important information on the testing methodology.
4+
5+
Most of the tests rely on "greedy equality", where we expect the output of
6+
speculative decoding on a sequence to exactly match the output of normal non-
7+
speculative decoding.
8+
9+
Since speculative decoding with rejection sampling guarantees that the output
10+
distribution matches the target model's output distribution (up to hardware
11+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
12+
equality.
13+
14+
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
15+
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
16+
Since there is no model is needed for generate the proposal, we could make
17+
the testcase much simplier than drafter multi-step one.
18+
19+
However, we still need to verify below scenario could be passed:
20+
* Batch size 1 greedy equality
21+
* Batch size >1 greedy equality
22+
* Test greedy equality under preemption
23+
* Test greedy equality under various ngram sizes / speculative sizes
24+
25+
With those tests, we can say at least, ngram spec would not break the correctess
26+
for the target model outputs.
27+
"""
28+
29+
from itertools import cycle
30+
31+
import pytest
32+
33+
from vllm import SamplingParams
34+
35+
from .conftest import get_output_from_llm_generator
36+
37+
38+
@pytest.mark.parametrize(
39+
"common_llm_kwargs",
40+
[{
41+
# Skip cuda graph recording for fast test.
42+
"enforce_eager": True,
43+
44+
# Required for spec decode.
45+
"use_v2_block_manager": True,
46+
47+
# Print spec metrics.
48+
"disable_log_stats": False,
49+
}])
50+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
51+
{
52+
"model": "JackFram/llama-68m",
53+
},
54+
])
55+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
56+
@pytest.mark.parametrize("test_llm_kwargs", [
57+
{
58+
"speculative_model": "[ngram]",
59+
"num_speculative_tokens": 5,
60+
"ngram_prompt_lookup_max": 3,
61+
},
62+
])
63+
@pytest.mark.parametrize(
64+
"output_len",
65+
[
66+
# Use long output len for the small model test.
67+
1536,
68+
])
69+
@pytest.mark.parametrize("batch_size", [1])
70+
@pytest.mark.parametrize("seed", [1])
71+
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
72+
baseline_llm_generator, test_llm_generator, batch_size: int,
73+
output_len: int):
74+
"""Verify greedy equality on a tiny model with batch size of one.
75+
76+
Since this test is cheaper than other e2e correctness tests, we generate
77+
with a higher output_len.
78+
"""
79+
run_greedy_equality_correctness_test(baseline_llm_generator,
80+
test_llm_generator,
81+
batch_size,
82+
max_output_len=output_len,
83+
force_output_len=True)
84+
85+
86+
@pytest.mark.parametrize(
87+
"common_llm_kwargs",
88+
[{
89+
# Skip cuda graph recording for fast test.
90+
"enforce_eager": True,
91+
92+
# Required for spec decode.
93+
"use_v2_block_manager": True,
94+
95+
# Print spec metrics.
96+
"disable_log_stats": False,
97+
}])
98+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
99+
{
100+
"model": "JackFram/llama-68m",
101+
},
102+
])
103+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
104+
@pytest.mark.parametrize("test_llm_kwargs", [
105+
{
106+
"speculative_model": "[ngram]",
107+
"num_speculative_tokens": 5,
108+
"ngram_prompt_lookup_max": 3,
109+
},
110+
])
111+
@pytest.mark.parametrize(
112+
"output_len",
113+
[
114+
# Use small output len for fast test.
115+
256,
116+
])
117+
@pytest.mark.parametrize("batch_size", [64])
118+
@pytest.mark.parametrize("seed", [1])
119+
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
120+
baseline_llm_generator, test_llm_generator, batch_size: int,
121+
output_len: int):
122+
"""Verify greedy equality on a tiny model and large batch size.
123+
"""
124+
run_greedy_equality_correctness_test(baseline_llm_generator,
125+
test_llm_generator,
126+
batch_size,
127+
max_output_len=output_len,
128+
force_output_len=True)
129+
130+
131+
@pytest.mark.parametrize(
132+
"common_llm_kwargs",
133+
[{
134+
"block_size": 8,
135+
# 2 for small prompt, 256//8 for generated.
136+
"num_gpu_blocks_override": 2 + 256 // 8,
137+
"max_model_len": (2 + 256 // 8) * 8,
138+
139+
# Skip cuda graph recording for fast test.
140+
"enforce_eager": True,
141+
142+
# Required for spec decode.
143+
"use_v2_block_manager": True
144+
}])
145+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
146+
{
147+
"model": "JackFram/llama-160m",
148+
},
149+
])
150+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
151+
@pytest.mark.parametrize("test_llm_kwargs", [
152+
{
153+
"speculative_model": "[ngram]",
154+
"num_speculative_tokens": 5,
155+
"ngram_prompt_lookup_max": 3,
156+
},
157+
])
158+
@pytest.mark.parametrize(
159+
"output_len",
160+
[
161+
# Use small output len for fast test.
162+
256,
163+
])
164+
@pytest.mark.parametrize("batch_size", [4])
165+
@pytest.mark.parametrize("seed", [1])
166+
def test_spec_decode_e2e_greedy_correctness_with_preemption(
167+
baseline_llm_generator, test_llm_generator, batch_size: int,
168+
output_len: int):
169+
"""Verify greedy equality, even when some sequences are preempted mid-
170+
generation.
171+
"""
172+
run_greedy_equality_correctness_test(baseline_llm_generator,
173+
test_llm_generator,
174+
batch_size,
175+
max_output_len=output_len,
176+
force_output_len=True)
177+
178+
179+
@pytest.mark.parametrize(
180+
"common_llm_kwargs",
181+
[{
182+
"model": "JackFram/llama-68m",
183+
184+
# Skip cuda graph recording for fast test.
185+
"enforce_eager": True,
186+
187+
# Required for spec decode.
188+
"use_v2_block_manager": True
189+
}])
190+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
191+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
192+
@pytest.mark.parametrize(
193+
"test_llm_kwargs",
194+
[
195+
{
196+
"speculative_model": "[ngram]",
197+
"num_speculative_tokens": k,
198+
"ngram_prompt_lookup_max": 3,
199+
}
200+
# Try a range of common k, as well as large speculation.
201+
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
202+
] + [
203+
{
204+
"speculative_model": "[ngram]",
205+
"num_speculative_tokens": k,
206+
"ngram_prompt_lookup_max": 1,
207+
}
208+
# Try a range of common k, as well as large speculation.
209+
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
210+
])
211+
@pytest.mark.parametrize("batch_size", [2])
212+
@pytest.mark.parametrize(
213+
"output_len",
214+
[
215+
# Use smaller output len for fast test.
216+
32,
217+
])
218+
@pytest.mark.parametrize("seed", [1])
219+
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
220+
output_len: int):
221+
"""Verify that speculative decoding produces exact equality to without spec
222+
decode with many different values of k.
223+
"""
224+
run_greedy_equality_correctness_test(baseline_llm_generator,
225+
test_llm_generator,
226+
batch_size,
227+
max_output_len=output_len,
228+
force_output_len=True)
229+
230+
231+
def run_greedy_equality_correctness_test(baseline_llm_generator,
232+
test_llm_generator,
233+
batch_size,
234+
max_output_len,
235+
force_output_len: bool,
236+
print_tokens: bool = False):
237+
"""Helper method that compares the outputs of both the baseline LLM and
238+
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
239+
the same when temperature is zero.
240+
"""
241+
temperature = 0.0
242+
243+
prompts = [
244+
"Hello, my name is",
245+
"The president of the United States is",
246+
"The capital of France is",
247+
"The future of AI is",
248+
"San Francisco is know for its",
249+
"Facebook was created in 2004 by",
250+
"Curious George is a",
251+
"Python 3.11 brings improvements to its",
252+
]
253+
254+
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
255+
256+
# If the test requires that we generated max_output_len tokens, then set the
257+
# sampling params to ignore eos token.
258+
ignore_eos = force_output_len
259+
260+
sampling_params = SamplingParams(
261+
max_tokens=max_output_len,
262+
ignore_eos=ignore_eos,
263+
temperature=temperature,
264+
)
265+
266+
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
267+
test_llm_generator, prompts, sampling_params)
268+
269+
(baseline_batch_tokens,
270+
baseline_batch_token_ids) = get_output_from_llm_generator(
271+
baseline_llm_generator, prompts, sampling_params)
272+
273+
assert len(baseline_batch_token_ids) == len(prompts)
274+
assert len(spec_batch_token_ids) == len(prompts)
275+
276+
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
277+
spec_tokens) in enumerate(
278+
zip(baseline_batch_token_ids, baseline_batch_tokens,
279+
spec_batch_token_ids, spec_batch_tokens)):
280+
if print_tokens:
281+
print(f'{i=} {baseline_tokens=}')
282+
print(f'{i=} {spec_tokens=}')
283+
print(f'{i=} {baseline_token_ids=}')
284+
print(f'{i=} {spec_token_ids=}')
285+
assert baseline_token_ids == spec_token_ids

0 commit comments

Comments
 (0)