|
5 | 5 | import pytest |
6 | 6 | import torch |
7 | 7 |
|
| 8 | +from vllm.attention.selector import (_Backend, |
| 9 | + global_force_attn_backend_context_manager) |
8 | 10 | from vllm.model_executor.layers.sampler import SamplerOutput |
9 | 11 | from vllm.model_executor.utils import set_random_seed |
10 | 12 | from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, |
@@ -303,6 +305,7 @@ def test_multi_step_with_batch_expansion_correct_output(): |
303 | 305 | seed, |
304 | 306 | model_runner_cls=TP1DraftModelRunner, |
305 | 307 | ) |
| 308 | + multi_step_worker.set_include_gpu_probs_tensor() |
306 | 309 | worker = create_worker( |
307 | 310 | Worker, |
308 | 311 | model_name, |
@@ -397,6 +400,7 @@ def test_multi_step_with_batch_expansion_incorrect_output(): |
397 | 400 | seed, |
398 | 401 | model_runner_cls=TP1DraftModelRunner, |
399 | 402 | ) |
| 403 | + multi_step_worker.set_include_gpu_probs_tensor() |
400 | 404 | worker = create_worker( |
401 | 405 | Worker, |
402 | 406 | model_name, |
@@ -477,6 +481,109 @@ def test_multi_step_with_batch_expansion_incorrect_output(): |
477 | 481 | assert (num_mismatch > 0) |
478 | 482 |
|
479 | 483 |
|
| 484 | +@torch.inference_mode() |
| 485 | +@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) |
| 486 | +# The choice of backends forces the multi_step_worker to choose between |
| 487 | +# the vanilla model_runner and TP1DraftModelRunner and that we can test |
| 488 | +# both code paths. |
| 489 | +@pytest.mark.parametrize('attn_backend', |
| 490 | + [_Backend.XFORMERS, _Backend.FLASH_ATTN]) |
| 491 | +def test_multi_step_correct_kvcache(num_steps, attn_backend): |
| 492 | + """Verify that the KV cache of the draft model |
| 493 | + is correctly updated for sequences with bonus token. |
| 494 | + """ |
| 495 | + seed = 100 |
| 496 | + model_name = "JackFram/llama-68m" |
| 497 | + |
| 498 | + block_size = 16 |
| 499 | + num_gpu_blocks = 2048 // block_size |
| 500 | + batch_size = 1 |
| 501 | + |
| 502 | + with global_force_attn_backend_context_manager(attn_backend): |
| 503 | + dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32' |
| 504 | + multi_step_worker = create_worker(MultiStepWorker, |
| 505 | + model_name, |
| 506 | + block_size, |
| 507 | + num_gpu_blocks, |
| 508 | + seed, |
| 509 | + model_runner_cls=TP1DraftModelRunner, |
| 510 | + dtype=dtype) |
| 511 | + multi_step_worker.set_include_gpu_probs_tensor() |
| 512 | + worker = create_worker(Worker, |
| 513 | + model_name, |
| 514 | + block_size, |
| 515 | + num_gpu_blocks, |
| 516 | + seed, |
| 517 | + dtype=dtype) |
| 518 | + |
| 519 | + prompts = [[0] for _ in range(batch_size)] |
| 520 | + # Already generate two tokens for the sequence |
| 521 | + # so that we can simulate the bonus token case |
| 522 | + multi_step_continuations = [[ |
| 523 | + random.randint(0, 1000), |
| 524 | + random.randint(0, 1000) |
| 525 | + ] for _ in prompts] |
| 526 | + final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts] |
| 527 | + |
| 528 | + seq_ids_with_bonus_token_in_last_step = set(range(batch_size)) |
| 529 | + seq_group_metadata_list = create_seq_group_metadata_from_prompts( |
| 530 | + prompts, |
| 531 | + num_gpu_blocks, |
| 532 | + block_size, |
| 533 | + continuations=multi_step_continuations, |
| 534 | + final_prompt_lens=final_prompt_lens) |
| 535 | + |
| 536 | + # Run multi-step. |
| 537 | + zero_kv_cache(multi_step_worker.cache_engine) |
| 538 | + multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest( |
| 539 | + seq_group_metadata_list=seq_group_metadata_list), |
| 540 | + sample_len=num_steps, |
| 541 | + seq_ids_with_bonus_token_in_last_step= |
| 542 | + seq_ids_with_bonus_token_in_last_step) |
| 543 | + |
| 544 | + # Run single-step repeatedly. |
| 545 | + zero_kv_cache(worker.cache_engine) |
| 546 | + # Generate the kv cache for the bonus token first |
| 547 | + single_step_continuations = [c[:1] for c in multi_step_continuations] |
| 548 | + seq_group_metadata_list = create_seq_group_metadata_from_prompts( |
| 549 | + prompts, |
| 550 | + num_gpu_blocks, |
| 551 | + block_size, |
| 552 | + continuations=single_step_continuations, |
| 553 | + final_prompt_lens=final_prompt_lens) |
| 554 | + single_step_output = worker.execute_model( |
| 555 | + execute_model_req=ExecuteModelRequest( |
| 556 | + seq_group_metadata_list=seq_group_metadata_list)) |
| 557 | + for _ in range(num_steps): |
| 558 | + seq_group_metadata_list = create_seq_group_metadata_from_prompts( |
| 559 | + prompts, |
| 560 | + num_gpu_blocks, |
| 561 | + block_size, |
| 562 | + continuations=multi_step_continuations, |
| 563 | + final_prompt_lens=final_prompt_lens) |
| 564 | + |
| 565 | + single_step_output = worker.execute_model( |
| 566 | + execute_model_req=ExecuteModelRequest( |
| 567 | + seq_group_metadata_list=seq_group_metadata_list)) |
| 568 | + |
| 569 | + for i, seq_group_output in enumerate(single_step_output[-1]): |
| 570 | + multi_step_continuations[i].append( |
| 571 | + seq_group_output.samples[0].output_token) |
| 572 | + |
| 573 | + # Verify that the KV cache of the single-step and |
| 574 | + # multi-step workers are the same. |
| 575 | + single_step_gpu_cache = worker.cache_engine[0].gpu_cache |
| 576 | + multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache |
| 577 | + num_layers = len(single_step_gpu_cache) |
| 578 | + allclose = lambda a, b: torch.allclose( |
| 579 | + a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2) |
| 580 | + for i in range(num_layers): |
| 581 | + assert allclose(single_step_gpu_cache[i][0], |
| 582 | + multi_step_gpu_cache[i][0]) |
| 583 | + assert allclose(single_step_gpu_cache[i][1], |
| 584 | + multi_step_gpu_cache[i][1]) |
| 585 | + |
| 586 | + |
480 | 587 | @torch.inference_mode() |
481 | 588 | def test_draft_proposals_full_speculation_len(): |
482 | 589 | """Verify Top1Proposer correctly handles case where all sequences |
|
0 commit comments