Skip to content

Commit c63ec8b

Browse files
NickLuccheshreyankg
authored andcommitted
[Bugfix][CI] ALiBi test case in xformers multi_query_kv_attention (vllm-project#11301)
1 parent d889367 commit c63ec8b

File tree

3 files changed

+83
-22
lines changed

3 files changed

+83
-22
lines changed

tests/kernels/test_attention.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from xformers import ops as xops
1818
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
1919

20+
from vllm.attention.backends.xformers import _make_alibi_bias
21+
2022
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
2123
# This will change depending on the compute capability.
2224
# - 512 as a buffer
@@ -345,20 +347,26 @@ def ref_multi_query_kv_attention(
345347
key: torch.Tensor,
346348
value: torch.Tensor,
347349
scale: float,
350+
alibi_bias: Optional[list[torch.Tensor]],
348351
dtype: torch.dtype,
349352
) -> torch.Tensor:
350353
num_seqs = len(cu_seq_lens) - 1
351354
ref_outputs: list[torch.Tensor] = []
355+
if alibi_bias:
356+
assert len(alibi_bias) == num_seqs
352357
for i in range(num_seqs):
353358
start_idx = cu_seq_lens[i]
354359
end_idx = cu_seq_lens[i + 1]
355360
seq_len = end_idx - start_idx
356361

357-
# Create attention mask.
358-
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
359-
diagonal=1)
360-
attn_mask = attn_mask * torch.finfo(dtype).min
361-
attn_mask = attn_mask.to(dtype=dtype)
362+
# Create attention mask. ALiBi already includes a tril causal mask.
363+
if alibi_bias:
364+
attn_mask = alibi_bias[i]
365+
else:
366+
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
367+
diagonal=1)
368+
attn_mask = attn_mask * torch.finfo(dtype).min
369+
attn_mask = attn_mask.to(dtype=dtype)
362370

363371
ref_output = ref_masked_attention(
364372
query[start_idx:end_idx],
@@ -372,7 +380,6 @@ def ref_multi_query_kv_attention(
372380
return torch.cat(ref_outputs, dim=0)
373381

374382

375-
# TODO(woosuk): Add tests for USE_ALIBI=True.
376383
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
377384
@pytest.mark.parametrize("num_heads", NUM_HEADS)
378385
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@@ -389,6 +396,7 @@ def test_multi_query_kv_attention(
389396
dtype: torch.dtype,
390397
seed: int,
391398
device: str,
399+
use_alibi: bool = False,
392400
) -> None:
393401
current_platform.seed_everything(seed)
394402
torch.set_default_device(device)
@@ -414,16 +422,40 @@ def test_multi_query_kv_attention(
414422
# Handle MQA and GQA
415423
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
416424
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
417-
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
418-
output = xops.memory_efficient_attention_forward(
419-
query.unsqueeze(0),
420-
key.unsqueeze(0),
421-
value.unsqueeze(0),
422-
attn_bias=attn_bias,
423-
p=0.0,
424-
scale=scale,
425-
)
426-
output = output.squeeze(0)
425+
alibi_bias = None
426+
if use_alibi:
427+
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
428+
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
429+
seq_lens)
430+
output = torch.empty_like(query)
431+
start = 0
432+
# Dynamic sequence length not supported with custom attn_bias.
433+
for i, seq_len in enumerate(seq_lens):
434+
end = start + seq_len
435+
out = xops.memory_efficient_attention_forward(
436+
query[None, start:end],
437+
key[None, start:end],
438+
value[None, start:end],
439+
attn_bias=attn_bias[i],
440+
p=0.0,
441+
scale=scale)
442+
output[start:end].copy_(out.view_as(query[start:end]))
443+
start += seq_len
444+
# xformers.AttentionBias to Tensor for use in reference impl.
445+
alibi_bias = [
446+
b.materialize(b.shape, device=device).squeeze() for b in attn_bias
447+
]
448+
else:
449+
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
450+
output = xops.memory_efficient_attention_forward(
451+
query.unsqueeze(0),
452+
key.unsqueeze(0),
453+
value.unsqueeze(0),
454+
attn_bias=attn_bias,
455+
p=0.0,
456+
scale=scale,
457+
)
458+
output = output.squeeze(0)
427459

428460
cu_seq_lens = [0]
429461
for seq_len in seq_lens:
@@ -434,8 +466,37 @@ def test_multi_query_kv_attention(
434466
key,
435467
value,
436468
scale,
469+
alibi_bias,
437470
dtype,
438471
)
439472
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
440473
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
441-
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
474+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
475+
476+
477+
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
478+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
479+
@pytest.mark.parametrize("head_size", [64])
480+
@pytest.mark.parametrize("dtype", DTYPES)
481+
@pytest.mark.parametrize("seed", SEEDS)
482+
@pytest.mark.parametrize("device", CUDA_DEVICES)
483+
@pytest.mark.skipif(current_platform.is_rocm(),
484+
reason="Xformers backend is not supported on ROCm.")
485+
@torch.inference_mode()
486+
def test_multi_query_kv_attention_with_alibi(
487+
num_seqs: int,
488+
num_heads: tuple[int, int],
489+
head_size: int,
490+
dtype: torch.dtype,
491+
seed: int,
492+
device: str,
493+
) -> None:
494+
return test_multi_query_kv_attention(
495+
num_seqs,
496+
num_heads,
497+
head_size,
498+
dtype,
499+
seed,
500+
device,
501+
use_alibi=True,
502+
)

tests/kernels/test_prefix_prefill.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
439439
# heads.
440440
#
441441
# see also: vllm/model_executor/layers/attention.py
442-
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
443-
query.shape[-1])
444442
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
445443
num_queries_per_kv, key.shape[-1])
446444
value = value[:, :,
447445
None, :].expand(value.shape[0], num_kv_heads,
448446
num_queries_per_kv, value.shape[-1])
449-
447+
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
448+
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
449+
# codebase. We save some time reshaping alibi matrix at runtime.
450+
key = key.reshape(key.shape[0], -1, key.shape[-1])
451+
value = value.reshape(value.shape[0], -1, value.shape[-1])
450452
query = query.unsqueeze(0)
451453
key = key.unsqueeze(0)
452454
value = value.unsqueeze(0)

vllm/attention/backends/xformers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,8 +788,6 @@ def _make_alibi_bias(
788788
dtype=dtype,
789789
)[:, :, :, :seq_len].copy_(bias)
790790
bias.mul_(alibi_slopes[:, None, None])
791-
if num_heads != num_kv_heads:
792-
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
793791
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
794792

795793
return attn_biases

0 commit comments

Comments
 (0)