Skip to content

Commit cba31c4

Browse files
authored
[v1] AttentionMetadata for each layer (#17394)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent a6fed02 commit cba31c4

File tree

9 files changed

+126
-46
lines changed

9 files changed

+126
-46
lines changed

vllm/attention/layer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def forward(
210210
if self.use_direct_call:
211211
forward_context: ForwardContext = get_forward_context()
212212
attn_metadata = forward_context.attn_metadata
213+
if isinstance(attn_metadata, dict):
214+
attn_metadata = attn_metadata[self.layer_name]
213215
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
214216
self.impl.forward(self,
215217
query,
@@ -226,6 +228,8 @@ def forward(
226228
if self.use_direct_call:
227229
forward_context = get_forward_context()
228230
attn_metadata = forward_context.attn_metadata
231+
if isinstance(attn_metadata, dict):
232+
attn_metadata = attn_metadata[self.layer_name]
229233
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
230234
return self.impl.forward(self, query, key, value,
231235
self_kv_cache, attn_metadata)
@@ -343,7 +347,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
343347
attn_metadata = forward_context.attn_metadata
344348
if attn_metadata is None:
345349
return
346-
350+
assert isinstance(attn_metadata, dict)
347351
connector.wait_for_layer_load(layer_name)
348352

349353

@@ -360,8 +364,9 @@ def maybe_save_kv_layer_to_connector(
360364
attn_metadata = forward_context.attn_metadata
361365
if attn_metadata is None:
362366
return
363-
364-
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
367+
assert isinstance(attn_metadata, dict)
368+
connector.save_kv_layer(layer_name, kv_cache_layer,
369+
attn_metadata[layer_name])
365370

366371

367372
def unified_attention(
@@ -374,6 +379,8 @@ def unified_attention(
374379

375380
forward_context: ForwardContext = get_forward_context()
376381
attn_metadata = forward_context.attn_metadata
382+
if isinstance(attn_metadata, dict):
383+
attn_metadata = attn_metadata[layer_name]
377384
self = forward_context.no_compile_layers[layer_name]
378385
kv_cache = self.kv_cache[forward_context.virtual_engine]
379386
output = self.impl.forward(self, query, key, value, kv_cache,
@@ -411,6 +418,8 @@ def unified_attention_with_output(
411418
wait_for_kv_layer_from_connector(layer_name)
412419
forward_context: ForwardContext = get_forward_context()
413420
attn_metadata = forward_context.attn_metadata
421+
if isinstance(attn_metadata, dict):
422+
attn_metadata = attn_metadata[layer_name]
414423
self = forward_context.no_compile_layers[layer_name]
415424
kv_cache = self.kv_cache[forward_context.virtual_engine]
416425
self.impl.forward(self,

vllm/forward_context.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from contextlib import contextmanager
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Any, Optional
7+
from typing import TYPE_CHECKING, Any, Optional, Union
88

99
import torch
1010
import torch.distributed as dist
@@ -38,8 +38,13 @@ class DPMetadata:
3838
class ForwardContext:
3939
# copy from vllm_config.compilation_config.static_forward_context
4040
no_compile_layers: dict[str, Any]
41-
# TODO: extend to support per-layer dynamic forward context
42-
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
41+
"""
42+
Type AttentionMetadata for v0,
43+
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
44+
attention layer to its attention metadata
45+
set dynamically for each forward pass
46+
"""
47+
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
4348
# TODO: remove after making all virtual_engines share the same kv cache
4449
virtual_engine: int # set dynamically for each forward pass
4550
# set dynamically for each forward pass

vllm/v1/attention/backends/flash_attn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.logger import init_logger
1919
from vllm.platforms import current_platform
2020
from vllm.utils import cdiv
21+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
2122

2223
if TYPE_CHECKING:
2324
from vllm.v1.core.sched.output import SchedulerOutput
@@ -309,13 +310,11 @@ def reorder_batch(self, input_batch: "InputBatch",
309310
return False
310311

311312
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
312-
common_prefix_len: int):
313+
common_prefix_len: int,
314+
common_attn_metadata: CommonAttentionMetadata):
313315
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
314-
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
315-
query_start_loc = query_start_loc_cpu.to(self.runner.device,
316-
non_blocking=True)
317-
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
318-
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
316+
query_start_loc = common_attn_metadata.query_start_loc
317+
seq_lens = common_attn_metadata.seq_lens
319318
block_table = (
320319
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
321320
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(

vllm/v1/attention/backends/flashinfer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_layers_from_vllm_config)
1919
from vllm.logger import init_logger
2020
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
21+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
2122

2223
if TYPE_CHECKING:
2324
from vllm.v1.core.sched.output import SchedulerOutput
@@ -394,16 +395,15 @@ def _plan(self, attn_metadata: FlashInferMetadata):
394395
)
395396

396397
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
397-
common_prefix_len: int):
398+
common_prefix_len: int,
399+
common_attn_metadata: CommonAttentionMetadata):
398400
assert self._num_decodes + self._num_prefills == num_reqs
399401
assert (self._num_decode_tokens +
400402
self._num_prefill_tokens == num_actual_tokens)
401403
page_size = self.runner.block_size
402404
device = self.runner.device
403-
qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
404-
self.runner.device, non_blocking=True)
405-
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
406-
non_blocking=True)
405+
qo_indptr = common_attn_metadata.query_start_loc
406+
seq_lens = common_attn_metadata.seq_lens
407407
block_table = (
408408
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
409409
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(

vllm/v1/attention/backends/mla/common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@
207207
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
208208
from vllm.platforms import current_platform
209209
from vllm.utils import cdiv, round_down
210+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
210211

211212
try:
212213
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -451,7 +452,8 @@ def _build_decode(self, input_positions: torch.Tensor,
451452
)
452453

453454
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
454-
common_prefix_len: int) -> M:
455+
common_prefix_len: int,
456+
common_attn_metadata: CommonAttentionMetadata) -> M:
455457
assert self._num_decodes + self._num_prefills == num_reqs
456458

457459
# Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -460,15 +462,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
460462
device = self.runner.device
461463
block_table = (
462464
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
463-
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
464-
device, non_blocking=True)
465465
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
466466
device, non_blocking=True).long()
467467
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
468468
device, non_blocking=True).long()
469469

470-
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
471-
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
470+
query_start_loc = common_attn_metadata.query_start_loc
471+
seq_lens = common_attn_metadata.seq_lens
472472

473473
prefill_metadata = None
474474
if self._num_prefills > 0:

vllm/v1/attention/backends/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass
3+
4+
import torch
5+
6+
7+
@dataclass
8+
class CommonAttentionMetadata:
9+
"""
10+
Attention metadata attributes that can be shared by layers in different KV
11+
cache groups and thus having different block table.
12+
"""
13+
14+
query_start_loc: torch.Tensor
15+
"""(batch_size + 1,), the start location of each request in query Tensor"""
16+
seq_lens: torch.Tensor
17+
"""(batch_size,), the length of each request including both computed tokens
18+
and newly scheduled tokens"""

vllm/v1/spec_decode/eagle.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import torch
33
import torch.nn as nn
44

5-
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
5+
from vllm.attention.layer import Attention
6+
from vllm.config import (CompilationLevel, VllmConfig,
7+
get_layers_from_vllm_config, set_current_vllm_config)
68
from vllm.forward_context import set_forward_context
79
from vllm.logger import init_logger
810
from vllm.model_executor.model_loader.loader import get_model_loader
@@ -276,6 +278,8 @@ def load_model(self, target_model: nn.Module) -> None:
276278
loader = get_model_loader(self.vllm_config.load_config)
277279
target_layer_num = self.vllm_config.model_config.get_num_layers(
278280
self.vllm_config.parallel_config)
281+
target_attn_layer_names = set(
282+
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
279283

280284
draft_model_config = \
281285
self.vllm_config.speculative_config.draft_model_config
@@ -292,6 +296,11 @@ def load_model(self, target_model: nn.Module) -> None:
292296
vllm_config=self.vllm_config,
293297
start_layer_id=target_layer_num).to(target_device)
294298

299+
draft_attn_layer_names = (
300+
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
301+
target_attn_layer_names)
302+
assert len(draft_attn_layer_names) == 1
303+
self.attn_layer_name = next(iter(draft_attn_layer_names))
295304
loaded_weights = self.model.load_weights(
296305
loader.get_all_weights(draft_model_config, self.model))
297306
if self.vllm_config.speculative_config.method == "eagle3":

vllm/v1/worker/gpu_model_runner.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
3131
check_use_alibi, is_pin_memory_available)
3232
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
33+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
3334
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
3435
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
3536
KVCacheConfig, KVCacheSpec,
@@ -157,9 +158,12 @@ def __init__(
157158
# Sampler
158159
self.sampler = Sampler()
159160

160-
# Lazy initialization
161+
# Lazy initializations
161162
# self.model: nn.Module # Set after load_model
163+
# Initialize in initialize_kv_cache
162164
self.kv_caches: list[torch.Tensor] = []
165+
# self.kv_cache_config: KVCacheConfig
166+
163167
# req_id -> (input_id -> encoder_output)
164168
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
165169

@@ -488,7 +492,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
488492
def _prepare_inputs(
489493
self,
490494
scheduler_output: "SchedulerOutput",
491-
) -> tuple[FlashAttentionMetadata, torch.Tensor,
495+
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
492496
Optional[SpecDecodeMetadata]]:
493497
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
494498
assert total_num_scheduled_tokens > 0
@@ -585,20 +589,39 @@ def _prepare_inputs(
585589
self.positions_cpu[:total_num_scheduled_tokens],
586590
non_blocking=True)
587591

588-
# Prepare for cascade attention if enabled & beneficial.
589-
common_prefix_len = 0
590-
if self.cascade_attn_enabled:
591-
common_prefix_len = self._compute_cascade_attn_prefix_len(
592-
num_scheduled_tokens,
593-
scheduler_output.num_common_prefix_blocks,
594-
)
592+
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
593+
self.device, non_blocking=True)
594+
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
595+
non_blocking=True)
596+
common_attn_metadata = CommonAttentionMetadata(
597+
query_start_loc=query_start_loc, seq_lens=seq_lens)
598+
599+
attn_metadata: dict[str, FlashAttentionMetadata] = {}
600+
# Prepare the attention metadata for each KV cache group and make layers
601+
# in the same group share the same metadata.
602+
# NOTE(Chen): there is exactly one KV cache group that contains all
603+
# attetnion layers in the model for now, so the current logic for
604+
# getting attn_metadata is not related to kv_cache_group information.
605+
# Will extend this part to support multiple KV cache groups later.
606+
for kv_cache_group_id, kv_cache_group_spec in enumerate(
607+
self.kv_cache_config.kv_cache_groups):
608+
609+
# Prepare for cascade attention if enabled & beneficial.
610+
common_prefix_len = 0
611+
if self.cascade_attn_enabled:
612+
common_prefix_len = self._compute_cascade_attn_prefix_len(
613+
num_scheduled_tokens,
614+
scheduler_output.num_common_prefix_blocks,
615+
)
595616

596-
attn_metadata = self.attn_metadata_builder.build(
597-
num_reqs=num_reqs,
598-
num_actual_tokens=total_num_scheduled_tokens,
599-
max_query_len=max_num_scheduled_tokens,
600-
common_prefix_len=common_prefix_len,
601-
)
617+
attn_metadata_i = self.attn_metadata_builder.build(
618+
num_reqs=num_reqs,
619+
num_actual_tokens=total_num_scheduled_tokens,
620+
max_query_len=max_num_scheduled_tokens,
621+
common_prefix_len=common_prefix_len,
622+
common_attn_metadata=common_attn_metadata)
623+
for layer_name in kv_cache_group_spec.layer_names:
624+
attn_metadata[layer_name] = attn_metadata_i
602625

603626
use_spec_decode = len(
604627
scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -608,7 +631,7 @@ def _prepare_inputs(
608631
# from these partial requests, we do so for simplicity.
609632
# We will ignore the sampled tokens from the partial requests.
610633
# TODO: Support prompt logprobs.
611-
logits_indices = attn_metadata.query_start_loc[1:] - 1
634+
logits_indices = query_start_loc[1:] - 1
612635
spec_decode_metadata = None
613636
else:
614637
# Get the number of draft tokens for each request.
@@ -1230,6 +1253,7 @@ def execute_model(
12301253
next_token_ids = torch.tensor(next_token_ids,
12311254
dtype=torch.int32,
12321255
device=self.device)
1256+
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
12331257

12341258
if spec_decode_metadata is None:
12351259
# input_ids can be None for multimodal models.
@@ -1241,8 +1265,8 @@ def execute_model(
12411265
dim=-1)
12421266
else:
12431267
target_hidden_states = hidden_states[:num_scheduled_tokens]
1244-
target_slot_mapping = attn_metadata.slot_mapping
1245-
cu_num_tokens = attn_metadata.query_start_loc
1268+
target_slot_mapping = eagle_attn_metadata.slot_mapping
1269+
cu_num_tokens = eagle_attn_metadata.query_start_loc
12461270
else:
12471271
# TODO(woosuk): Refactor this.
12481272
num_draft_tokens = spec_decode_metadata.num_draft_tokens
@@ -1256,7 +1280,7 @@ def execute_model(
12561280
device=self.device,
12571281
)
12581282
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
1259-
attn_metadata.query_start_loc,
1283+
eagle_attn_metadata.query_start_loc,
12601284
num_rejected_tokens,
12611285
)
12621286
target_token_ids = self.input_ids[token_indices]
@@ -1266,7 +1290,8 @@ def execute_model(
12661290
[h[token_indices] for h in aux_hidden_states], dim=-1)
12671291
else:
12681292
target_hidden_states = hidden_states[token_indices]
1269-
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
1293+
target_slot_mapping = eagle_attn_metadata.slot_mapping[
1294+
token_indices]
12701295

12711296
draft_token_ids = self.drafter.propose(
12721297
target_token_ids=target_token_ids,
@@ -1275,7 +1300,7 @@ def execute_model(
12751300
target_slot_mapping=target_slot_mapping,
12761301
next_token_ids=next_token_ids,
12771302
cu_num_tokens=cu_num_tokens,
1278-
block_table=attn_metadata.block_table,
1303+
block_table=eagle_attn_metadata.block_table,
12791304
sampling_metadata=sampling_metadata,
12801305
)
12811306
spec_token_ids = draft_token_ids.tolist()
@@ -1708,6 +1733,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
17081733
raise NotImplementedError(
17091734
"Hybrid models with more than one KV cache type are not "
17101735
"supported yet.")
1736+
self.kv_cache_config = kv_cache_config
17111737

17121738
kv_caches: dict[str, torch.Tensor] = {}
17131739

0 commit comments

Comments
 (0)