Skip to content

[v1][KVCacheManager] Avoid full cache hit by controlling max_length #17999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/v1/core/test_specialized_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):


def test_sliding_window_possible_cached_prefix():
block_size = 2
sliding_window_spec = SlidingWindowSpec(
block_size=2,
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
Expand All @@ -44,7 +45,9 @@ def run_one_case(block_is_cached, expect_length):
i: block_pool.blocks[i + 10]
}

computed_blocks = manager.find_longest_cache_hit(block_hash_list)
computed_blocks = manager.find_longest_cache_hit(
block_hash_list,
len(block_hash_list) * block_size)
assert len(computed_blocks) == expect_length

assert all(block == block_pool.null_block
Expand Down
31 changes: 10 additions & 21 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,16 @@ def get_computed_blocks(self,
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1

if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None

computed_blocks = (
self.single_type_manager.find_longest_cache_hit(block_hashes))
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
# the single last token, because allocate_slots() requires
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1

computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes, max_cache_hit_length)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
Expand All @@ -171,12 +166,6 @@ def get_computed_blocks(self,
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_computed_tokens

if last_block_hash is not None:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)

return KVCacheBlocks(computed_blocks), num_computed_tokens

def allocate_slots(
Expand Down
37 changes: 21 additions & 16 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,19 @@ def get_num_common_prefix_blocks(self, request_id: str,
raise NotImplementedError

@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
max_length: int) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list. if eagle is enabled, drop the last matched
block to force recompute the last block to get the required hidden
states for eagle drafting head. Need to be customized for each attention
type.
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. If no cache hit is found, return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.

Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.

Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
Expand Down Expand Up @@ -226,10 +228,12 @@ def remove_skipped_blocks(self, request_id: str,

class FullAttentionManager(SingleTypeKVCacheManager):

def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
max_length: int) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes:
max_num_blocks = max_length // self.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
Expand Down Expand Up @@ -276,19 +280,20 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
self.sliding_window_contiguous_blocks += 1
self._null_block = block_pool.null_block

def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
max_length: int) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes)
max_num_blocks = max_length // self.block_size
computed_blocks = [self._null_block] * max_num_blocks
num_contiguous_blocks = 0

match_found = False
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
Expand Down