-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
[Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix #21588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
573396b
1627bfa
e245638
0f22c7e
561e793
6db193e
1e72add
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import functools | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
from vllm import envs | ||
from vllm.attention.backends.abstract import AttentionBackend | ||
from vllm.attention.selector import get_attn_backend | ||
from vllm.config import CacheConfig, QuantizationConfig | ||
from vllm.v1.attention.backends.utils import ( | ||
CommonAttentionMetadata, make_local_attention_virtual_batches, | ||
subclass_attention_backend, subclass_attention_metadata_builder) | ||
|
||
from ..layer import Attention | ||
|
||
|
||
@functools.lru_cache | ||
def create_chunked_local_attention_backend( | ||
underlying_attn_backend: AttentionBackend, | ||
attention_chunk_size: int, | ||
block_size: int, | ||
) -> type[AttentionBackend]: | ||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" | ||
|
||
def build_preprocess_fn(cm: CommonAttentionMetadata): | ||
return make_local_attention_virtual_batches(attention_chunk_size, cm, | ||
block_size) | ||
|
||
# Dynamically create a new attention backend that wraps the | ||
# underlying attention backend but applies | ||
# `make_local_attention_virtual_batches` before calling `build(...)` | ||
builder_cls = subclass_attention_metadata_builder( | ||
name_prefix=prefix, | ||
builder_cls=underlying_attn_backend.get_builder_cls(), | ||
build_preprocess_fn=build_preprocess_fn) | ||
attn_backend = subclass_attention_backend( | ||
name_prefix=prefix, | ||
attention_backend_cls=underlying_attn_backend, | ||
builder_cls=builder_cls) | ||
|
||
return attn_backend | ||
|
||
|
||
class ChunkedLocalAttention(Attention): | ||
|
||
def __init__(self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
attention_chunk_size: int, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
kv_sharing_target_layer_name: Optional[str] = None, | ||
prefix: str = ""): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know we don't need this currently, but can we add |
||
dtype = torch.get_default_dtype() | ||
if cache_config is not None: | ||
kv_cache_dtype = cache_config.cache_dtype | ||
block_size = cache_config.block_size | ||
else: | ||
kv_cache_dtype = "auto" | ||
block_size = 16 | ||
|
||
if envs.VLLM_USE_V1: | ||
underlying_attn_backend = get_attn_backend(head_size, dtype, | ||
kv_cache_dtype, | ||
block_size) | ||
|
||
attn_backend = create_chunked_local_attention_backend( | ||
underlying_attn_backend, attention_chunk_size, block_size) | ||
else: | ||
# in v0 the local attention is handled inside the backends | ||
attn_backend = None | ||
|
||
super().__init__( | ||
num_heads=num_heads, | ||
head_size=head_size, | ||
scale=scale, | ||
num_kv_heads=num_kv_heads, | ||
alibi_slopes=alibi_slopes, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
prefix=prefix, | ||
kv_sharing_target_layer_name=kv_sharing_target_layer_name, | ||
attn_backend=attn_backend) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,12 @@ | |
import functools | ||
from abc import abstractmethod | ||
from dataclasses import dataclass, make_dataclass | ||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar | ||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, | ||
TypeVar) | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from vllm.attention.layer import Attention | ||
from vllm.config import VllmConfig, get_layers_from_vllm_config | ||
from vllm.utils import cdiv | ||
|
||
|
@@ -20,6 +20,8 @@ | |
from vllm.v1.worker.gpu_input_batch import InputBatch | ||
|
||
import vllm.envs as envs | ||
from vllm.attention.backends.abstract import AttentionBackend | ||
from vllm.attention.layer import Attention | ||
from vllm.distributed.kv_transfer.kv_connector.utils import ( | ||
get_kv_connector_cache_layout) | ||
from vllm.logger import init_logger | ||
|
@@ -522,6 +524,48 @@ def make_local_attention_virtual_batches( | |
) | ||
|
||
|
||
def subclass_attention_metadata_builder( | ||
name_prefix: str, | ||
builder_cls: type[AttentionMetadataBuilder[M]], | ||
build_preprocess_fn: Callable[[CommonAttentionMetadata], | ||
CommonAttentionMetadata], | ||
) -> type[AttentionMetadataBuilder[M]]: | ||
""" | ||
Return a new subclass of `builder_cls` whose .build(...) method | ||
first calls build_preprocess_fn(common_attn_metadata) on the metadata. | ||
""" | ||
name: str = name_prefix + builder_cls.__name__ # type: ignore | ||
|
||
def build(self, | ||
common_prefix_len: int, | ||
common_attn_metadata: CommonAttentionMetadata, | ||
fast_build: bool = False): | ||
return builder_cls.build(self, common_prefix_len, | ||
build_preprocess_fn(common_attn_metadata), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ya thats a good idea; lets land the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok I think I understand; ya thats tough. Let me think about it for a bit |
||
fast_build) | ||
|
||
Wrapped = type( | ||
name, | ||
(builder_cls, ), # inherit from the original | ||
{ | ||
"build": build, | ||
}) | ||
return Wrapped # type: ignore | ||
|
||
|
||
def subclass_attention_backend( | ||
name_prefix: str, attention_backend_cls: type[AttentionBackend], | ||
builder_cls: type[AttentionMetadataBuilder[M]] | ||
) -> type[AttentionBackend]: | ||
""" | ||
Return a new subclass where `get_builder_cls` returns `builder_cls`. | ||
""" | ||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore | ||
|
||
return type(name, (attention_backend_cls, ), | ||
{"get_builder_cls": lambda: builder_cls}) | ||
|
||
|
||
def split_decodes_and_prefills( | ||
common_attn_metadata: CommonAttentionMetadata, | ||
decode_threshold: int = 1, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the scenario for passing attn_backend explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for
class ChunkedLocalAttention(Attention):
where we wrap the attention backend in LocalAttention wrapperI think in the future we can modularize this better to make subclasses like
ChunkedLocalAttention
cleaner but Im a bit scared to refactor it too heavily before V0 is fully deprecated