Skip to content

Commit cf5f000

Browse files
authored
[torch.compile] Hide KV cache behind torch.compile boundary (#11677)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent 3de2b1e commit cf5f000

18 files changed

+198
-44
lines changed

tests/kernels/test_encoder_decoder_attn.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,18 @@ class that Attention will automatically select when it is constructed.
142142
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
143143

144144
# Construct KV cache
145-
kv_cache = make_kv_cache(test_pt.num_blocks,
146-
test_pt.num_heads,
147-
test_pt.head_size,
148-
test_pt.block_size,
149-
device=CUDA_DEVICE,
150-
backend=test_pt.backend_name)
145+
if test_pt.attn_type in (AttentionType.DECODER,
146+
AttentionType.ENCODER_DECODER):
147+
kv_cache = make_kv_cache(test_pt.num_blocks,
148+
test_pt.num_heads,
149+
test_pt.head_size,
150+
test_pt.block_size,
151+
device=CUDA_DEVICE,
152+
backend=test_pt.backend_name)
153+
else:
154+
kv_cache = torch.tensor([])
155+
156+
attn.kv_cache = [kv_cache]
151157
return TestResources(scale, attn, kv_cache)
152158

153159

tests/test_utils.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import torch
88
from vllm_test_utils import monitor
99

10+
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
1011
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
11-
StoreBoolean, deprecate_kwargs, get_open_port,
12-
memory_profiling, merge_async_iterators, supports_kw)
12+
StoreBoolean, bind_kv_cache, deprecate_kwargs,
13+
get_open_port, memory_profiling, merge_async_iterators,
14+
supports_kw)
1315

1416
from .utils import error_on_warning, fork_new_process_for_each_test
1517

@@ -325,6 +327,85 @@ def measure_current_non_torch():
325327
lib.cudaFree(handle2)
326328

327329

330+
def test_bind_kv_cache():
331+
from vllm.attention import Attention
332+
333+
ctx = {
334+
'layers.0.self_attn': Attention(32, 128, 0.1),
335+
'layers.1.self_attn': Attention(32, 128, 0.1),
336+
'layers.2.self_attn': Attention(32, 128, 0.1),
337+
'layers.3.self_attn': Attention(32, 128, 0.1),
338+
}
339+
kv_cache = [
340+
torch.zeros((1, )),
341+
torch.zeros((1, )),
342+
torch.zeros((1, )),
343+
torch.zeros((1, )),
344+
]
345+
bind_kv_cache(ctx, [kv_cache])
346+
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
347+
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
348+
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
349+
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
350+
351+
def test_bind_kv_cache_non_attention():
352+
from vllm.attention import Attention
353+
354+
# example from Jamba PP=2
355+
ctx = {
356+
'model.layers.20.attn': Attention(32, 128, 0.1),
357+
'model.layers.28.attn': Attention(32, 128, 0.1),
358+
}
359+
kv_cache = [
360+
torch.zeros((1, )),
361+
torch.zeros((1, )),
362+
]
363+
bind_kv_cache(ctx, [kv_cache])
364+
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
365+
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
366+
367+
368+
def test_bind_kv_cache_encoder_decoder():
369+
from vllm.attention import Attention, AttentionType
370+
371+
# example from bart
372+
ctx = {
373+
'encoder.layers.0.self_attn.attn':
374+
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
375+
'decoder.layers.0.encoder_attn.attn':
376+
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
377+
'decoder.layers.0.self_attn.attn':
378+
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
379+
}
380+
381+
kv_cache = [
382+
torch.zeros((1, )),
383+
]
384+
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
385+
386+
bind_kv_cache(ctx, [kv_cache])
387+
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
388+
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
389+
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
390+
391+
392+
def test_bind_kv_cache_pp():
393+
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
394+
with set_current_vllm_config(cfg):
395+
from vllm.attention import Attention
396+
397+
ctx = {
398+
'layers.0.self_attn': Attention(32, 128, 0.1),
399+
}
400+
kv_cache = [
401+
[torch.zeros((1, ))],
402+
[torch.zeros((1, ))]
403+
]
404+
bind_kv_cache(ctx, kv_cache)
405+
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
406+
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
407+
408+
328409
def test_placeholder_module_error_handling():
329410
placeholder = PlaceholderModule("placeholder_1234")
330411

tests/v1/engine/test_engine_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from transformers import AutoTokenizer
66

7+
from tests.utils import fork_new_process_for_each_test
78
from vllm import SamplingParams
89
from vllm.engine.arg_utils import EngineArgs
910
from vllm.platforms import current_platform
@@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
3637
)
3738

3839

40+
@fork_new_process_for_each_test
3941
def test_engine_core(monkeypatch):
4042

4143
with monkeypatch.context() as m:
@@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
138140
assert len(engine_core.scheduler.running) == 0
139141

140142

143+
@fork_new_process_for_each_test
141144
def test_engine_core_advanced_sampling(monkeypatch):
142145
"""
143146
A basic end-to-end test to verify that the engine functions correctly

tests/v1/engine/test_engine_core_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
from transformers import AutoTokenizer
88

9+
from tests.utils import fork_new_process_for_each_test
910
from vllm import SamplingParams
1011
from vllm.engine.arg_utils import EngineArgs
1112
from vllm.platforms import current_platform
@@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
7576
break
7677

7778

79+
@fork_new_process_for_each_test
7880
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
7981
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
8082

@@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
143145
client.abort_requests([request.request_id])
144146

145147

148+
@fork_new_process_for_each_test
146149
@pytest.mark.asyncio
147150
async def test_engine_core_client_asyncio(monkeypatch):
148151

vllm/attention/layer.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ def __init__(
121121
compilation_config.static_forward_context[prefix] = self
122122
self.layer_name = prefix
123123
self.attn_type = attn_type
124+
# use a placeholder kv cache tensor during init, which will be replaced
125+
# by bind_kv_cache
126+
# this variable will not be accessed if use_direct_call is True
127+
self.kv_cache = [
128+
torch.tensor([]) for _ in range(get_current_vllm_config(
129+
).parallel_config.pipeline_parallel_size)
130+
]
124131

125132
def forward(
126133
self,
@@ -148,11 +155,11 @@ def forward(
148155
if value is not None:
149156
value = value.view(-1, self.num_kv_heads, self.head_size)
150157
torch.ops.vllm.unified_attention_with_output(
151-
query, key, value, output, kv_cache, self.layer_name)
158+
query, key, value, output, self.layer_name)
152159
return output.view(-1, hidden_size)
153160
else:
154161
return torch.ops.vllm.unified_attention(query, key, value,
155-
kv_cache, self.layer_name)
162+
self.layer_name)
156163

157164
def extra_repr(self) -> str:
158165
s = f"head_size={self.impl.head_size}" # type: ignore
@@ -230,12 +237,12 @@ def unified_attention(
230237
query: torch.Tensor,
231238
key: torch.Tensor,
232239
value: torch.Tensor,
233-
kv_cache: torch.Tensor,
234240
layer_name: str,
235241
) -> torch.Tensor:
236242
forward_context: ForwardContext = get_forward_context()
237-
attn_metadata = forward_context.dynamic_forward_context
238-
self = forward_context.static_forward_context[layer_name]
243+
attn_metadata = forward_context.attn_metadata
244+
self = forward_context.attn_layers[layer_name]
245+
kv_cache = self.kv_cache[forward_context.virtual_engine]
239246
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
240247
self._k_scale, self._v_scale)
241248

@@ -244,7 +251,6 @@ def unified_attention_fake(
244251
query: torch.Tensor,
245252
key: torch.Tensor,
246253
value: torch.Tensor,
247-
kv_cache: torch.Tensor,
248254
layer_name: str,
249255
) -> torch.Tensor:
250256
return torch.empty_like(query).contiguous()
@@ -253,7 +259,7 @@ def unified_attention_fake(
253259
direct_register_custom_op(
254260
op_name="unified_attention",
255261
op_func=unified_attention,
256-
mutates_args=["kv_cache"],
262+
mutates_args=[],
257263
fake_impl=unified_attention_fake,
258264
dispatch_key=current_platform.dispatch_key,
259265
)
@@ -264,12 +270,12 @@ def unified_attention_with_output(
264270
key: torch.Tensor,
265271
value: torch.Tensor,
266272
output: torch.Tensor,
267-
kv_cache: torch.Tensor,
268273
layer_name: str,
269274
) -> None:
270275
forward_context: ForwardContext = get_forward_context()
271-
attn_metadata = forward_context.dynamic_forward_context
272-
self = forward_context.static_forward_context[layer_name]
276+
attn_metadata = forward_context.attn_metadata
277+
self = forward_context.attn_layers[layer_name]
278+
kv_cache = self.kv_cache[forward_context.virtual_engine]
273279
self.impl.forward(query,
274280
key,
275281
value,
@@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
285291
key: torch.Tensor,
286292
value: torch.Tensor,
287293
output: torch.Tensor,
288-
kv_cache: torch.Tensor,
289294
layer_name: str,
290295
) -> None:
291296
return
@@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
294299
direct_register_custom_op(
295300
op_name="unified_attention_with_output",
296301
op_func=unified_attention_with_output,
297-
mutates_args=["kv_cache", "output"],
302+
mutates_args=["output"],
298303
fake_impl=unified_attention_with_output_fake,
299304
dispatch_key=current_platform.dispatch_key,
300305
)

vllm/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2780,7 +2780,6 @@ def model_post_init(self, __context: Any) -> None:
27802780
compilation_time: float = PrivateAttr
27812781

27822782
# Per-model forward context
2783-
# Mainly used to store attention cls
27842783
# Map from layer name to the attention cls
27852784
static_forward_context: Dict[str, Any] = PrivateAttr
27862785

vllm/forward_context.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
from collections import defaultdict
33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import Any, Dict, Optional
5+
from typing import TYPE_CHECKING, Any, Dict, Optional
66

77
import torch
88

99
import vllm.envs as envs
1010
from vllm.config import VllmConfig
1111
from vllm.logger import init_logger
1212

13+
if TYPE_CHECKING:
14+
from vllm.attention.backends.abstract import AttentionMetadata
15+
1316
logger = init_logger(__name__)
1417

1518
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
@@ -21,9 +24,12 @@
2124

2225
@dataclass
2326
class ForwardContext:
24-
static_forward_context: Dict[str, Any]
27+
# copy from vllm_config.compilation_config.static_forward_context
28+
attn_layers: Dict[str, Any]
2529
# TODO: extend to support per-layer dynamic forward context
26-
dynamic_forward_context: Any
30+
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
31+
# TODO: remove after making all virtual_engines share the same kv cache
32+
virtual_engine: int # set dynamically for each forward pass
2733

2834

2935
_forward_context: Optional[ForwardContext] = None
@@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
3844

3945

4046
@contextmanager
41-
def set_forward_context(context: Any, vllm_config: VllmConfig):
47+
def set_forward_context(attn_metadata: Any,
48+
vllm_config: VllmConfig,
49+
virtual_engine: int = 0):
4250
"""A context manager that stores the current forward context,
4351
can be attention metadata, etc.
4452
Here we can inject common logic for every model forward pass.
4553
"""
4654
global forward_start_time
47-
need_to_track_batchsize = track_batchsize and context is not None
55+
need_to_track_batchsize = track_batchsize and attn_metadata is not None
4856
if need_to_track_batchsize:
4957
forward_start_time = time.perf_counter()
5058
global _forward_context
5159
prev_context = _forward_context
5260
_forward_context = ForwardContext(
53-
static_forward_context=vllm_config.compilation_config.
54-
static_forward_context,
55-
dynamic_forward_context=context)
61+
attn_layers=vllm_config.compilation_config.static_forward_context,
62+
virtual_engine=virtual_engine,
63+
attn_metadata=attn_metadata)
5664
try:
5765
yield
5866
finally:
59-
global batchsize_counter
6067
global last_logging_time, batchsize_logging_interval
6168
if need_to_track_batchsize:
62-
if hasattr(context, "num_prefill_tokens"):
69+
if hasattr(attn_metadata, "num_prefill_tokens"):
6370
# for v0 attention backends
64-
batchsize = context.num_prefill_tokens + \
65-
context.num_decode_tokens
71+
batchsize = attn_metadata.num_prefill_tokens + \
72+
attn_metadata.num_decode_tokens
6673
else:
6774
# for v1 attention backends
68-
batchsize = context.num_input_tokens
75+
batchsize = attn_metadata.num_input_tokens
6976
# we use synchronous scheduling right now,
7077
# adding a sync point here should not affect
7178
# scheduling of the next batch

vllm/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,3 +2138,38 @@ def get_mp_context():
21382138
_check_multiproc_method()
21392139
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
21402140
return multiprocessing.get_context(mp_method)
2141+
2142+
2143+
def bind_kv_cache(
2144+
ctx: Dict[str, Any],
2145+
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
2146+
) -> None:
2147+
# Bind the kv_cache tensor to Attention modules, similar to
2148+
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
2149+
# Special things handled here:
2150+
# 1. Some models have non-attention layers, e.g., Jamba
2151+
# 2. Pipeline parallelism, each rank only has a subset of layers
2152+
# 3. Encoder attention has no kv cache
2153+
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
2154+
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
2155+
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
2156+
# tensor
2157+
from vllm.attention import AttentionType
2158+
from vllm.model_executor.models.utils import extract_layer_index
2159+
layer_need_kv_cache = [
2160+
layer_name for layer_name in ctx
2161+
if ctx[layer_name].attn_type in (AttentionType.DECODER,
2162+
AttentionType.ENCODER_DECODER)
2163+
]
2164+
layer_index_sorted = sorted(
2165+
set(
2166+
extract_layer_index(layer_name)
2167+
for layer_name in layer_need_kv_cache))
2168+
for layer_name in layer_need_kv_cache:
2169+
kv_cache_idx = layer_index_sorted.index(
2170+
extract_layer_index(layer_name))
2171+
forward_ctx = ctx[layer_name]
2172+
assert len(forward_ctx.kv_cache) == len(kv_cache)
2173+
for ve, ve_kv_cache in enumerate(kv_cache):
2174+
assert forward_ctx.kv_cache[ve].numel() == 0
2175+
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]

0 commit comments

Comments
 (0)