Skip to content

[V1] Support cross-layer KV sharing #18212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 226 additions & 1 deletion tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@

import pytest

from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.attention.layer import Attention
from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_config)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.worker.tpu_model_runner import (
Expand Down Expand Up @@ -362,3 +367,223 @@ def test_get_req_paddings():
assert _get_req_paddings(1, 32) == [8, 16, 32]
assert _get_req_paddings(8, 32) == [8, 16, 32]
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]


def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
kv_sharing_target_layer_name=layer_1,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None


def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name=invalid_layer,
)
}
# suppress var not used error
assert fwd_context is not None


def test_init_kv_cache_with_kv_sharing_target_same_as_current():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None


def test_init_kv_cache_without_kv_sharing(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(model_runner.shared_kv_cache_layers) == 0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 2
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
assert kv_cache_config.tensors[layer_1].size == available_memory // 2

max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720

# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config.num_blocks = 1
for layer in kv_cache_config.tensors:
kv_cache_config.tensors[layer].size =\
kv_cache_spec[layer].page_size_bytes

model_runner.initialize_kv_cache(kv_cache_config)

layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv)

# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1


def test_init_kv_cache_with_kv_sharing_valid(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
)
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.tensors[layer_0].size == available_memory

max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 2 * 1310720

# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config.num_blocks = 1
kv_cache_config.tensors[layer_0].size =\
kv_cache_spec[layer_0].page_size_bytes

model_runner.initialize_kv_cache(kv_cache_config)

layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv)

# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
Loading