-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Refactor Attention Take 2 #3462
Changes from 47 commits
aa4b6c0
cf1c96c
88f1ff6
2044252
6378222
65b5988
6605797
b6f2cec
e5e0f49
d6985fa
c76e5c1
0d5bc56
5dde0d7
6bed6da
deb02da
2fd5b2c
8ce58cb
32894ce
1344e73
144f6cb
47c59b4
61d2f63
f820e1c
d2c0bf8
0ec1786
ec1904f
86918c8
19820f9
f0f6a96
15a36b7
90d91cd
d569f5c
f88877f
fbed6b0
7b20793
a6062d3
faa1806
2e14d70
82ee3d6
d95176a
1ffbe21
53e49c2
b3b99fd
de7f764
3ee77fd
bbcd032
83fbfcd
8a89930
0da3a5d
6142ce9
c6ea553
1740d1f
ea434dd
1c6c06e
3cee5ad
a5ac88e
f87ee0d
8b1dc2e
55010d9
c4ff6ee
ab2faeb
bbb1f11
927f788
cbefadb
5d87e8f
64734f8
1888a3e
5591caa
fba63c0
786110d
4cbae94
587f11e
d6964f7
a49e87d
0ce6259
843cb16
76552c4
0dbaed6
6747d14
2f1db63
fffdeea
f723ac1
722e09b
43bb346
12eba34
e84fa8a
58de636
6a8b538
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata | ||
from vllm.attention.layer import Attention | ||
from vllm.attention.selector import get_attn_backend | ||
|
||
__all__ = [ | ||
"AttentionBackend", | ||
"AttentionMetadata", | ||
"Attention", | ||
"get_attn_backend", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Optional, Tuple, Type | ||
|
||
import torch | ||
|
||
|
||
class AttentionBackend(ABC): | ||
"""Abstract class for attention backends.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_impl_cls() -> Type["AttentionImpl"]: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def make_metadata(*args, **kwargs) -> "AttentionMetadata": | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
) -> Tuple[int, ...]: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def swap_blocks( | ||
src_kv_cache: torch.Tensor, | ||
dst_kv_cache: torch.Tensor, | ||
src_to_dst: Dict[int, int], | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def copy_blocks( | ||
kv_caches: List[torch.Tensor], | ||
src_to_dists: Dict[int, List[int]], | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
|
||
@dataclass | ||
class AttentionMetadata: | ||
|
||
... | ||
|
||
|
||
class AttentionImpl(ABC): | ||
|
||
@abstractmethod | ||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: AttentionMetadata, | ||
) -> torch.Tensor: | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
"""Attention layer with Flash and PagedAttention.""" | ||
# NOTE(woosuk): This file is temporary and will be replaced by | ||
# FlashInfer backend. At the moment, this file includes many duplicated | ||
# code from XFormers backend. The duplicated code will be removed once | ||
# FlashInfer backend is implemented. | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Optional, Tuple, Type | ||
|
||
from flash_attn import flash_attn_func | ||
import torch | ||
|
||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
AttentionMetadata) | ||
from vllm.attention.ops.paged_attn import PagedAttention | ||
|
||
|
||
class FlashAttentionBackend(AttentionBackend): | ||
|
||
@staticmethod | ||
def get_impl_cls() -> Type["FlashAttentionImpl"]: | ||
return FlashAttentionImpl | ||
|
||
@staticmethod | ||
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata": | ||
return FlashAttentionMetadata(*args, **kwargs) | ||
|
||
@staticmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
) -> Tuple[int, ...]: | ||
return (2, num_blocks, block_size * num_kv_heads * head_size) | ||
|
||
@staticmethod | ||
def swap_blocks( | ||
src_kv_cache: torch.Tensor, | ||
dst_kv_cache: torch.Tensor, | ||
src_to_dst: Dict[int, int], | ||
) -> None: | ||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) | ||
|
||
@staticmethod | ||
def copy_blocks( | ||
kv_caches: List[torch.Tensor], | ||
src_to_dists: Dict[int, List[int]], | ||
) -> None: | ||
PagedAttention.copy_blocks(kv_caches, src_to_dists) | ||
|
||
|
||
@dataclass | ||
class FlashAttentionMetadata(AttentionMetadata): | ||
|
||
is_prompt: bool | ||
slot_mapping: torch.Tensor | ||
prompt_lens: Optional[torch.Tensor] | ||
max_seq_len: Optional[int] | ||
start_loc: Optional[torch.Tensor] | ||
max_context_len: Optional[int] | ||
context_lens: Optional[torch.Tensor] | ||
block_tables: Optional[torch.Tensor] | ||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. | ||
use_cuda_graph: bool | ||
Comment on lines
+102
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename to something more device-netural like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This variable doesn't have to be device-neutral since this is defined only for |
||
kv_cache_dtype: str | ||
|
||
|
||
class FlashAttentionImpl(AttentionImpl): | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.scale = float(scale) | ||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||
self.sliding_window = ((sliding_window, sliding_window) | ||
if sliding_window is not None else (-1, -1)) | ||
if alibi_slopes is not None: | ||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) | ||
self.alibi_slopes = alibi_slopes | ||
|
||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
||
suppored_head_sizes = PagedAttention.get_supported_head_sizes() | ||
if head_size not in suppored_head_sizes: | ||
raise ValueError( | ||
f"Head size {head_size} is not supported by PagedAttention. " | ||
f"Supported head sizes are: {suppored_head_sizes}.") | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: FlashAttentionMetadata, | ||
) -> torch.Tensor: | ||
"""Forward pass with FlashAttention and PagedAttention. | ||
|
||
Args: | ||
query: shape = [batch_size, seq_len, num_heads * head_size] | ||
key: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||
value: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] | ||
attn_metadata: Metadata for attention. | ||
Returns: | ||
shape = [batch_size, seq_len, num_heads * head_size] | ||
""" | ||
batch_size, seq_len, hidden_size = query.shape | ||
# Reshape the query, key, and value tensors. | ||
query = query.view(-1, self.num_heads, self.head_size) | ||
key = key.view(-1, self.num_kv_heads, self.head_size) | ||
value = value.view(-1, self.num_kv_heads, self.head_size) | ||
|
||
if kv_cache is not None: | ||
key_cache, value_cache = PagedAttention.split_kv_cache( | ||
kv_cache, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the input keys and values and store them in the cache. | ||
# If kv_cache is not provided, the new key and value tensors are | ||
# not cached. This happens during the initial memory profiling run. | ||
PagedAttention.reshape_and_cache(key, value, key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
attn_metadata.kv_cache_dtype) | ||
|
||
if attn_metadata.is_prompt: | ||
# Prompt run. | ||
if kv_cache is None or attn_metadata.block_tables.numel() == 0: | ||
# normal attention | ||
query = query.unflatten(0, (batch_size, seq_len)) | ||
key = key.unflatten(0, (batch_size, seq_len)) | ||
value = value.unflatten(0, (batch_size, seq_len)) | ||
output = flash_attn_func( | ||
query, | ||
key, | ||
value, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
window_size=self.sliding_window, | ||
alibi_slopes=self.alibi_slopes, | ||
) | ||
else: | ||
# prefix-enabled attention | ||
output = PagedAttention.forward_prefix( | ||
query, | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.block_tables, | ||
attn_metadata.start_loc, | ||
attn_metadata.prompt_lens, | ||
attn_metadata.context_lens, | ||
attn_metadata.max_seq_len, | ||
self.alibi_slopes, | ||
) | ||
else: | ||
# Decoding run. | ||
output = PagedAttention.forward_decode( | ||
query, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.block_tables, | ||
attn_metadata.context_lens, | ||
attn_metadata.max_context_len, | ||
attn_metadata.kv_cache_dtype, | ||
self.num_kv_heads, | ||
self.scale, | ||
self.alibi_slopes, | ||
) | ||
|
||
# Reshape the output tensor. | ||
return output.view(batch_size, seq_len, hidden_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true given that AMD will use the flash_attn path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed in another comment, I'm not sure whether we will use
flash-attn
for AMD GPUs since xformers started to officially support AMD GPUs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regardless, I fixed the comment to open the possibility that we have
flash-attn
backend as well asflashinfer
backend.