Skip to content

Commit 91e4174

Browse files
committed
Support multiple attention groups for KV sharing (vllm-project#22672)
Summary: vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec. As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used. However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion: ``` assert len(attn_groups[group_idx]) == 1, ( "Only one attention group per KV cache group is supported " "for KV-cache sharing for now.") ``` This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group. Test Plan: new added unit test passes: ``` pytest tests/v1/test_kv_sharing.py ``` Rollback Plan: Differential Revision: D80020191 Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 7f89ed2 commit 91e4174

File tree

2 files changed

+213
-16
lines changed

2 files changed

+213
-16
lines changed

tests/v1/test_kv_sharing.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from unittest.mock import Mock
5+
6+
import torch
7+
8+
from vllm.v1.attention.backends.flash_attn import (
9+
FlashAttentionBackend, FlashAttentionMetadataBuilder)
10+
from vllm.v1.attention.backends.flex_attention import (
11+
FlexAttentionBackend, FlexAttentionMetadataBuilder)
12+
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec
13+
from vllm.v1.worker.utils import (AttentionGroup,
14+
initialize_kv_cache_for_kv_sharing)
15+
16+
17+
def new_kv_cache_spec():
18+
return FullAttentionSpec(16, 1, 1, torch.float32, False)
19+
20+
21+
def test_initialize_kv_cache_for_kv_sharing_different_attn_groups():
22+
"""
23+
Test initializing KV cache sharing with different attention groups.
24+
Layers in the same KV cache group might be placed in different attn groups
25+
if they have different attention backends.
26+
"""
27+
shared_kv_cache_layers = {
28+
"model.layers.2": "model.layers.0",
29+
"model.layers.3": "model.layers.1",
30+
}
31+
32+
# Layers 0 and 1 both belong in KV cache group 0
33+
# However, if they have have different attention backends, they will be
34+
# placed in different attention groups for KV cache group 0
35+
kv_cache_groups = [
36+
KVCacheGroupSpec(["model.layers.0", "model.layers.1"],
37+
new_kv_cache_spec()),
38+
]
39+
40+
attn_groups = [
41+
# KV cache group 0 has two attention groups
42+
[
43+
AttentionGroup(
44+
backend=FlashAttentionBackend,
45+
metadata_builder=Mock(spec=FlashAttentionMetadataBuilder),
46+
layer_names=["model.layers.0"],
47+
),
48+
AttentionGroup(
49+
backend=FlexAttentionBackend,
50+
metadata_builder=Mock(spec=FlexAttentionMetadataBuilder),
51+
layer_names=["model.layers.1"],
52+
),
53+
],
54+
]
55+
56+
# Only layers 0 and 1 will have KV caches allocated
57+
kv_caches = {
58+
"model.layers.0": torch.zeros(1, 2, 3),
59+
"model.layers.1": torch.ones(1, 2, 3),
60+
}
61+
62+
initialize_kv_cache_for_kv_sharing(
63+
shared_kv_cache_layers=shared_kv_cache_layers,
64+
kv_cache_groups=kv_cache_groups,
65+
kv_caches=kv_caches,
66+
attn_groups=attn_groups,
67+
)
68+
69+
# Check that the KV caches were shared correctly
70+
assert kv_caches["model.layers.2"].data_ptr(
71+
) == kv_caches["model.layers.0"].data_ptr()
72+
assert kv_caches["model.layers.3"].data_ptr(
73+
) == kv_caches["model.layers.1"].data_ptr()
74+
75+
# Check that the layers were added to the correct KV cache group
76+
assert len(kv_cache_groups) == 1
77+
assert kv_cache_groups[0].layer_names == [
78+
"model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
79+
]
80+
81+
# Check that the layers were added to the attention groups
82+
assert len(attn_groups) == 1 and len(attn_groups[0]) == 2
83+
assert attn_groups[0][0].layer_names == [
84+
"model.layers.0", "model.layers.2"
85+
]
86+
assert attn_groups[0][1].layer_names == [
87+
"model.layers.1", "model.layers.3"
88+
]
89+
90+
91+
def test_initialize_kv_cache_for_kv_sharing_same_attn_groups():
92+
"""
93+
Test case assuming that all layers in the same KV cache group have the same
94+
attention backends. This is true for most models.
95+
"""
96+
shared_kv_cache_layers = {
97+
"model.layers.2": "model.layers.0",
98+
"model.layers.3": "model.layers.1",
99+
}
100+
101+
kv_cache_groups = [
102+
KVCacheGroupSpec(["model.layers.0", "model.layers.1"],
103+
new_kv_cache_spec()),
104+
]
105+
106+
attn_groups = [
107+
# KV cache group 0 has a single attention group
108+
# as all layers have the same flash attention backend
109+
[
110+
AttentionGroup(
111+
backend=FlashAttentionBackend,
112+
metadata_builder=Mock(spec=FlashAttentionMetadataBuilder),
113+
layer_names=["model.layers.0", "model.layers.1"],
114+
),
115+
],
116+
]
117+
118+
kv_caches = {
119+
"model.layers.0": torch.zeros(1, 2, 3),
120+
"model.layers.1": torch.ones(1, 2, 3),
121+
}
122+
123+
initialize_kv_cache_for_kv_sharing(
124+
shared_kv_cache_layers=shared_kv_cache_layers,
125+
kv_cache_groups=kv_cache_groups,
126+
kv_caches=kv_caches,
127+
attn_groups=attn_groups,
128+
)
129+
130+
# Check that the KV caches were shared correctly
131+
assert kv_caches["model.layers.2"].data_ptr(
132+
) == kv_caches["model.layers.0"].data_ptr()
133+
assert kv_caches["model.layers.3"].data_ptr(
134+
) == kv_caches["model.layers.1"].data_ptr()
135+
136+
# Check that the layers were added to the correct KV cache group
137+
assert len(kv_cache_groups) == 1
138+
assert kv_cache_groups[0].layer_names == [
139+
"model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
140+
]
141+
142+
# Check that the layers were added to the attention groups
143+
assert len(attn_groups) == 1 and len(attn_groups[0]) == 1
144+
assert attn_groups[0][0].layer_names == [
145+
"model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
146+
]
147+
148+
149+
def test_initialize_kv_cache_for_kv_sharing_no_attn_groups():
150+
"""
151+
Test KV sharing set up when no attention groups are provided.
152+
This is the case for the TPU model runner, which doesn't have
153+
support for attention groups yet.
154+
"""
155+
shared_kv_cache_layers = {
156+
"model.layers.2": "model.layers.0",
157+
"model.layers.3": "model.layers.1",
158+
}
159+
160+
kv_cache_groups = [
161+
KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()),
162+
KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()),
163+
]
164+
165+
kv_caches = {
166+
"model.layers.0": torch.zeros(1, 2, 3),
167+
"model.layers.1": torch.ones(1, 2, 3),
168+
}
169+
170+
initialize_kv_cache_for_kv_sharing(
171+
shared_kv_cache_layers=shared_kv_cache_layers,
172+
kv_cache_groups=kv_cache_groups,
173+
kv_caches=kv_caches,
174+
)
175+
176+
# Check that the KV caches were shared correctly
177+
assert kv_caches["model.layers.2"].data_ptr(
178+
) == kv_caches["model.layers.0"].data_ptr()
179+
assert kv_caches["model.layers.3"].data_ptr(
180+
) == kv_caches["model.layers.1"].data_ptr()
181+
182+
# Check that the layers were added to the correct KV cache group
183+
assert len(kv_cache_groups) == 2
184+
assert kv_cache_groups[0].layer_names == [
185+
"model.layers.0", "model.layers.2"
186+
]
187+
assert kv_cache_groups[1].layer_names == [
188+
"model.layers.1", "model.layers.3"
189+
]

vllm/v1/worker/utils.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -225,26 +225,34 @@ def initialize_kv_cache_for_kv_sharing(
225225
Note that layers in shared_kv_cache_layers.keys() are not
226226
originally included as it only contains layers which have its own
227227
KV cache allocation.
228+
attn_groups: Optional list of attention groups. Layers in the same KV
229+
cache group may be placed in different attention groups if they
230+
have different attention backends. Currently only provided by
231+
GPU model runner.
228232
"""
229-
# Record index of KV cache group for each layer that allocates a KV cache.
230-
layer_to_kv_cache_group_idx: dict[str, int] = {}
231-
for i, kv_cache_group in enumerate(kv_cache_groups):
232-
for layer_name in kv_cache_group.layer_names:
233-
layer_to_kv_cache_group_idx[layer_name] = i
233+
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
234+
layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
235+
if attn_groups:
236+
for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
237+
for attn_group_idx, attn_group in enumerate(kv_attn_groups):
238+
for layer_name in attn_group.layer_names:
239+
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
240+
attn_group_idx)
241+
else:
242+
for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
243+
for layer_name in kv_cache_group.layer_names:
244+
# attn group idx default to 0 if not provided
245+
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
234246

235247
for layer_name, target_layer_name in shared_kv_cache_layers.items():
236248
kv_caches[layer_name] = kv_caches[target_layer_name]
237-
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
238-
kv_cache_groups[group_idx].layer_names.append(layer_name)
239-
240-
if attn_groups is not None:
241-
assert len(attn_groups[group_idx]) == 1, (
242-
"Only one attention group per KV cache group is supported "
243-
"for KV-cache sharing for now.")
244-
# TODO(lucas): I think in the future the layers that re-use a
245-
# KV cache will be in a different attention group so we can
246-
# remove this code from here.
247-
attn_groups[group_idx][0].layer_names.append(layer_name)
249+
kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
250+
kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)
251+
252+
if attn_groups:
253+
attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
254+
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
255+
layer_name)
248256

249257

250258
def bind_kv_cache(

0 commit comments

Comments
 (0)