|
| 1 | +"""The tests in this file verify end-to-end speculative decoding correctness. |
| 2 | +
|
| 3 | +This docstring details important information on the testing methodology. |
| 4 | +
|
| 5 | +Most of the tests rely on "greedy equality", where we expect the output of |
| 6 | +speculative decoding on a sequence to exactly match the output of normal non- |
| 7 | +speculative decoding. |
| 8 | +
|
| 9 | +Since speculative decoding with rejection sampling guarantees that the output |
| 10 | +distribution matches the target model's output distribution (up to hardware |
| 11 | +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy |
| 12 | +equality. |
| 13 | +
|
| 14 | +For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, |
| 15 | +and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. |
| 16 | +Since there is no model is needed for generate the proposal, we could make |
| 17 | +the testcase much simplier than drafter multi-step one. |
| 18 | +
|
| 19 | +However, we still need to verify below scenario could be passed: |
| 20 | + * Batch size 1 greedy equality |
| 21 | + * Batch size >1 greedy equality |
| 22 | + * Test greedy equality under preemption |
| 23 | + * Test greedy equality under various ngram sizes / speculative sizes |
| 24 | +
|
| 25 | +With those tests, we can say at least, ngram spec would not break the correctess |
| 26 | +for the target model outputs. |
| 27 | +""" |
| 28 | + |
| 29 | +from itertools import cycle |
| 30 | + |
| 31 | +import pytest |
| 32 | + |
| 33 | +from vllm import SamplingParams |
| 34 | + |
| 35 | +from .conftest import get_output_from_llm_generator |
| 36 | + |
| 37 | + |
| 38 | +@pytest.mark.parametrize( |
| 39 | + "common_llm_kwargs", |
| 40 | + [{ |
| 41 | + # Skip cuda graph recording for fast test. |
| 42 | + "enforce_eager": True, |
| 43 | +
|
| 44 | + # Required for spec decode. |
| 45 | + "use_v2_block_manager": True, |
| 46 | +
|
| 47 | + # Print spec metrics. |
| 48 | + "disable_log_stats": False, |
| 49 | + }]) |
| 50 | +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ |
| 51 | + { |
| 52 | + "model": "JackFram/llama-68m", |
| 53 | + }, |
| 54 | +]) |
| 55 | +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) |
| 56 | +@pytest.mark.parametrize("test_llm_kwargs", [ |
| 57 | + { |
| 58 | + "speculative_model": "[ngram]", |
| 59 | + "num_speculative_tokens": 5, |
| 60 | + "ngram_prompt_lookup_max": 3, |
| 61 | + }, |
| 62 | +]) |
| 63 | +@pytest.mark.parametrize( |
| 64 | + "output_len", |
| 65 | + [ |
| 66 | + # Use long output len for the small model test. |
| 67 | + 1536, |
| 68 | + ]) |
| 69 | +@pytest.mark.parametrize("batch_size", [1]) |
| 70 | +@pytest.mark.parametrize("seed", [1]) |
| 71 | +def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( |
| 72 | + baseline_llm_generator, test_llm_generator, batch_size: int, |
| 73 | + output_len: int): |
| 74 | + """Verify greedy equality on a tiny model with batch size of one. |
| 75 | +
|
| 76 | + Since this test is cheaper than other e2e correctness tests, we generate |
| 77 | + with a higher output_len. |
| 78 | + """ |
| 79 | + run_greedy_equality_correctness_test(baseline_llm_generator, |
| 80 | + test_llm_generator, |
| 81 | + batch_size, |
| 82 | + max_output_len=output_len, |
| 83 | + force_output_len=True) |
| 84 | + |
| 85 | + |
| 86 | +@pytest.mark.parametrize( |
| 87 | + "common_llm_kwargs", |
| 88 | + [{ |
| 89 | + # Skip cuda graph recording for fast test. |
| 90 | + "enforce_eager": True, |
| 91 | +
|
| 92 | + # Required for spec decode. |
| 93 | + "use_v2_block_manager": True, |
| 94 | +
|
| 95 | + # Print spec metrics. |
| 96 | + "disable_log_stats": False, |
| 97 | + }]) |
| 98 | +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ |
| 99 | + { |
| 100 | + "model": "JackFram/llama-68m", |
| 101 | + }, |
| 102 | +]) |
| 103 | +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) |
| 104 | +@pytest.mark.parametrize("test_llm_kwargs", [ |
| 105 | + { |
| 106 | + "speculative_model": "[ngram]", |
| 107 | + "num_speculative_tokens": 5, |
| 108 | + "ngram_prompt_lookup_max": 3, |
| 109 | + }, |
| 110 | +]) |
| 111 | +@pytest.mark.parametrize( |
| 112 | + "output_len", |
| 113 | + [ |
| 114 | + # Use small output len for fast test. |
| 115 | + 256, |
| 116 | + ]) |
| 117 | +@pytest.mark.parametrize("batch_size", [64]) |
| 118 | +@pytest.mark.parametrize("seed", [1]) |
| 119 | +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( |
| 120 | + baseline_llm_generator, test_llm_generator, batch_size: int, |
| 121 | + output_len: int): |
| 122 | + """Verify greedy equality on a tiny model and large batch size. |
| 123 | + """ |
| 124 | + run_greedy_equality_correctness_test(baseline_llm_generator, |
| 125 | + test_llm_generator, |
| 126 | + batch_size, |
| 127 | + max_output_len=output_len, |
| 128 | + force_output_len=True) |
| 129 | + |
| 130 | + |
| 131 | +@pytest.mark.parametrize( |
| 132 | + "common_llm_kwargs", |
| 133 | + [{ |
| 134 | + "block_size": 8, |
| 135 | + # 2 for small prompt, 256//8 for generated. |
| 136 | + "num_gpu_blocks_override": 2 + 256 // 8, |
| 137 | + "max_model_len": (2 + 256 // 8) * 8, |
| 138 | +
|
| 139 | + # Skip cuda graph recording for fast test. |
| 140 | + "enforce_eager": True, |
| 141 | +
|
| 142 | + # Required for spec decode. |
| 143 | + "use_v2_block_manager": True |
| 144 | + }]) |
| 145 | +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ |
| 146 | + { |
| 147 | + "model": "JackFram/llama-160m", |
| 148 | + }, |
| 149 | +]) |
| 150 | +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) |
| 151 | +@pytest.mark.parametrize("test_llm_kwargs", [ |
| 152 | + { |
| 153 | + "speculative_model": "[ngram]", |
| 154 | + "num_speculative_tokens": 5, |
| 155 | + "ngram_prompt_lookup_max": 3, |
| 156 | + }, |
| 157 | +]) |
| 158 | +@pytest.mark.parametrize( |
| 159 | + "output_len", |
| 160 | + [ |
| 161 | + # Use small output len for fast test. |
| 162 | + 256, |
| 163 | + ]) |
| 164 | +@pytest.mark.parametrize("batch_size", [4]) |
| 165 | +@pytest.mark.parametrize("seed", [1]) |
| 166 | +def test_spec_decode_e2e_greedy_correctness_with_preemption( |
| 167 | + baseline_llm_generator, test_llm_generator, batch_size: int, |
| 168 | + output_len: int): |
| 169 | + """Verify greedy equality, even when some sequences are preempted mid- |
| 170 | + generation. |
| 171 | + """ |
| 172 | + run_greedy_equality_correctness_test(baseline_llm_generator, |
| 173 | + test_llm_generator, |
| 174 | + batch_size, |
| 175 | + max_output_len=output_len, |
| 176 | + force_output_len=True) |
| 177 | + |
| 178 | + |
| 179 | +@pytest.mark.parametrize( |
| 180 | + "common_llm_kwargs", |
| 181 | + [{ |
| 182 | + "model": "JackFram/llama-68m", |
| 183 | +
|
| 184 | + # Skip cuda graph recording for fast test. |
| 185 | + "enforce_eager": True, |
| 186 | +
|
| 187 | + # Required for spec decode. |
| 188 | + "use_v2_block_manager": True |
| 189 | + }]) |
| 190 | +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) |
| 191 | +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) |
| 192 | +@pytest.mark.parametrize( |
| 193 | + "test_llm_kwargs", |
| 194 | + [ |
| 195 | + { |
| 196 | + "speculative_model": "[ngram]", |
| 197 | + "num_speculative_tokens": k, |
| 198 | + "ngram_prompt_lookup_max": 3, |
| 199 | + } |
| 200 | + # Try a range of common k, as well as large speculation. |
| 201 | + for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] |
| 202 | + ] + [ |
| 203 | + { |
| 204 | + "speculative_model": "[ngram]", |
| 205 | + "num_speculative_tokens": k, |
| 206 | + "ngram_prompt_lookup_max": 1, |
| 207 | + } |
| 208 | + # Try a range of common k, as well as large speculation. |
| 209 | + for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] |
| 210 | + ]) |
| 211 | +@pytest.mark.parametrize("batch_size", [2]) |
| 212 | +@pytest.mark.parametrize( |
| 213 | + "output_len", |
| 214 | + [ |
| 215 | + # Use smaller output len for fast test. |
| 216 | + 32, |
| 217 | + ]) |
| 218 | +@pytest.mark.parametrize("seed", [1]) |
| 219 | +def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, |
| 220 | + output_len: int): |
| 221 | + """Verify that speculative decoding produces exact equality to without spec |
| 222 | + decode with many different values of k. |
| 223 | + """ |
| 224 | + run_greedy_equality_correctness_test(baseline_llm_generator, |
| 225 | + test_llm_generator, |
| 226 | + batch_size, |
| 227 | + max_output_len=output_len, |
| 228 | + force_output_len=True) |
| 229 | + |
| 230 | + |
| 231 | +def run_greedy_equality_correctness_test(baseline_llm_generator, |
| 232 | + test_llm_generator, |
| 233 | + batch_size, |
| 234 | + max_output_len, |
| 235 | + force_output_len: bool, |
| 236 | + print_tokens: bool = False): |
| 237 | + """Helper method that compares the outputs of both the baseline LLM and |
| 238 | + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly |
| 239 | + the same when temperature is zero. |
| 240 | + """ |
| 241 | + temperature = 0.0 |
| 242 | + |
| 243 | + prompts = [ |
| 244 | + "Hello, my name is", |
| 245 | + "The president of the United States is", |
| 246 | + "The capital of France is", |
| 247 | + "The future of AI is", |
| 248 | + "San Francisco is know for its", |
| 249 | + "Facebook was created in 2004 by", |
| 250 | + "Curious George is a", |
| 251 | + "Python 3.11 brings improvements to its", |
| 252 | + ] |
| 253 | + |
| 254 | + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] |
| 255 | + |
| 256 | + # If the test requires that we generated max_output_len tokens, then set the |
| 257 | + # sampling params to ignore eos token. |
| 258 | + ignore_eos = force_output_len |
| 259 | + |
| 260 | + sampling_params = SamplingParams( |
| 261 | + max_tokens=max_output_len, |
| 262 | + ignore_eos=ignore_eos, |
| 263 | + temperature=temperature, |
| 264 | + ) |
| 265 | + |
| 266 | + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( |
| 267 | + test_llm_generator, prompts, sampling_params) |
| 268 | + |
| 269 | + (baseline_batch_tokens, |
| 270 | + baseline_batch_token_ids) = get_output_from_llm_generator( |
| 271 | + baseline_llm_generator, prompts, sampling_params) |
| 272 | + |
| 273 | + assert len(baseline_batch_token_ids) == len(prompts) |
| 274 | + assert len(spec_batch_token_ids) == len(prompts) |
| 275 | + |
| 276 | + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, |
| 277 | + spec_tokens) in enumerate( |
| 278 | + zip(baseline_batch_token_ids, baseline_batch_tokens, |
| 279 | + spec_batch_token_ids, spec_batch_tokens)): |
| 280 | + if print_tokens: |
| 281 | + print(f'{i=} {baseline_tokens=}') |
| 282 | + print(f'{i=} {spec_tokens=}') |
| 283 | + print(f'{i=} {baseline_token_ids=}') |
| 284 | + print(f'{i=} {spec_token_ids=}') |
| 285 | + assert baseline_token_ids == spec_token_ids |
0 commit comments