Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def create_deterministic_logits(token_ids):

# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder

result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,12 @@ def rnd_stride_order():
return rnd_stride

# Patch the attention backend class and re-trigger the KV cache creation.
for attn_backend in model_runner.attn_backends:
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order)

model_runner.attn_backends = []
model_runner.attn_metadata_builders = []
model_runner.attn_groups = []
model_runner.initialize_kv_cache(model_runner.kv_cache_config)

# Shape is unchanged, but layout may differ
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def advance_step(self, model_input: "ModelRunnerInputBase",
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError

@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)


@dataclass
class AttentionMetadata:
Expand Down
36 changes: 18 additions & 18 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the scenario for passing attn_backend explicitly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for class ChunkedLocalAttention(Attention): where we wrap the attention backend in LocalAttention wrapper

I think in the future we can modularize this better to make subclasses like ChunkedLocalAttention cleaner but Im a bit scared to refactor it too heavily before V0 is fully deprecated

**extra_impl_args,
) -> None:
"""
Expand Down Expand Up @@ -137,15 +139,6 @@ def __init__(
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window

# For v1 we have backend agnostic iRoPE (local chunked attention)
# we have to store the flag on the layer so gpu model runner can
# set KVSpec appropriately (and pop it so it doesnt get passed to
# the backends)
if envs.VLLM_USE_V1:
self.use_irope = extra_impl_args.pop("use_irope", False)
else:
self.use_irope = extra_impl_args.get("use_irope", False)

quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
Expand All @@ -166,18 +159,22 @@ def __init__(
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
if attn_backend is None:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla)
else:
self.attn_backend = attn_backend

impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
Expand All @@ -187,7 +184,7 @@ def __init__(
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()

self.use_output = attn_backend.accept_output_buffer
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
Expand Down Expand Up @@ -309,6 +306,9 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)

def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend


class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
Expand Down
88 changes: 88 additions & 0 deletions vllm/attention/layers/chunked_local_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import List, Optional

import torch

from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)

from ..layer import Attention


@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"

def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)

# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)

return attn_backend


class ChunkedLocalAttention(Attention):

def __init__(self,
num_heads: int,
head_size: int,
scale: float,
attention_chunk_size: int,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
kv_sharing_target_layer_name: Optional[str] = None,
prefix: str = ""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we don't need this currently, but can we add kv_sharing_target_layer_name as an arg to be feature complete with base Attention layer?

dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16

if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)

attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size)
else:
# in v0 the local attention is handled inside the backends
attn_backend = None

super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend)
2 changes: 1 addition & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_attention_free: bool = False,
use_mla: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers import Llama4TextConfig

from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -194,17 +195,18 @@ def __init__(self,
is_neox_style=is_neox_style,
) if not self.nope else None

self.attn = Attention(
attn_cls = Attention if self.nope else ChunkedLocalAttention
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn",
)
**({
"attention_chunk_size": config.attention_chunk_size
} if not self.nope else {}))

def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
Expand Down
48 changes: 46 additions & 2 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
TypeVar)

import numpy as np
import torch

from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv

Expand All @@ -20,6 +20,8 @@
from vllm.v1.worker.gpu_input_batch import InputBatch

import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
Expand Down Expand Up @@ -522,6 +524,48 @@ def make_local_attention_virtual_batches(
)


def subclass_attention_metadata_builder(
name_prefix: str,
builder_cls: type[AttentionMetadataBuilder[M]],
build_preprocess_fn: Callable[[CommonAttentionMetadata],
CommonAttentionMetadata],
) -> type[AttentionMetadataBuilder[M]]:
"""
Return a new subclass of `builder_cls` whose .build(...) method
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
"""
name: str = name_prefix + builder_cls.__name__ # type: ignore

def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
return builder_cls.build(self, common_prefix_len,
build_preprocess_fn(common_attn_metadata),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also add a build_postprocess_fn? I think #21590 needs it for adding some special attribute to YOCO layers. Also CC @sarckk
And @sarckk remeber to update #21590 to a cleaner way after this pr is landed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya thats a good idea; lets land the build_postprocess_fn with YOCO clean-up, since its hard for me to test otherwise

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Copy link
Collaborator

@sarckk sarckk Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_postprocess_fn for YOCO would involve passing logits_indices each iteration, so we cannot build it once at Attention layer init. Options I can think of:

  1. Keep postprocess in gpu model runner (what we have currently)
  2. Add build_postprocess_fn as an arg to build(...) (would affect all build impls)
  3. Add the extra fields to CommonAttentionMetadata that is only used by fast prefill KV sharing layers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I think I understand; ya thats tough. Let me think about it for a bit

fast_build)

Wrapped = type(
name,
(builder_cls, ), # inherit from the original
{
"build": build,
})
return Wrapped # type: ignore


def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore

return type(name, (attention_backend_cls, ),
{"get_builder_cls": lambda: builder_cls})


def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def propose(
assert self.runner is not None

# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)

# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
Expand Down Expand Up @@ -349,7 +349,8 @@ def propose_tree(
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builder
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)

Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
raise ValueError("Multiple KVCacheGroups is not"
"currently supported with CPU model runner.")

assert type(
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
assert type(self.attn_groups[0]
[0].metadata_builder) is TorchSDPAMetadataBuilderV1

self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
self.attn_groups[0][0].metadata_builder.reorder_batch(
self.input_batch, scheduler_output)

def _postprocess_tenosrs(self) -> None:
# Note: replace device tensors with cpu tensors
Expand Down
Loading