11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import importlib
4- from typing import Optional
4+ from typing import Callable , Optional
55
66import pytest
77import torch
1919 FreeKVCacheBlockQueue , KVCacheBlock , PrefixCachingMetrics ,
2020 estimate_max_model_len , generate_block_hash_extra_keys ,
2121 get_kv_cache_config , get_max_concurrency_for_kv_cache_config ,
22- hash_block_tokens , hash_request_tokens , init_none_hash ,
22+ get_request_block_hasher , hash_block_tokens , init_none_hash ,
2323 is_kv_cache_type_uniform , unify_kv_cache_configs )
2424from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
2525 KVCacheGroupSpec , KVCacheTensor ,
3333def make_request (
3434 request_id : str ,
3535 prompt_token_ids : list [int ],
36+ block_size : int = 3 ,
37+ hash_fn : Callable = hash ,
3638 mm_positions : Optional [list [PlaceholderRange ]] = None ,
3739 mm_hashes : Optional [list [str ]] = None ,
3840 cache_salt : Optional [str ] = None ,
@@ -49,18 +51,17 @@ def make_request(
4951 mm_item = MultiModalKwargsItem .from_elems ([mm_elem ])
5052 mm_kwargs = [mm_item ] * len (mm_positions )
5153
52- return Request (
53- request_id = request_id ,
54- prompt_token_ids = prompt_token_ids ,
55- multi_modal_kwargs = mm_kwargs ,
56- multi_modal_hashes = mm_hashes ,
57- multi_modal_placeholders = mm_positions ,
58- sampling_params = SamplingParams (max_tokens = 17 ),
59- pooling_params = None ,
60- eos_token_id = 100 ,
61- lora_request = None ,
62- cache_salt = cache_salt ,
63- )
54+ return Request (request_id = request_id ,
55+ prompt_token_ids = prompt_token_ids ,
56+ multi_modal_kwargs = mm_kwargs ,
57+ multi_modal_hashes = mm_hashes ,
58+ multi_modal_placeholders = mm_positions ,
59+ sampling_params = SamplingParams (max_tokens = 17 ),
60+ pooling_params = None ,
61+ eos_token_id = 100 ,
62+ lora_request = None ,
63+ cache_salt = cache_salt ,
64+ block_hasher = get_request_block_hasher (block_size , hash_fn ))
6465
6566
6667def new_kv_cache_spec (block_size = 16 ,
@@ -428,22 +429,22 @@ def test_hash_block_tokens(hash_fn):
428429
429430
430431@pytest .mark .parametrize ("hash_fn" , [sha256 , sha256_cbor_64bit , hash ])
431- def test_hash_request_tokens (hash_fn ):
432+ def test_request_block_hasher (hash_fn ):
432433 import vllm .v1 .core .kv_cache_utils
433434 init_none_hash (hash_fn )
434435 request = make_request (
435436 request_id = "0" ,
436437 prompt_token_ids = [_ for _ in range (6 )],
438+ block_size = 3 ,
439+ hash_fn = hash_fn ,
437440 mm_positions = [
438441 PlaceholderRange (offset = 0 , length = 3 ),
439442 PlaceholderRange (offset = 3 , length = 3 ),
440443 ],
441444 mm_hashes = ["hash1" , "hash2" ],
442445 )
443446
444- block_size = 3
445- block_hashes = hash_request_tokens (hash_fn , block_size , request )
446-
447+ block_hashes = request .block_hashes
447448 assert len (block_hashes ) == 2
448449 assert isinstance (block_hashes [0 ], vllm .v1 .core .kv_cache_utils .BlockHash )
449450 assert isinstance (block_hashes [1 ], vllm .v1 .core .kv_cache_utils .BlockHash )
@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
464465 request1 = make_request (
465466 request_id = "0" ,
466467 prompt_token_ids = [_ for _ in range (6 )],
468+ block_size = 3 ,
469+ hash_fn = hash_fn ,
467470 mm_positions = [
468471 PlaceholderRange (offset = 0 , length = 3 ),
469472 PlaceholderRange (offset = 3 , length = 3 ),
@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
479482 ],
480483 mm_hashes = ["hash3" , "hash2" ],
481484 )
482- block_size = 3
483- block_hashes1 = hash_request_tokens (hash_fn , block_size , request1 )
484- block_hashes2 = hash_request_tokens (hash_fn , block_size , request2 )
485+ block_hashes1 = request1 .block_hashes
486+ block_hashes2 = request2 .block_hashes
485487 assert block_hashes1 [0 ] != block_hashes2 [0 ]
486488 assert block_hashes1 [1 ] != block_hashes2 [1 ]
487489
@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
493495 request = make_request (
494496 request_id = "0" ,
495497 prompt_token_ids = [_ for _ in range (6 )],
498+ block_size = 3 ,
499+ hash_fn = hash_fn ,
496500 mm_positions = None ,
497501 mm_hashes = None ,
498502 )
499503
500- block_size = 3
501- block_hashes = hash_request_tokens (hash_fn , block_size , request )
504+ block_hashes = request .block_hashes
502505
503506 assert len (block_hashes ) == 2
504507 assert block_hashes [0 ].token_ids == (0 , 1 , 2 )
@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
858861 request = make_request (
859862 request_id = "0" ,
860863 prompt_token_ids = [],
864+ block_size = block_size ,
861865 mm_positions = None ,
862866 mm_hashes = None ,
863867 )
0 commit comments