|
21 | 21 |
|
22 | 22 | import pytest
|
23 | 23 |
|
| 24 | +from ..utils import maybe_enable_chunked_prefill |
24 | 25 | from .conftest import run_equality_correctness_test
|
25 | 26 |
|
26 | 27 | # main model
|
|
67 | 68 | ])
|
68 | 69 | @pytest.mark.parametrize("batch_size", [1, 32])
|
69 | 70 | @pytest.mark.parametrize("seed", [1])
|
| 71 | +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) |
70 | 72 | def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
71 | 73 | per_test_common_llm_kwargs,
|
72 | 74 | baseline_llm_kwargs, test_llm_kwargs,
|
73 | 75 | batch_size: int, output_len: int,
|
74 |
| - seed: int): |
| 76 | + seed: int, prefill_chunk_size: int): |
75 | 77 | """Verify greedy equality with different batch size."""
|
| 78 | + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) |
76 | 79 | run_equality_correctness_test(vllm_runner,
|
77 | 80 | common_llm_kwargs,
|
78 | 81 | per_test_common_llm_kwargs,
|
@@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
119 | 122 | @pytest.mark.parametrize("batch_size", [8])
|
120 | 123 | @pytest.mark.parametrize("seed", [1])
|
121 | 124 | @pytest.mark.parametrize("logprobs", [1, 6])
|
| 125 | +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) |
122 | 126 | def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
123 | 127 | per_test_common_llm_kwargs,
|
124 | 128 | baseline_llm_kwargs, test_llm_kwargs,
|
125 | 129 | batch_size: int, output_len: int,
|
126 |
| - seed: int, logprobs: int): |
| 130 | + seed: int, logprobs: int, |
| 131 | + prefill_chunk_size: int): |
127 | 132 | """Verify greedy equality with different batch size."""
|
| 133 | + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) |
128 | 134 | run_equality_correctness_test(vllm_runner,
|
129 | 135 | common_llm_kwargs,
|
130 | 136 | per_test_common_llm_kwargs,
|
@@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
167 | 173 | ])
|
168 | 174 | @pytest.mark.parametrize("batch_size", [1, 32])
|
169 | 175 | @pytest.mark.parametrize("seed", [1])
|
| 176 | +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) |
170 | 177 | def test_medusa_e2e_greedy_correctness_cuda_graph(
|
171 | 178 | vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
172 | 179 | baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
173 |
| - seed: int): |
| 180 | + seed: int, prefill_chunk_size: int): |
174 | 181 | """Verify greedy equality with cuda graph enabled and different
|
175 | 182 | batch sizes."""
|
| 183 | + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) |
176 | 184 | run_equality_correctness_test(vllm_runner,
|
177 | 185 | common_llm_kwargs,
|
178 | 186 | per_test_common_llm_kwargs,
|
@@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
217 | 225 | ])
|
218 | 226 | @pytest.mark.parametrize("batch_size", [4])
|
219 | 227 | @pytest.mark.parametrize("seed", [1])
|
| 228 | +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) |
220 | 229 | def test_medusa_e2e_greedy_correctness_with_preemption(
|
221 | 230 | vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
222 | 231 | baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
223 |
| - seed: int): |
| 232 | + seed: int, prefill_chunk_size: int): |
224 | 233 | """Verify greedy equality, even when some sequences are preempted mid-
|
225 | 234 | generation.
|
226 | 235 | """
|
| 236 | + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) |
227 | 237 | run_equality_correctness_test(vllm_runner,
|
228 | 238 | common_llm_kwargs,
|
229 | 239 | per_test_common_llm_kwargs,
|
@@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
267 | 277 | 32,
|
268 | 278 | ])
|
269 | 279 | @pytest.mark.parametrize("seed", [1])
|
| 280 | +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) |
270 | 281 | def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
271 | 282 | per_test_common_llm_kwargs, baseline_llm_kwargs,
|
272 | 283 | test_llm_kwargs, batch_size: int, output_len: int,
|
273 |
| - seed: int): |
| 284 | + seed: int, prefill_chunk_size: int): |
274 | 285 | """Verify that medusa speculative decoding produces exact equality
|
275 | 286 | to without spec decode with different values of num_speculative_tokens.
|
276 | 287 | """
|
| 288 | + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) |
277 | 289 | run_equality_correctness_test(vllm_runner,
|
278 | 290 | common_llm_kwargs,
|
279 | 291 | per_test_common_llm_kwargs,
|
@@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
313 | 325 | 32,
|
314 | 326 | ])
|
315 | 327 | @pytest.mark.parametrize("seed", [1])
|
| 328 | +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) |
316 | 329 | def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
317 | 330 | per_test_common_llm_kwargs, baseline_llm_kwargs,
|
318 | 331 | test_llm_kwargs, batch_size: int,
|
319 |
| - output_len: int, seed: int): |
| 332 | + output_len: int, seed: int, |
| 333 | + prefill_chunk_size: int): |
320 | 334 | """Verify that medusa speculative decoding produces exact equality
|
321 | 335 | to without spec decode when speculation is disabled for large
|
322 | 336 | batch sizes.
|
323 | 337 | """
|
| 338 | + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) |
324 | 339 | run_equality_correctness_test(vllm_runner,
|
325 | 340 | common_llm_kwargs,
|
326 | 341 | per_test_common_llm_kwargs,
|
@@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
361 | 376 | 32,
|
362 | 377 | ])
|
363 | 378 | @pytest.mark.parametrize("seed", [1])
|
| 379 | +@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) |
364 | 380 | def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
365 | 381 | baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
366 |
| - output_len: int, seed: int): |
| 382 | + output_len: int, seed: int, prefill_chunk_size: int): |
367 | 383 | """Verify that speculative decoding generates the same output
|
368 | 384 | with batch expansion scorer and mqa scorer.
|
369 | 385 | """
|
| 386 | + maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) |
370 | 387 | run_equality_correctness_test(vllm_runner,
|
371 | 388 | common_llm_kwargs,
|
372 | 389 | per_test_common_llm_kwargs,
|
|
0 commit comments