Skip to content

[Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAttention and Improve Efficiency #12921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 67 additions & 51 deletions tests/neuron/test_prefix_prefill.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import random
from typing import Optional

import pytest
Expand Down Expand Up @@ -171,19 +170,31 @@ def ref_context_attention(
return output


@pytest.mark.parametrize(
"block_size, large_tile_size",
[
(32, 2048), # 64 blocks
(32, 4096), # 128 blocks
(32, 8192), # 256 blocks
(64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size,mixed_precision",
[
(4, 2, 8, False),
(4, 2, 8, True),
(32, 8, 64, True),
(16, 2, 128, True),
],
)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
Expand All @@ -192,40 +203,46 @@ def test_contexted_kv_attention(

from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc

assert large_tile_size % block_size == 0

device = xm.xla_device()

os.environ["NEURON_CC_FLAGS"] = (
" --model-type=transformer -O1 "
" --internal-hlo2tensorizer-options='--verify-hlo' ")
compiler_flags = [
"--model-type=transformer -O1",
"--internal-hlo2tensorizer-options='--verify-hlo'",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str

random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)

min_ctx_len = 2
max_ctx_len = 64
min_query_len = 2
max_query_len = 64
prefill_batch_size = 2
decode_batch_size = 6
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
prefill_batch_size = 4
decode_batch_size = 12
batch_size = prefill_batch_size + decode_batch_size
block_size = 32
max_model_len = (max_query_len + max_ctx_len) * 4

max_block_per_request = max_model_len // block_size
dtype = torch.float32
cache_size = (batch_size * max_block_per_request) + 2
ctx_lens = [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(prefill_batch_size)
] + [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(decode_batch_size)
]
query_lens = [
random.randint(min_query_len, max_query_len)
for _ in range(prefill_batch_size)
] + [1 for _ in range(decode_batch_size)]
prefill_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (prefill_batch_size, ),
dtype=torch.long).tolist()
decode_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (decode_batch_size, ),
dtype=torch.long).tolist()
ctx_lens = prefill_ctx_lens + decode_ctx_lens
query_lens = torch.randint(
min_query_len,
max_query_len + 1,
(prefill_batch_size, ),
dtype=torch.long,
).tolist() + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv

Expand Down Expand Up @@ -254,7 +271,6 @@ def test_contexted_kv_attention(
values = values[torch.randperm(cache_size)]
block_table = values[:batch_size * max_block_per_request].view(
batch_size, max_block_per_request)
torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
Expand Down Expand Up @@ -311,9 +327,7 @@ def test_contexted_kv_attention(
# build neuron program
return_debug_tensors = False
B_P_SIZE = 128
LARGE_TILE_SZ = 2048
max_num_queries = (
(sum(query_lens) + block_size - 1) // block_size) * block_size
LARGE_TILE_SZ = large_tile_size

def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
Expand All @@ -332,26 +346,28 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
0,
)

def shift_bit_length(x):
return 1 << (x - 1).bit_length()
def ceil_div(a, b):
return (a + b - 1) // b

def pad_to_multiple(a, b):
return ceil_div(a, b) * b

def pad_to_next_power_of_2(a):
assert a > 0
return 2**int(a - 1).bit_length()

# calculate input shapes
max_num_queries_shifted = shift_bit_length(max_num_queries)
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
assert (max_num_queries_padded == B_P_SIZE
), "invalid {max_num_queries_padded=}"
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
max_num_queries = pad_to_next_power_of_2(max_num_queries)
head_size_padded = B_P_SIZE
assert head_size_padded >= head_size
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks_shifted = shift_bit_length(
((context_lens + block_size - 1) // block_size).sum().item())
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
num_active_blocks_shifted)
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
assert (num_active_blocks *
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
LARGE_TILE_SZ // block_size)
context_kv_len = num_active_blocks * block_size
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
assert (context_kv_len %
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"

# pad QKV tensors
pad_dims = (
Expand All @@ -360,7 +376,7 @@ def shift_bit_length(x):
0,
0,
0,
max_num_queries_padded - query.shape[0],
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k, pad_dims, "constant", 0)
Expand Down Expand Up @@ -397,7 +413,7 @@ def shift_bit_length(x):
0,
context_kv_len - prior_mask.shape[1],
0,
B_P_SIZE - prior_mask.shape[0],
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
Expand All @@ -406,9 +422,9 @@ def shift_bit_length(x):
active_mask,
(
0,
B_P_SIZE - active_mask.shape[1],
max_num_queries - active_mask.shape[1],
0,
B_P_SIZE - active_mask.shape[0],
max_num_queries - active_mask.shape[0],
),
"constant",
0,
Expand All @@ -430,6 +446,8 @@ def shift_bit_length(x):
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
return_debug_tensors=return_debug_tensors,
)

if return_debug_tensors:
Expand All @@ -439,17 +457,15 @@ def shift_bit_length(x):
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
debug_tensors = []

output_nki = torch.tensor(output_nki).cpu()
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]

num_actual_tokens = sum(query_lens)
print(f"{num_actual_tokens=}")
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.permute(
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
output_nki = output_nki[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
"constant",
0,
)
Expand Down
Loading