Skip to content

Use FlashAttention for multi_query_kv_attention #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

```bash
pip install cmake torch transformers
pip install flash-attn # This may take up to 10 mins.
pip install -e .
```

Expand Down
67 changes: 37 additions & 30 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional

from flash_attn.flash_attention import FlashAttention
import torch
import torch.nn as nn

Expand All @@ -14,20 +15,7 @@ def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)

def _masked_attention(
self,
query: torch.Tensor, # [num_queries, num_heads, head_size]
key: torch.Tensor, # [num_keys, num_heads, head_size]
value: torch.Tensor, # [num_keys, num_heads, head_size]
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
) -> torch.Tensor: # [num_queries, num_heads, head_size]
query = query * self.scale
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
out = torch.einsum('hqk,khd->qhd', attn, value)
return out
self.flash_attn = FlashAttention(softmax_scale=self.scale)

def multi_query_kv_attention(
self,
Expand All @@ -37,21 +25,31 @@ def multi_query_kv_attention(
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int],
) -> None:
# FIXME(woosuk): Replace the following with a custom op.
start_idx = 0
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[2]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')

device = query.device
prefix_sum = [0]
for prompt_len in prompt_lens:
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]

attention_mask = torch.triu(
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
attention_out = self._masked_attention(q, k, v, attention_mask)
out.copy_(attention_out, non_blocking=True)

start_idx += prompt_len
prefix_sum.append(prefix_sum[-1] + prompt_len)
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
max_prompt_len = max(prompt_lens)

# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv = torch.stack([query, key, value], dim=1)
out = self.flash_attn(
qkv,
cu_seqlens=prefix_sum,
max_s=max_prompt_len,
causal=True,
)[0]
num_tokens = prefix_sum[-1]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output[:num_tokens].copy_(out, non_blocking=True)

def single_query_cached_kv_attention(
self,
Expand All @@ -61,6 +59,14 @@ def single_query_cached_kv_attention(
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
head_size = value_cache.shape[2]
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
if head_size not in supported_head_sizes:
raise ValueError(f'head_size ({head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{supported_head_sizes}.')

block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
Expand Down Expand Up @@ -101,8 +107,9 @@ def forward(
output = output.view(-1, num_heads, head_size)

# Compute the attention op for prompts.
self.multi_query_kv_attention(
output, query, key, value, input_metadata.prompt_lens)
if input_metadata.num_prompts > 0:
self.multi_query_kv_attention(
output, query, key, value, input_metadata.prompt_lens)

# Wait until the cache op is done.
if cache_event is not None:
Expand Down
7 changes: 5 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
parser.add_argument('--block-size', type=int, default=8, help='token block size')
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)')
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
args = parser.parse_args()


Expand All @@ -27,6 +29,7 @@ def main():
block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks,
dtype=args.dtype,
)
controllers.append(controller)

Expand Down
67 changes: 65 additions & 2 deletions tests/kernels/attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import random
from typing import Optional

from flash_attn.flash_attention import FlashAttention
import torch

from cacheflow import attention_ops

MAX_SEQ_LEN = 4096


def ref_masked_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention(
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')

context_lens = [random.randint(1, 4096) for _ in range(num_tokens)]
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')

Expand Down Expand Up @@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


def test_multi_query_kv_attention(
num_seqs: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> None:
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
max_seq_len = max(seq_lens)
num_tokens = sum(seq_lens)

cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')

scale = float(1.0 / (head_size ** 0.5))
query = torch.randn(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
key = torch.rand_like(query)
value = torch.rand_like(query)

qkv = torch.stack([query, key, value], dim=1)
flash_attn = FlashAttention(softmax_scale=scale)
output = flash_attn(
qkv,
cu_seqlens=cu_seq_lens,
max_s=max_seq_len,
causal=True,
)[0]

ref_outputs = []
for i, seq_len in enumerate(seq_lens):
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)

assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


@torch.inference_mode()
def test_attention() -> None:
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for head_size in [64, 80, 96, 128, 256]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
test_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
Expand All @@ -137,6 +189,17 @@ def test_attention() -> None:
dtype=dtype,
)

# NOTE(woosuk): FlashAttention does not support FP32.
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for head_size in [64, 80, 96, 128]:
test_multi_query_kv_attention(
num_seqs=11,
num_heads=3,
head_size=head_size,
dtype=dtype,
)


if __name__ == '__main__':
test_attention()