|
1 | 1 | # Test the LLMEngine with multi-step-decoding
|
2 | 2 |
|
| 3 | +import copy |
3 | 4 | from typing import Optional
|
4 | 5 |
|
5 | 6 | import pytest
|
@@ -196,3 +197,160 @@ def test_multi_step_llm_w_prompt_logprobs(
|
196 | 197 | name_0="hf",
|
197 | 198 | name_1="vllm",
|
198 | 199 | )
|
| 200 | + |
| 201 | + |
| 202 | +@pytest.mark.parametrize("model", MODELS) |
| 203 | +@pytest.mark.parametrize("dtype", ["half"]) |
| 204 | +@pytest.mark.parametrize("tp_size", [1]) |
| 205 | +@pytest.mark.parametrize("max_tokens", [5]) |
| 206 | +@pytest.mark.parametrize("enforce_eager", [True]) |
| 207 | +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) |
| 208 | +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) |
| 209 | +@pytest.mark.parametrize("num_logprobs", [None, 5]) |
| 210 | +def test_multi_step_llm_chunked_prefill_prefix_cache( |
| 211 | + vllm_runner, |
| 212 | + example_prompts, |
| 213 | + model: str, |
| 214 | + dtype: str, |
| 215 | + tp_size: int, |
| 216 | + max_tokens: int, |
| 217 | + enforce_eager: int, |
| 218 | + num_scheduler_steps: int, |
| 219 | + num_prompts: int, |
| 220 | + num_logprobs: Optional[int], |
| 221 | +) -> None: |
| 222 | + """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. |
| 223 | +
|
| 224 | + Set up contrived scenario which tests for a possible failure mode of |
| 225 | + scheduling with multi-step+"single-step chunked prefill"+APC |
| 226 | +
|
| 227 | + "single-step chunked prefill" here refers to the current vLLM multi-step+ |
| 228 | + chunked-prefill implementation, which requires that a prefill may only |
| 229 | + be scheduled in the same step as decodes if the prefill prompt fits in a |
| 230 | + single chunk (note that "complete" multi-step+chunked-prefill would allow |
| 231 | + a prefill to span multiple chunks & multiple steps but that is not yet |
| 232 | + the case.) |
| 233 | +
|
| 234 | + "APC" is short for "automatic prefix caching". |
| 235 | +
|
| 236 | + This test creates a scenario where the scheduler must decide whether/how |
| 237 | + to schedule a prefill with a prompt that exceeds the available token budget. |
| 238 | + The correct behavior for multi-step+"single-step chunked prefill"+APC is to |
| 239 | + put off scheduling the prefill until a future step. |
| 240 | +
|
| 241 | + Validate that: |
| 242 | + * Multi-step kernels do not raise an exception due to incorrect scheduler |
| 243 | + behavior |
| 244 | + * Generated tokens match between |
| 245 | + multi-step+"single-step chunked prefill"+APC and |
| 246 | + single-step scheduling. |
| 247 | + * (If logprobs are enabled) check logprobs are close enough |
| 248 | +
|
| 249 | + Args: |
| 250 | + vllm_runner: vLLM model runner fixture |
| 251 | + example_prompts: test fixture providing example prompts |
| 252 | + model: model under test (same for single- and multi-step engines) |
| 253 | + dtype: tensor datatype for engine to utilize |
| 254 | + tp_size: degree of tensor-parallelism |
| 255 | + max_tokens: the maximum number of tokens to generate |
| 256 | + enforce_eager |
| 257 | + num_scheduler_steps: for multi-step scheduling, GPU-side steps per |
| 258 | + GPU -> CPU output transfer |
| 259 | + num_prompts: number of example prompts under test |
| 260 | + num_logprobs: corresponds to the `logprobs` argument to the OpenAI |
| 261 | + completions endpoint; `None` -> 1 logprob returned. |
| 262 | + """ |
| 263 | + |
| 264 | + # Set up contrived test for correct scheduling behavior with |
| 265 | + # multi-step+"single-step chunked prefill"+APC. |
| 266 | + # |
| 267 | + # Assume block_size=16 |
| 268 | + # |
| 269 | + # Assume max_num_batched_tokens=48 |
| 270 | + # => Per-step token budget=48 |
| 271 | + # |
| 272 | + # 1. Scheduler schedules 0th prompt (24 tokens) |
| 273 | + # => Remaining token budget=24 |
| 274 | + # 2. Scheduler attempts to schedule 1st prompt (30 tokens) |
| 275 | + # * 30 tokens exceeds 24 token remaining budget |
| 276 | + # * Correct behavior: do not schedule this prompt in this step |
| 277 | + # * Incorrect behavior: schedule prompt chunk |
| 278 | + # * `do_sample=False` for this prompt in this step |
| 279 | + # * Chunk size = (remaining tokens // block size) * block size |
| 280 | + # |
| 281 | + # The Incorrect scheduling behavior - if it occurs - will cause an exception |
| 282 | + # in the model runner resulting from `do_sample=False`. |
| 283 | + assert len(example_prompts) >= 2 |
| 284 | + challenge_prompts = copy.deepcopy(example_prompts) |
| 285 | + challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' |
| 286 | + 'inference and serving engine for LLMs.\n' |
| 287 | + ) # 24 tok |
| 288 | + challenge_prompts[1] = ( |
| 289 | + 'Briefly describe the major milestones in the ' |
| 290 | + 'development of artificial intelligence from 1950 to 2020.\n' |
| 291 | + ) # 30 tok |
| 292 | + |
| 293 | + # If necessary, adjust the length of `challenge_prompts` to match |
| 294 | + # `num_prompts` |
| 295 | + if len(challenge_prompts) < num_prompts: |
| 296 | + challenge_prompts = (challenge_prompts * |
| 297 | + ((num_prompts // len(challenge_prompts)) + 1)) |
| 298 | + challenge_prompts = challenge_prompts[:num_prompts] |
| 299 | + assert len(challenge_prompts) == num_prompts |
| 300 | + |
| 301 | + # Single-step scheduler baseline |
| 302 | + with vllm_runner( |
| 303 | + model, |
| 304 | + dtype=dtype, |
| 305 | + enforce_eager=enforce_eager, |
| 306 | + gpu_memory_utilization=0.7, |
| 307 | + tensor_parallel_size=tp_size, |
| 308 | + use_v2_block_manager=True, |
| 309 | + num_scheduler_steps=num_scheduler_steps, |
| 310 | + max_model_len=48, |
| 311 | + max_num_batched_tokens=48, |
| 312 | + max_num_seqs=4, |
| 313 | + block_size=16, |
| 314 | + ) as vllm_model: |
| 315 | + outputs_baseline = (vllm_model.generate_greedy( |
| 316 | + challenge_prompts, max_tokens) if num_logprobs is None else |
| 317 | + vllm_model.generate_greedy_logprobs( |
| 318 | + challenge_prompts, max_tokens, num_logprobs)) |
| 319 | + |
| 320 | + # multi-step+"single-step chunked prefill"+APC |
| 321 | + with vllm_runner( |
| 322 | + model, |
| 323 | + dtype=dtype, |
| 324 | + enforce_eager=enforce_eager, |
| 325 | + gpu_memory_utilization=0.7, |
| 326 | + tensor_parallel_size=tp_size, |
| 327 | + use_v2_block_manager=True, |
| 328 | + enable_chunked_prefill=True, |
| 329 | + enable_prefix_caching=True, |
| 330 | + num_scheduler_steps=num_scheduler_steps, |
| 331 | + max_model_len=48, |
| 332 | + max_num_batched_tokens=48, |
| 333 | + max_num_seqs=4, |
| 334 | + block_size=16, |
| 335 | + ) as vllm_model: |
| 336 | + outputs_w_features = (vllm_model.generate_greedy( |
| 337 | + challenge_prompts, max_tokens) if num_logprobs is None else |
| 338 | + vllm_model.generate_greedy_logprobs( |
| 339 | + challenge_prompts, max_tokens, num_logprobs)) |
| 340 | + |
| 341 | + if num_logprobs is None: |
| 342 | + # No-logprobs test |
| 343 | + check_outputs_equal( |
| 344 | + outputs_0_lst=outputs_baseline, |
| 345 | + outputs_1_lst=outputs_w_features, |
| 346 | + name_0="multi-step", |
| 347 | + name_1="multi-step+features", |
| 348 | + ) |
| 349 | + else: |
| 350 | + # Yes-logprobs test |
| 351 | + check_logprobs_close( |
| 352 | + outputs_0_lst=outputs_baseline, |
| 353 | + outputs_1_lst=outputs_w_features, |
| 354 | + name_0="multi-step", |
| 355 | + name_1="multi-step+features", |
| 356 | + ) |
0 commit comments