Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Refactor Attention Take 2 #3462

Merged
merged 88 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
aa4b6c0
Move
WoosukKwon Mar 17, 2024
cf1c96c
Attention
WoosukKwon Mar 17, 2024
88f1ff6
Fix
WoosukKwon Mar 17, 2024
2044252
Fix import errors
WoosukKwon Mar 17, 2024
6378222
Remove
WoosukKwon Mar 17, 2024
65b5988
Add Abstract AttentionBackend
WoosukKwon Mar 17, 2024
6605797
FlashAttn -> FlashInfer
WoosukKwon Mar 17, 2024
b6f2cec
Remove KVCache Layout
WoosukKwon Mar 17, 2024
e5e0f49
Add attention dispatcher
WoosukKwon Mar 17, 2024
d6985fa
Fix Abstract attention backend
WoosukKwon Mar 17, 2024
c76e5c1
Minor
WoosukKwon Mar 17, 2024
0d5bc56
Fix models
WoosukKwon Mar 17, 2024
5dde0d7
FlashInfer -> FlashAttention
WoosukKwon Mar 17, 2024
6bed6da
PagedAttentionImpl -> PagedAttention
WoosukKwon Mar 17, 2024
deb02da
Minor
WoosukKwon Mar 17, 2024
2fd5b2c
WIP
WoosukKwon Mar 17, 2024
8ce58cb
Minor
WoosukKwon Mar 18, 2024
32894ce
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 18, 2024
1344e73
Fix Neuron
WoosukKwon Mar 18, 2024
144f6cb
Remove cache events
WoosukKwon Mar 18, 2024
47c59b4
Remove
WoosukKwon Mar 18, 2024
61d2f63
KVCache -> torch.Tensor
WoosukKwon Mar 18, 2024
f820e1c
yapf
WoosukKwon Mar 18, 2024
d2c0bf8
Move
WoosukKwon Mar 18, 2024
0ec1786
Get attn backend
WoosukKwon Mar 18, 2024
ec1904f
Move
WoosukKwon Mar 18, 2024
86918c8
Fix
WoosukKwon Mar 18, 2024
19820f9
Fix
WoosukKwon Mar 18, 2024
f0f6a96
yapf
WoosukKwon Mar 18, 2024
15a36b7
type
WoosukKwon Mar 18, 2024
90d91cd
Remove InputMetadata
WoosukKwon Mar 18, 2024
d569f5c
Fix PagedAttention
WoosukKwon Mar 18, 2024
f88877f
Fix
WoosukKwon Mar 18, 2024
fbed6b0
Minor
WoosukKwon Mar 18, 2024
7b20793
InputMetadata -> AttentionMetadata
WoosukKwon Mar 18, 2024
a6062d3
Comment
WoosukKwon Mar 18, 2024
faa1806
Fix swap and copy
WoosukKwon Mar 18, 2024
2e14d70
Minor
WoosukKwon Mar 18, 2024
82ee3d6
Comment
WoosukKwon Mar 18, 2024
d95176a
Minor
WoosukKwon Mar 18, 2024
1ffbe21
Fix FlashAttention backend
WoosukKwon Mar 18, 2024
53e49c2
yapf
WoosukKwon Mar 18, 2024
b3b99fd
Minor
WoosukKwon Mar 18, 2024
de7f764
Minor
WoosukKwon Mar 18, 2024
3ee77fd
yapfg
WoosukKwon Mar 18, 2024
bbcd032
Minor
WoosukKwon Mar 18, 2024
83fbfcd
Minor refactor
WoosukKwon Mar 18, 2024
8a89930
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 20, 2024
0da3a5d
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 20, 2024
6142ce9
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 21, 2024
c6ea553
Minor
WoosukKwon Mar 21, 2024
1740d1f
Fix
WoosukKwon Mar 21, 2024
ea434dd
Fix
WoosukKwon Mar 21, 2024
1c6c06e
Fix
WoosukKwon Mar 21, 2024
3cee5ad
yapf
WoosukKwon Mar 21, 2024
a5ac88e
Delete unused logger
WoosukKwon Mar 21, 2024
f87ee0d
Remove max_size=1
WoosukKwon Mar 21, 2024
8b1dc2e
Use max_prompt_len instead of max_seq_len
WoosukKwon Mar 21, 2024
55010d9
Add PagedAttentionMetadata
WoosukKwon Mar 21, 2024
c4ff6ee
ruff
WoosukKwon Mar 21, 2024
ab2faeb
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 21, 2024
bbb1f11
Fix
WoosukKwon Mar 21, 2024
927f788
Fix Jais
WoosukKwon Mar 21, 2024
cbefadb
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 22, 2024
5d87e8f
Rename
WoosukKwon Mar 24, 2024
64734f8
Fix comment
WoosukKwon Mar 24, 2024
1888a3e
Fix comment
WoosukKwon Mar 24, 2024
5591caa
Minor
WoosukKwon Mar 24, 2024
fba63c0
Merge allocate_gpu_cache and allocate_cpu_cache
WoosukKwon Mar 24, 2024
786110d
Fix comment on ref attention
WoosukKwon Mar 24, 2024
4cbae94
Minor
WoosukKwon Mar 24, 2024
587f11e
ref -> naive
WoosukKwon Mar 24, 2024
d6964f7
Fix
WoosukKwon Mar 24, 2024
a49e87d
PagedAttention.get_kv_cache_shape
WoosukKwon Mar 24, 2024
0ce6259
Empty cache
WoosukKwon Mar 24, 2024
843cb16
use enforce_eager for test_logprob
WoosukKwon Mar 24, 2024
76552c4
Revert
WoosukKwon Mar 24, 2024
0dbaed6
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 24, 2024
6747d14
log when using naive attention
WoosukKwon Mar 24, 2024
2f1db63
Merge branch 'main' into flashinfer-take3
WoosukKwon Mar 25, 2024
fffdeea
lru_cache(maxsize=None)
WoosukKwon Mar 25, 2024
f723ac1
Minor
WoosukKwon Mar 25, 2024
722e09b
Minor
WoosukKwon Mar 25, 2024
43bb346
Add assert
WoosukKwon Mar 25, 2024
12eba34
Minor fix in err msg
WoosukKwon Mar 25, 2024
e84fa8a
soft fail for samplers test
WoosukKwon Mar 25, 2024
58de636
Add explicit gc to test_beam_search
WoosukKwon Mar 25, 2024
6a8b538
Minor
WoosukKwon Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import time

import torch
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

Expand Down
10 changes: 10 additions & 0 deletions vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

__all__ = [
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
]
78 changes: 78 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type

import torch


class AttentionBackend(ABC):
"""Abstract class for attention backends."""

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError

@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError

@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError

@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError


@dataclass
class AttentionMetadata:

...


class AttentionImpl(ABC):

@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
raise NotImplementedError

@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
raise NotImplementedError
182 changes: 182 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Attention layer with Flash and PagedAttention."""
# NOTE(woosuk): This file is temporary and will be replaced by
# FlashInfer backend. At the moment, this file includes many duplicated
# code from XFormers backend. The duplicated code will be removed once
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true given that AMD will use the flash_attn path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in another comment, I'm not sure whether we will use flash-attn for AMD GPUs since xformers started to officially support AMD GPUs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless, I fixed the comment to open the possibility that we have flash-attn backend as well as flashinfer backend.

# FlashInfer backend is implemented.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type

from flash_attn import flash_attn_func
import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import PagedAttention


class FlashAttentionBackend(AttentionBackend):

@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class FlashAttentionMetadata(AttentionMetadata):

is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[torch.Tensor]
max_seq_len: Optional[int]
start_loc: Optional[torch.Tensor]
max_context_len: Optional[int]
context_lens: Optional[torch.Tensor]
block_tables: Optional[torch.Tensor]
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
Comment on lines +102 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to something more device-netural like pad_graph_inputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable doesn't have to be device-neutral since this is defined only for FlashAttentionBackend, which is only for NVIDIA GPUs (or plus AMD GPUs). I feel we just need to move this variable to outside of AttentionMetadata since it is not used inside Attention.

kv_cache_dtype: str


class FlashAttentionImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.reshape_and_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype)

if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.start_loc,
attn_metadata.prompt_lens,
attn_metadata.context_lens,
attn_metadata.max_seq_len,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Loading
Loading