Skip to content

Commit 8adaf38

Browse files
wenlei03root
authored andcommitted
[Speculative decoding] Add ngram prompt lookup decoding
Algo details could refer to this blog post: https://huggingface.co/blog/assisted-generation Code directly refer to transformers's current implementation. huggingface/transformers#27775 Since we directly get draft from prompt, there is no need another model or modified model to get the proposal, it would be the most convenient way to enjoy the speedup of speculation.
1 parent d6f4bd7 commit 8adaf38

File tree

14 files changed

+1004
-319
lines changed

14 files changed

+1004
-319
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from itertools import cycle
23
from typing import List, Optional, Tuple, Union
34

45
import pytest
@@ -185,3 +186,60 @@ def get_output_from_llm_generator(
185186
del llm
186187

187188
return tokens, token_ids
189+
190+
191+
def run_greedy_equality_correctness_test(baseline_llm_generator,
192+
test_llm_generator,
193+
batch_size,
194+
max_output_len,
195+
force_output_len: bool,
196+
print_tokens: bool = False):
197+
"""Helper method that compares the outputs of both the baseline LLM and
198+
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
199+
the same when temperature is zero.
200+
"""
201+
temperature = 0.0
202+
203+
prompts = [
204+
"Hello, my name is",
205+
"The president of the United States is",
206+
"The capital of France is",
207+
"The future of AI is",
208+
"San Francisco is know for its",
209+
"Facebook was created in 2004 by",
210+
"Curious George is a",
211+
"Python 3.11 brings improvements to its",
212+
]
213+
214+
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
215+
216+
# If the test requires that we generated max_output_len tokens, then set the
217+
# sampling params to ignore eos token.
218+
ignore_eos = force_output_len
219+
220+
sampling_params = SamplingParams(
221+
max_tokens=max_output_len,
222+
ignore_eos=ignore_eos,
223+
temperature=temperature,
224+
)
225+
226+
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
227+
test_llm_generator, prompts, sampling_params)
228+
229+
(baseline_batch_tokens,
230+
baseline_batch_token_ids) = get_output_from_llm_generator(
231+
baseline_llm_generator, prompts, sampling_params)
232+
233+
assert len(baseline_batch_token_ids) == len(prompts)
234+
assert len(spec_batch_token_ids) == len(prompts)
235+
236+
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
237+
spec_tokens) in enumerate(
238+
zip(baseline_batch_token_ids, baseline_batch_tokens,
239+
spec_batch_token_ids, spec_batch_tokens)):
240+
if print_tokens:
241+
print(f'{i=} {baseline_tokens=}')
242+
print(f'{i=} {spec_tokens=}')
243+
print(f'{i=} {baseline_token_ids=}')
244+
print(f'{i=} {spec_token_ids=}')
245+
assert baseline_token_ids == spec_token_ids

tests/spec_decode/e2e/test_correctness.py renamed to tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535

3636
from vllm import SamplingParams
3737

38-
from .conftest import get_output_from_llm_generator
38+
from .conftest import (get_output_from_llm_generator,
39+
run_greedy_equality_correctness_test)
3940

4041

4142
@pytest.mark.parametrize(
@@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
545546
batch_size,
546547
max_output_len=output_len,
547548
force_output_len=True)
548-
549-
550-
def run_greedy_equality_correctness_test(baseline_llm_generator,
551-
test_llm_generator,
552-
batch_size,
553-
max_output_len,
554-
force_output_len: bool,
555-
print_tokens: bool = False):
556-
"""Helper method that compares the outputs of both the baseline LLM and
557-
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
558-
the same when temperature is zero.
559-
"""
560-
temperature = 0.0
561-
562-
prompts = [
563-
"Hello, my name is",
564-
"The president of the United States is",
565-
"The capital of France is",
566-
"The future of AI is",
567-
"San Francisco is know for its",
568-
"Facebook was created in 2004 by",
569-
"Curious George is a",
570-
"Python 3.11 brings improvements to its",
571-
]
572-
573-
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
574-
575-
# If the test requires that we generated max_output_len tokens, then set the
576-
# sampling params to ignore eos token.
577-
ignore_eos = force_output_len
578-
579-
sampling_params = SamplingParams(
580-
max_tokens=max_output_len,
581-
ignore_eos=ignore_eos,
582-
temperature=temperature,
583-
)
584-
585-
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
586-
test_llm_generator, prompts, sampling_params)
587-
588-
(baseline_batch_tokens,
589-
baseline_batch_token_ids) = get_output_from_llm_generator(
590-
baseline_llm_generator, prompts, sampling_params)
591-
592-
assert len(baseline_batch_token_ids) == len(prompts)
593-
assert len(spec_batch_token_ids) == len(prompts)
594-
595-
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
596-
spec_tokens) in enumerate(
597-
zip(baseline_batch_token_ids, baseline_batch_tokens,
598-
spec_batch_token_ids, spec_batch_tokens)):
599-
if print_tokens:
600-
print(f'{i=} {baseline_tokens=}')
601-
print(f'{i=} {spec_tokens=}')
602-
print(f'{i=} {baseline_token_ids=}')
603-
print(f'{i=} {spec_token_ids=}')
604-
assert baseline_token_ids == spec_token_ids
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""This docstring details important information on the testing methodology.
2+
3+
Most of the tests rely on "greedy equality", where we expect the output of
4+
speculative decoding on a sequence to exactly match the output of normal non-
5+
speculative decoding.
6+
7+
Since speculative decoding with rejection sampling guarantees that the output
8+
distribution matches the target model's output distribution (up to hardware
9+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
10+
equality.
11+
12+
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
13+
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
14+
Since there is no model is needed for generate the proposal, we could make
15+
the testcase much simpler than drafter multi-step one.
16+
17+
However, we still need to verify below scenario could be passed:
18+
* Batch size 1 greedy equality
19+
* Batch size >1 greedy equality
20+
* Test greedy equality under preemption
21+
* Test greedy equality under various ngram sizes / speculative sizes
22+
23+
With those tests, we can say at least, ngram spec would not break the correctess
24+
for the target model outputs.
25+
"""
26+
27+
import pytest
28+
29+
from .conftest import run_greedy_equality_correctness_test
30+
31+
32+
@pytest.mark.parametrize(
33+
"common_llm_kwargs",
34+
[{
35+
# Skip cuda graph recording for fast test.
36+
"enforce_eager": True,
37+
38+
# Required for spec decode.
39+
"use_v2_block_manager": True,
40+
41+
# Print spec metrics.
42+
"disable_log_stats": False,
43+
}])
44+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
45+
{
46+
"model": "JackFram/llama-68m",
47+
},
48+
])
49+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
50+
@pytest.mark.parametrize("test_llm_kwargs", [
51+
{
52+
"speculative_model": "[ngram]",
53+
"num_speculative_tokens": 5,
54+
"ngram_prompt_lookup_max": 3,
55+
},
56+
])
57+
@pytest.mark.parametrize("output_len", [
58+
256,
59+
])
60+
@pytest.mark.parametrize("batch_size", [1, 64])
61+
@pytest.mark.parametrize("seed", [1])
62+
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
63+
test_llm_generator, batch_size: int,
64+
output_len: int):
65+
"""Verify greedy equality on a tiny model with differnt batch size."""
66+
run_greedy_equality_correctness_test(baseline_llm_generator,
67+
test_llm_generator,
68+
batch_size,
69+
max_output_len=output_len,
70+
force_output_len=True)
71+
72+
73+
@pytest.mark.parametrize(
74+
"common_llm_kwargs",
75+
[{
76+
"block_size": 8,
77+
# 2 for small prompt, 256//8 for generated.
78+
"num_gpu_blocks_override": 2 + 256 // 8,
79+
"max_model_len": (2 + 256 // 8) * 8,
80+
81+
# Skip cuda graph recording for fast test.
82+
"enforce_eager": True,
83+
84+
# Required for spec decode.
85+
"use_v2_block_manager": True
86+
}])
87+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
88+
{
89+
"model": "JackFram/llama-160m",
90+
},
91+
])
92+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
93+
@pytest.mark.parametrize("test_llm_kwargs", [
94+
{
95+
"speculative_model": "[ngram]",
96+
"num_speculative_tokens": 5,
97+
"ngram_prompt_lookup_max": 3,
98+
},
99+
])
100+
@pytest.mark.parametrize(
101+
"output_len",
102+
[
103+
# Use small output len for fast test.
104+
256,
105+
])
106+
@pytest.mark.parametrize("batch_size", [4])
107+
@pytest.mark.parametrize("seed", [1])
108+
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
109+
test_llm_generator,
110+
batch_size: int,
111+
output_len: int):
112+
"""Verify greedy equality, even when some sequences are preempted mid-
113+
generation.
114+
"""
115+
run_greedy_equality_correctness_test(baseline_llm_generator,
116+
test_llm_generator,
117+
batch_size,
118+
max_output_len=output_len,
119+
force_output_len=True)
120+
121+
122+
@pytest.mark.parametrize(
123+
"common_llm_kwargs",
124+
[{
125+
"model": "JackFram/llama-68m",
126+
127+
# Skip cuda graph recording for fast test.
128+
"enforce_eager": True,
129+
130+
# Required for spec decode.
131+
"use_v2_block_manager": True
132+
}])
133+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
134+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
135+
@pytest.mark.parametrize(
136+
"test_llm_kwargs",
137+
[
138+
{
139+
"speculative_model": "[ngram]",
140+
"num_speculative_tokens": k,
141+
"ngram_prompt_lookup_max": 3,
142+
}
143+
# Try a range of common k, as well as large speculation.
144+
for k in [1, 3, 5]
145+
] + [
146+
{
147+
"speculative_model": "[ngram]",
148+
"num_speculative_tokens": k,
149+
"ngram_prompt_lookup_max": 1,
150+
}
151+
# Try a range of common k, as well as large speculation.
152+
for k in [1, 3, 5]
153+
])
154+
@pytest.mark.parametrize("batch_size", [2])
155+
@pytest.mark.parametrize(
156+
"output_len",
157+
[
158+
# Use smaller output len for fast test.
159+
32,
160+
])
161+
@pytest.mark.parametrize("seed", [1])
162+
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
163+
batch_size: int, output_len: int):
164+
"""Verify that ngram speculative decoding produces exact equality
165+
to without spec decode with many different values of k and
166+
different ngram_prompt_lookup_max.
167+
"""
168+
run_greedy_equality_correctness_test(baseline_llm_generator,
169+
test_llm_generator,
170+
batch_size,
171+
max_output_len=output_len,
172+
force_output_len=True)

0 commit comments

Comments
 (0)