@@ -477,3 +477,70 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
477
477
assert expected_token_ids == actual_token_ids
478
478
479
479
assert baseline_token_ids == test_token_ids
480
+
481
+
482
+ @pytest .mark .parametrize (
483
+ "common_llm_kwargs" ,
484
+ [{
485
+ # Use a small model for a fast test.
486
+ "model" : "facebook/opt-125m" ,
487
+
488
+ # skip cuda graph creation for fast test.
489
+ "enforce_eager" : True ,
490
+
491
+ # we keep the blocks small, so that hit eviction quickly
492
+ "max_model_len" : 48 ,
493
+ "block_size" : 16 ,
494
+ "num_gpu_blocks_override" : 3 ,
495
+
496
+ # Test APC in v2 block
497
+ "use_v2_block_manager" : True ,
498
+ }])
499
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
500
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{
501
+ "enable_prefix_caching" : False
502
+ }])
503
+ @pytest .mark .parametrize ("test_llm_kwargs" , [{
504
+ "enable_prefix_caching" : True ,
505
+ }])
506
+ @pytest .mark .parametrize ("seed" , [1 ])
507
+ def test_auto_prefix_caching_after_evition_start (baseline_llm_generator ,
508
+ test_llm_generator ):
509
+ """Verify block manager v2 with auto prefix caching could works normal
510
+ even when eviction started.
511
+ With APC enabled, all blocks are held by native block at the beginning.
512
+ Then blocks are managed by evictor instead. If cache hit at the evitor's
513
+ block, then it could be reused, or we need to recompute its kv cache.
514
+ """
515
+ output_len = 10
516
+ temperature = 0.0
517
+
518
+ prompts = [
519
+ "You are a helpful assistant. Please answer truthfully and write "
520
+ "out your thinking step by step to be sure you get the right answer. "
521
+ "If you make a mistake, attempt to correct it. who are you?" ,
522
+ "You are a helpful assistant. Please answer truthfully and write out "
523
+ "your thinking step by step to be sure you get the right answer. You "
524
+ "are helpful and harmless and you follow ethical guidelines. "
525
+ "who are you?"
526
+ ]
527
+
528
+ sampling_params = SamplingParams (
529
+ max_tokens = output_len ,
530
+ ignore_eos = True ,
531
+ temperature = temperature ,
532
+ )
533
+
534
+ print ('Getting token ids with APC disabled' )
535
+ baseline_token_ids = get_token_ids_from_llm_generator (
536
+ baseline_llm_generator , prompts , sampling_params )
537
+
538
+ print ('Getting token ids with APC enabled' )
539
+ test_token_ids = get_token_ids_from_llm_generator (test_llm_generator ,
540
+ prompts , sampling_params )
541
+
542
+ for expected_token_ids , actual_token_ids in zip (baseline_token_ids ,
543
+ test_token_ids ):
544
+ assert expected_token_ids == actual_token_ids
545
+
546
+ assert baseline_token_ids == test_token_ids
0 commit comments