Skip to content

[TPU]Fix KV cache sharing tests #19371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 52 additions & 60 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import unittest.mock as mock

import pytest

Expand All @@ -17,24 +16,8 @@
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
_get_padded_token_len, _get_req_paddings, _get_token_paddings)

# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher = mock.patch.dict(
"sys.modules", {
"torch_xla": mock.MagicMock(),
"torch_xla.core.xla_model": mock.MagicMock(),
"torch_xla.runtime": mock.MagicMock(),
})
torch_xla_patcher.start()

# Mock the PallasAttentionBackend
pallas_attention_backend_patcher = mock.patch(
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", )
pallas_attention_backend_patcher.start()


@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
def get_vllm_config():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
Expand All @@ -60,18 +43,19 @@ def model_runner():
cache_config=cache_config,
scheduler_config=scheduler_config,
)
return vllm_config


def get_model_runner(vllm_config):
device = "xla:0" # Mocking TPU device
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
return TPUModelRunner(vllm_config, device)
return TPUModelRunner(vllm_config, device)


@pytest.fixture(autouse=True, scope="session")
def cleanup_patches():
yield
torch_xla_patcher.stop()
pallas_attention_backend_patcher.stop()
@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
vllm_config = get_vllm_config()
return get_model_runner(vllm_config)


def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
Expand Down Expand Up @@ -370,12 +354,14 @@ def test_get_req_paddings():
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(
model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
Expand All @@ -399,13 +385,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
Expand All @@ -428,12 +415,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
Expand All @@ -457,11 +445,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert fwd_context is not None


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_without_kv_sharing(model_runner):
def test_init_kv_cache_without_kv_sharing():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Expand All @@ -482,33 +469,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_config.model_config.max_model_len = 1_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(model_runner.shared_kv_cache_layers) == 0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
# page size for each layer KV can be calculated as
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
Comment on lines +480 to +482
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The updated comment explaining the page size calculation is much clearer and more detailed. This is a great improvement for understanding the test's assumptions!

num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 2
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2

max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720
# max_context_len = available_memory / (page_size / block_size) / num_caches
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
assert max_context_len == 655360

# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
# this will only allocate 2 block worth of memory (2 * 512kb)
kv_cache_config.num_blocks = 1
for layer in kv_cache_config.tensors:
kv_cache_config.tensors[layer].size =\
kv_cache_spec[layer].page_size_bytes
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
kv_cache_tensor.size = (
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
Comment on lines +501 to +503
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change from iterating kv_cache_config.tensors (which was a dict) to kv_cache_config.kv_cache_tensors (which is a list) and using kv_cache_tensor.shared_by[0] to get the layer name for kv_cache_spec looks correct given the structure of KVCacheConfig and KVCacheTensor. This aligns well with how KV cache tensors can be defined and potentially shared.


model_runner.initialize_kv_cache(kv_cache_config)

Expand All @@ -524,11 +516,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_valid(model_runner):
def test_init_kv_cache_with_kv_sharing_valid():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Expand All @@ -552,33 +543,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner):
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# page size for layer 0's kv_cache_spec is 512KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
num_expected_blocks = 2 * 20480 # 20GB / 512KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 1
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.tensors[layer_0].size == available_memory
assert kv_cache_config.kv_cache_tensors[0].size == available_memory

max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 2 * 1310720
assert max_context_len == (2 * 655360)

# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
# this will only allocate 1 block worth of memory (512kb)
kv_cache_config.num_blocks = 1
kv_cache_config.tensors[layer_0].size =\
kv_cache_config.kv_cache_tensors[0].size =\
kv_cache_spec[layer_0].page_size_bytes

model_runner.initialize_kv_cache(kv_cache_config)
Expand Down