Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/N] Chunked prefill data update #3538

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
4d6a05f
[2/n] scheduler changes
rkooo567 Feb 29, 2024
0831f84
[2/n] ip
rkooo567 Feb 29, 2024
f31371f
[2/n]ip
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
b9d93c5
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Feb 29, 2024
42dd362
[2/n] ip
rkooo567 Mar 1, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
e3afc25
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
6141885
[2/n] ip
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
d503a22
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
85760db
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 3, 2024
e40bc45
[2/n] update sequence data
rkooo567 Mar 3, 2024
d85670f
[2/n] add prefill range apis
rkooo567 Mar 3, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
0ca1284
add data.
rkooo567 Mar 3, 2024
2487bda
ip
rkooo567 Mar 3, 2024
81151e8
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
c5f3a0d
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9bbb04e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
5e47c1e
Merge branch 'chunked-prefill-3' into chunked-prefill-scheduler-data-…
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
c66ec36
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
c067a4c
working.
rkooo567 Mar 11, 2024
e1f244a
clean up.
rkooo567 Mar 11, 2024
d09eaf5
.
rkooo567 Mar 11, 2024
4a8ab3c
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
a08e65e
Merge branch 'main' into 1dquery
rkooo567 Mar 11, 2024
d9532f8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 11, 2024
93a7b90
.
rkooo567 Mar 12, 2024
b4b94c6
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
647d8cc
.
rkooo567 Mar 12, 2024
65ac6ce
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
b2f4b3e
ip
rkooo567 Mar 12, 2024
cc8419f
.
rkooo567 Mar 12, 2024
76e7ca8
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 12, 2024
d3d0336
Merge branch 'main' into 1dquery
rkooo567 Mar 15, 2024
11ec167
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 15, 2024
3cb8093
ip addressing comments.
rkooo567 Mar 16, 2024
5391129
Alibi slopes working now.
rkooo567 Mar 18, 2024
6b04443
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
fe344f6
add new fieflds
rkooo567 Mar 18, 2024
e619c4e
Flash attn works now
rkooo567 Mar 18, 2024
9c86aa3
Linting
rkooo567 Mar 18, 2024
5b4aa09
temporary
rkooo567 Mar 18, 2024
03dd155
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
4cced78
fix tests
rkooo567 Mar 18, 2024
cdb7a2c
Fixed
rkooo567 Mar 18, 2024
276be06
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 18, 2024
d87b651
Pass unit tests.
rkooo567 Mar 18, 2024
2c18896
experiment
rkooo567 Mar 18, 2024
b46f902
.
rkooo567 Mar 18, 2024
07b22f8
.
rkooo567 Mar 18, 2024
9bd7ea1
.
rkooo567 Mar 18, 2024
c55402f
trial
rkooo567 Mar 18, 2024
a13cf7e
remove --fork
rkooo567 Mar 18, 2024
c5c5581
Merge branch 'main' into 1dquery
rkooo567 Mar 18, 2024
ec91304
fixed
rkooo567 Mar 19, 2024
4977e53
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 19, 2024
4a54688
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
2e6e919
Addressed code review.
rkooo567 Mar 19, 2024
1f6f6b0
Merge branch 'main' into 1dquery
rkooo567 Mar 19, 2024
ac7828c
revert removing forked
rkooo567 Mar 19, 2024
3d7f1a1
done
rkooo567 Mar 19, 2024
bcdd74a
Merge branch 'main' into 1dquery
rkooo567 Mar 20, 2024
fa3ce4e
final code review.
rkooo567 Mar 20, 2024
a83b235
Merge branch '1dquery' into chunked-prefill-scheduler-data-update
rkooo567 Mar 20, 2024
7205ef9
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 21, 2024
8bc0af5
.
rkooo567 Mar 21, 2024
97bcb6f
ip
rkooo567 Mar 21, 2024
df34350
working except tests.
rkooo567 Mar 21, 2024
e70e03d
.
rkooo567 Mar 21, 2024
f89f428
ip
rkooo567 Mar 21, 2024
bf02f8e
done
rkooo567 Mar 21, 2024
ad43095
done
rkooo567 Mar 21, 2024
16b6196
Addressed code review.
rkooo567 Mar 22, 2024
916abc8
merge conflict fixed
rkooo567 Mar 25, 2024
5002e61
update
rkooo567 Mar 25, 2024
80f51ea
test fix
rkooo567 Mar 25, 2024
3cc5e99
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 25, 2024
fa7ba35
lint
rkooo567 Mar 25, 2024
51cf7f2
fix broken tests.
rkooo567 Mar 25, 2024
cdee1c6
.
rkooo567 Mar 26, 2024
16e3a7d
done
rkooo567 Mar 26, 2024
e0d301c
remove num chunked prefill from seq group metadata
rkooo567 Mar 27, 2024
5e0f87e
change apis
rkooo567 Mar 27, 2024
6e72648
cleaned
rkooo567 Mar 27, 2024
4f869be
now working
rkooo567 Mar 27, 2024
4f63c57
update with new apis
rkooo567 Mar 27, 2024
5c3abf4
working!
rkooo567 Mar 27, 2024
66f3fcf
fixed
rkooo567 Mar 27, 2024
9c12d8e
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 27, 2024
9d4b65c
Addressed code review.
rkooo567 Mar 28, 2024
54a58b2
Merge branch 'main' into chunked-prefill-scheduler-data-update
rkooo567 Mar 28, 2024
9bdb9dc
fix tests.
rkooo567 Mar 28, 2024
88126a9
fixed a bug
rkooo567 Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
oss flash attention works
  • Loading branch information
rkooo567 committed Feb 28, 2024
commit 6947167ea42f592df37ead195dbfb5c6906609bd
250 changes: 18 additions & 232 deletions tests/kernels/test_flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import random
from typing import List, Optional, Tuple
from typing import Optional, Tuple

import pytest
import torch
import torch.nn.functional as F
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from vllm.model_executor.layers.attention import (
flash_single_query_cached_kv_attention,
flash_multi_query_cached_kv_attention_varlen,
)
flash_attn_with_kvcache_paged, )
from vllm.utils import get_max_shared_memory_bytes

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
Expand All @@ -25,11 +21,16 @@
NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
NUM_HEADS_SMALL = NUM_HEADS
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [32]
# head size should be bigger than or equal to block size.
HEAD_SIZES = [256]
# TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824
# should fix the block size. But right now, the block size should be
# divisible by 256.
BLOCK_SIZES = [256]
USE_ALIBI = [False, True]
SEEDS = [0]
PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)]
# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)]
PAD_CONFIGS = [(0, 0)]


def pad_attention_inputs(
Expand Down Expand Up @@ -137,22 +138,14 @@ def ref_single_query_cached_kv_attention(
output[i].copy_(out, non_blocking=True)


# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("use_alibi", [False])
# @pytest.mark.parametrize("block_size", [32])
# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("pad_config", PAD_CONFIGS)
@pytest.mark.parametrize("num_seqs", [3])
@pytest.mark.parametrize("num_heads", [(40, 40), (64, 8)])
@pytest.mark.parametrize("head_size", [80, 96])
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("block_size", [32])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", [False, True])
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("pad_config", [(0, 0)])
@pytest.mark.parametrize("pad_config", PAD_CONFIGS)
@torch.inference_mode()
def test_flash_paged_attention(
kv_cache_factory,
Expand Down Expand Up @@ -180,9 +173,6 @@ def test_flash_paged_attention(

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
Expand Down Expand Up @@ -218,15 +208,13 @@ def test_flash_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]

# Call the paged attention kernel.
num_valid_tokens = torch.cuda.IntTensor([num_seqs])
output = torch.empty_like(query)

padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \
pad_attention_inputs(pad_config, block_size, query,
block_tables, context_lens, max_context_len)

flash_single_query_cached_kv_attention(
output,
output = flash_attn_with_kvcache_paged(
padded_query,
key_cache,
value_cache,
Expand Down Expand Up @@ -256,205 +244,3 @@ def test_flash_paged_attention(
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


def ref_multi_query_kv_attention_padded(
query: torch.Tensor,
num_queries_per_kv: int,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
cu_seq_lens: List[int],
context_lens: List[int],
scale: float,
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1
block_size = value_cache.shape[-3]
ref_outputs = []

for i in range(num_seqs):
q_start_idx = cu_seq_lens[i]
q_end_idx = cu_seq_lens[i + 1]
seq_len = q_end_idx - q_start_idx

context_len = context_lens[i]

block_table = block_tables[i]
keys = []
values = []

for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size

k = key_cache[block_number, block_offset, :, :]
keys.append(k)

v = value_cache[block_number, block_offset, :, :]
values.append(v)

keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)

if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

q = query[q_start_idx:q_end_idx, :, :]
k = keys[:context_len, :, :]
v = values[:context_len, :, :]

assert seq_len <= context_len

# pad q if seq_len is less than context_len
# this is for correct calculation of attention.
if seq_len < context_len:
indices = [i % seq_len for i in range(context_len - seq_len)]
q_left_pad = q[indices, :, :]
q = torch.cat([q_left_pad, q], dim=0)

# Create attention mask.
attn_mask = torch.triu(torch.ones(context_len,
context_len,
dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device="cuda")

ref_output = ref_masked_attention(
q,
k,
v,
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output[-seq_len:, :, :])
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output


def is_a100():
return torch.cuda.get_device_name().find("NVIDIA A100") >= 0


if not is_a100():
NUM_HEADS_SMALL = [(16, 16), (16, 8)]
MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192)

NUM_BLOCKS = 1024
BLOCK_SIZE = 32


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("version", ["flash"])
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
version: str,
seed: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096)

seq_lens = [random.randint(1, max_len / 2) for i in range(num_seqs)]
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda")

context_lens = random.sample(range(max_seq_len, max_len), num_seqs)
max_context_len = max(context_lens)
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")

num_tokens = sum(seq_lens)
cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)

cu_context_lens = [0]
for context_len in context_lens:
cu_context_lens.append(cu_context_lens[-1] + context_len)

scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
num_queries_per_kv = num_query_heads // num_kv_heads

value_cache = torch.empty(NUM_BLOCKS,
BLOCK_SIZE,
num_kv_heads,
head_size,
dtype=dtype,
device="cuda")
key_cache = torch.empty(NUM_BLOCKS,
BLOCK_SIZE,
num_kv_heads,
head_size,
dtype=dtype,
device="cuda")
query = torch.empty(num_tokens,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
value_cache.uniform_(-scale, scale)
key_cache.uniform_(-scale, scale)
query.uniform_(-scale, scale)

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + BLOCK_SIZE - 1) // BLOCK_SIZE
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")

output = torch.empty_like(query)

if version == "flash":
flash_multi_query_cached_kv_attention_varlen(
output,
query,
key_cache,
value_cache,
scale,
block_tables,
torch.cuda.IntTensor(cu_seq_lens),
torch.cuda.IntTensor(cu_context_lens),
BLOCK_SIZE,
max_seq_len,
max_context_len,
None,
)
else:
assert False, f"{version=} is not supported"

ref_output = ref_multi_query_kv_attention_padded(
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
cu_seq_lens,
context_lens,
scale,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
Loading