Skip to content

[v1][KVCacheManager] Add a special KVCacheNullBlock class #18652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ def test_prefill(hash_algo):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id - 1].block_hash == block_hash
assert manager.block_pool.blocks[block_id - 1].ref_cnt == 1
parent_block_hash = block_hash.hash_value

# Check partial block metadata
for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id - 1].block_hash is None
assert manager.block_pool.blocks[block_id - 1].ref_cnt == 1

# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
Expand Down Expand Up @@ -217,14 +217,14 @@ def test_prefill_plp():
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id - 1].block_hash == block_hash
assert manager.block_pool.blocks[block_id - 1].ref_cnt == 1
parent_block_hash = block_hash.hash_value

# Check partial block metadata
for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id - 1].block_hash is None
assert manager.block_pool.blocks[block_id - 1].ref_cnt == 1

# Request #1 is a non-prompt-logprobs request:
# Cache hit in the common prefix when the original block is still in use.
Expand Down Expand Up @@ -282,7 +282,7 @@ def test_prefill_plp():
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids[0]:
assert manager.block_pool.blocks[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id - 1].ref_cnt == 1

manager.free(req2)

Expand Down Expand Up @@ -425,8 +425,8 @@ def test_hash_block_correct_reuse():
computed_blocks)
assert len(blocks.blocks) == 1

assert manager.block_pool.blocks[
blocks.blocks[0].block_id].block_hash is None
assert manager.block_pool.blocks[blocks.blocks[0].block_id -
1].block_hash is None


def test_computed_blocks_not_evicted():
Expand Down
4 changes: 3 additions & 1 deletion tests/v1/core/test_specialized_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def run_one_case(block_is_cached, expect_length):
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
# blocks[i+9] is with block_id=i+10, as null_block with
# block_id=0 is not included in block_pool.blocks.
block_pool.cached_block_hash_to_block[block_hash] = {
i: block_pool.blocks[i + 10]
i: block_pool.blocks[i + 9]
}

computed_blocks = manager.find_longest_cache_hit(
Expand Down
19 changes: 8 additions & 11 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
KVCacheBlock, KVCacheNullBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
Expand Down Expand Up @@ -37,9 +37,12 @@ def __init__(
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.

# A special placeholder block with block_id=0 and is never freed.
self.null_block = KVCacheNullBlock(0)
# All other kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
KVCacheBlock(idx) for idx in range(1, num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
Expand All @@ -58,11 +61,6 @@ def __init__(
self.cached_block_hash_to_block: dict[BlockHashType, dict[
int, KVCacheBlock]] = defaultdict(dict)

# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()

self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []

Expand Down Expand Up @@ -251,7 +249,7 @@ def touch(self, blocks: list[KVCacheBlock]) -> None:
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and block != self.null_block:
if block.ref_cnt == 0:
self.free_block_queue.remove(block)
block.incr_ref()

Expand All @@ -265,8 +263,7 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""
for block in ordered_blocks:
block.decr_ref()
# null_block should not be added to the free list.
if block.ref_cnt == 0 and block != self.null_block:
if block.ref_cnt == 0:
self.free_block_queue.append(block)

def reset_prefix_cache(self) -> bool:
Expand Down
22 changes: 22 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,28 @@ def __repr__(self) -> str:
f"next_free_block={next_block_id})")


class KVCacheNullBlock(KVCacheBlock):
"""A special KVCacheBlock with ref_cnt always equals 1 and thus will never
be freed."""

def __init__(self, block_id: int):
super().__init__(block_id, ref_cnt=1)

def incr_ref(self):
pass

def decr_ref(self):
pass

@property
def block_hash(self) -> Optional[BlockHashType]:
return None

@block_hash.setter
def block_hash(self, block_hash: BlockHashType):
raise ValueError("Should not set block_hash for null block.")


class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // self.block_size
computed_blocks = [self._null_block] * max_num_blocks
computed_blocks: list[KVCacheBlock] = [self._null_block
] * max_num_blocks
num_contiguous_blocks = 0

match_found = False
Expand Down