Skip to content

[Model] Implement DualChunkAttention for Qwen2 Models #6139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.attention.layer import Attention, DualChunkAttention
from vllm.attention.selector import (get_attn_backend,
get_dual_chunk_attn_backend)

__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"Attention",
"DualChunkAttention",
"get_attn_backend",
"get_dual_chunk_attn_backend",
]
129 changes: 129 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,132 @@ def forward(
kv_scale: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError


class DualChunkAttentionBackend(ABC):
"""Abstract class for dual chunk attention backends."""

@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["DualChunkAttentionImpl"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["DualChunkAttentionMetadata"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "DualChunkAttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError

@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError

@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError


@dataclass
class DualChunkAttentionMetadata:
"""DualChunkAttention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

@property
@abstractmethod
def prefill_metadata(self) -> Optional["DualChunkAttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass

@property
@abstractmethod
def decode_metadata(self) -> Optional["DualChunkAttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass

def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}


T2 = TypeVar("T2", bound=DualChunkAttentionMetadata)


class DualChunkAttentionImpl(ABC, Generic[T2]):

@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
raise NotImplementedError

@abstractmethod
def forward(
self,
query: torch.Tensor,
query_succ: torch.Tensor,
query_inter: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T2,
kv_scale: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError
Loading
Loading