Skip to content

[Core] Use tuple for kv cache group block ids #19175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )

# Check full block metadata
parent_block_hash = None
Expand All @@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2

Expand Down Expand Up @@ -175,13 +175,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[6]]
assert blocks.get_block_ids() == ([6], )

# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
Expand All @@ -205,7 +205,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
Expand Down Expand Up @@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
8], [9, 10, 11, 12])

# Check full block metadata
parent_block_hash = None
Expand All @@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
[0, 10, 11]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
7], [0, 10, 11])
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[13], [14], [15]]
assert blocks.get_block_ids() == ([13], [14], [15])
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
Expand Down Expand Up @@ -374,7 +374,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]

# Check full block metadata
Expand All @@ -400,13 +400,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2

Expand Down Expand Up @@ -444,7 +444,7 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
assert block_ids != [[1, 2, 3, 4]]
assert block_ids != ([1, 2, 3, 4], )

# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
Expand Down Expand Up @@ -474,7 +474,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )

# Append slots without allocating a new block.
req0.num_computed_tokens = 55
Expand Down Expand Up @@ -546,12 +546,12 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [[1, 2]]
assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[10]]
assert blocks.get_block_ids() == ([10], )
assert manager.block_pool.free_block_queue.num_free_blocks == 7


Expand Down Expand Up @@ -865,7 +865,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
Expand Down Expand Up @@ -926,7 +926,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )

unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
Expand All @@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )

# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
Expand Down
8 changes: 4 additions & 4 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[[0]], # block_ids should be list[list[int]]
block_ids=([0], ), # block_ids should be tuple[list[int]]
num_computed_tokens=0,
lora_request=None,
))
Expand Down Expand Up @@ -132,10 +132,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
# This is safe since we currently only use single KV cache groups
block_table = multi_group_block_table[0]

# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
# Extract the first group's block IDs
if isinstance(req_state.block_ids[0], list):
# New format: list[list[int]] - extract first group
# New format: tuple[list[int], ...] - extract first group
req_block_ids = req_state.block_ids[0]
else:
# Legacy format: list[int] - use directly
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[[]],
new_block_ids=([], ),
num_computed_tokens=0,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[[]],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[[0]],
block_ids=([0], ),
num_computed_tokens=0,
lora_request=None,
))
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[[]],
new_block_ids=([], ),
num_computed_tokens=0,
)

Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def get_cached_block(
BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group:
return None
first_block_id = next(iter(cached_blocks_one_group))
cached_blocks.append(cached_blocks_one_group[first_block_id])
first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(first_block)
return cached_blocks

def cache_full_blocks(
Expand Down Expand Up @@ -260,7 +260,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
return True
return False

def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Expand Down Expand Up @@ -299,7 +299,7 @@ def reset_prefix_cache(self) -> bool:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "
Expand Down
Loading