Skip to content

[Bugfix] get_num_blocks_to_allocate with null_block #19031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/v1/core/test_specialized_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,26 @@ def assert_block_id(block_table, ids):
# of removed blocks should be [1003, 1002].
manager.remove_skipped_blocks("test", 11)
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])


def test_get_num_blocks_to_allocate():
block_size = 2
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4, # Placeholder value, not related to test result
use_mla=False,
)

block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)
] + [KVCacheBlock(i + 1) for i in range(5)]

assert manager.get_num_blocks_to_allocate("1", 20 * block_size,
cached_blocks_1) == 20
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
cached_blocks_2) == 15
5 changes: 3 additions & 2 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
self.null_block.is_null = True

self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
Expand Down Expand Up @@ -251,7 +252,7 @@ def touch(self, blocks: list[KVCacheBlock]) -> None:
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and block != self.null_block:
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()

Expand All @@ -266,7 +267,7 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
for block in ordered_blocks:
block.decr_ref()
# null_block should not be added to the free list.
if block.ref_cnt == 0 and block != self.null_block:
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.append(block)

def reset_prefix_cache(self) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class KVCacheBlock:
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None

# Whether the block is a null block that should never be cached.
is_null: bool = False

def incr_ref(self):
self.ref_cnt += 1

Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def get_num_blocks_to_allocate(
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(blk.ref_cnt == 0
for blk in new_computed_blocks)
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks)
return ((num_new_blocks + num_evictable_computed_blocks) *
self.num_kv_cache_groups)

Expand Down