Skip to content

[v1] AttentionMetadata for each layer #17394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def forward(
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
Expand All @@ -226,6 +228,8 @@ def forward(
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
Expand Down Expand Up @@ -343,7 +347,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return

assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)


Expand All @@ -360,8 +364,9 @@ def maybe_save_kv_layer_to_connector(
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return

connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])


def unified_attention(
Expand All @@ -374,6 +379,8 @@ def unified_attention(

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache,
Expand Down Expand Up @@ -411,6 +418,8 @@ def unified_attention_with_output(
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
Expand Down
11 changes: 8 additions & 3 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -38,8 +38,13 @@ class DPMetadata:
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
"""
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata
set dynamically for each forward pass
"""
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
Expand Down
11 changes: 5 additions & 6 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -309,13 +310,11 @@ def reorder_batch(self, input_batch: "InputBatch",
return False

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True)
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_layers_from_vllm_config)
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -394,16 +395,15 @@ def _plan(self, attn_metadata: FlashInferMetadata):
)

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.runner.block_size
device = self.runner.device
qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
non_blocking=True)
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import CommonAttentionMetadata

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down Expand Up @@ -451,7 +452,8 @@ def _build_decode(self, input_positions: torch.Tensor,
)

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int) -> M:
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
assert self._num_decodes + self._num_prefills == num_reqs

# Note(simon): be careful about the CPU <> GPU memory movement in this
Expand All @@ -460,15 +462,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
device = self.runner.device
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
device, non_blocking=True)
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()

seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens

prefill_metadata = None
if self._num_prefills > 0:
Expand Down
18 changes: 18 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass

import torch


@dataclass
class CommonAttentionMetadata:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""

query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
11 changes: 10 additions & 1 deletion vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import triton
import triton.language as tl

from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader.loader import get_model_loader
Expand Down Expand Up @@ -277,6 +279,8 @@ def load_model(self, target_model: nn.Module) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_layer_num = self.vllm_config.model_config.get_num_layers(
self.vllm_config.parallel_config)
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())

draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
Expand All @@ -293,6 +297,11 @@ def load_model(self, target_model: nn.Module) -> None:
vllm_config=self.vllm_config,
start_layer_id=target_layer_num).to(target_device)

draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = next(iter(draft_attn_layer_names))
loaded_weights = self.model.load_weights(
loader.get_all_weights(draft_model_config, self.model))
if self.vllm_config.speculative_config.method == "eagle3":
Expand Down
68 changes: 47 additions & 21 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
Expand Down Expand Up @@ -157,9 +158,12 @@ def __init__(
# Sampler
self.sampler = Sampler()

# Lazy initialization
# Lazy initializations
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
# self.kv_cache_config: KVCacheConfig

# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

Expand Down Expand Up @@ -488,7 +492,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[FlashAttentionMetadata, torch.Tensor,
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
Expand Down Expand Up @@ -585,20 +589,39 @@ def _prepare_inputs(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)

# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)

attn_metadata: dict[str, FlashAttentionMetadata] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):

# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)

attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
)
attn_metadata_i = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i

use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
Expand All @@ -608,7 +631,7 @@ def _prepare_inputs(
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = attn_metadata.query_start_loc[1:] - 1
logits_indices = query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
Expand Down Expand Up @@ -1230,6 +1253,7 @@ def execute_model(
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]

if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
Expand All @@ -1241,8 +1265,8 @@ def execute_model(
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
Expand All @@ -1256,7 +1280,7 @@ def execute_model(
device=self.device,
)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
eagle_attn_metadata.query_start_loc,
num_rejected_tokens,
)
target_token_ids = self.input_ids[token_indices]
Expand All @@ -1266,7 +1290,8 @@ def execute_model(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]

draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
Expand All @@ -1275,7 +1300,7 @@ def execute_model(
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_table,
block_table=eagle_attn_metadata.block_table,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
Expand Down Expand Up @@ -1708,6 +1733,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.kv_cache_config = kv_cache_config

kv_caches: dict[str, torch.Tensor] = {}

Expand Down
Loading