17
17
from xformers import ops as xops
18
18
from xformers .ops .fmha .attn_bias import BlockDiagonalCausalMask
19
19
20
+ from vllm .attention .backends .xformers import _make_alibi_bias
21
+
20
22
FLOAT32_BYTES = torch .finfo (torch .float ).bits // 8
21
23
# This will change depending on the compute capability.
22
24
# - 512 as a buffer
@@ -345,20 +347,26 @@ def ref_multi_query_kv_attention(
345
347
key : torch .Tensor ,
346
348
value : torch .Tensor ,
347
349
scale : float ,
350
+ alibi_bias : Optional [list [torch .Tensor ]],
348
351
dtype : torch .dtype ,
349
352
) -> torch .Tensor :
350
353
num_seqs = len (cu_seq_lens ) - 1
351
354
ref_outputs : list [torch .Tensor ] = []
355
+ if alibi_bias :
356
+ assert len (alibi_bias ) == num_seqs
352
357
for i in range (num_seqs ):
353
358
start_idx = cu_seq_lens [i ]
354
359
end_idx = cu_seq_lens [i + 1 ]
355
360
seq_len = end_idx - start_idx
356
361
357
- # Create attention mask.
358
- attn_mask = torch .triu (torch .ones (seq_len , seq_len , dtype = dtype ),
359
- diagonal = 1 )
360
- attn_mask = attn_mask * torch .finfo (dtype ).min
361
- attn_mask = attn_mask .to (dtype = dtype )
362
+ # Create attention mask. ALiBi already includes a tril causal mask.
363
+ if alibi_bias :
364
+ attn_mask = alibi_bias [i ]
365
+ else :
366
+ attn_mask = torch .triu (torch .ones (seq_len , seq_len , dtype = dtype ),
367
+ diagonal = 1 )
368
+ attn_mask = attn_mask * torch .finfo (dtype ).min
369
+ attn_mask = attn_mask .to (dtype = dtype )
362
370
363
371
ref_output = ref_masked_attention (
364
372
query [start_idx :end_idx ],
@@ -372,7 +380,6 @@ def ref_multi_query_kv_attention(
372
380
return torch .cat (ref_outputs , dim = 0 )
373
381
374
382
375
- # TODO(woosuk): Add tests for USE_ALIBI=True.
376
383
@pytest .mark .parametrize ("num_seqs" , NUM_PREFILL_SEQS )
377
384
@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
378
385
@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
@@ -389,6 +396,7 @@ def test_multi_query_kv_attention(
389
396
dtype : torch .dtype ,
390
397
seed : int ,
391
398
device : str ,
399
+ use_alibi : bool = False ,
392
400
) -> None :
393
401
current_platform .seed_everything (seed )
394
402
torch .set_default_device (device )
@@ -414,16 +422,40 @@ def test_multi_query_kv_attention(
414
422
# Handle MQA and GQA
415
423
key = torch .repeat_interleave (key , num_queries_per_kv , dim = 1 )
416
424
value = torch .repeat_interleave (value , num_queries_per_kv , dim = 1 )
417
- attn_bias = BlockDiagonalCausalMask .from_seqlens (seq_lens )
418
- output = xops .memory_efficient_attention_forward (
419
- query .unsqueeze (0 ),
420
- key .unsqueeze (0 ),
421
- value .unsqueeze (0 ),
422
- attn_bias = attn_bias ,
423
- p = 0.0 ,
424
- scale = scale ,
425
- )
426
- output = output .squeeze (0 )
425
+ alibi_bias = None
426
+ if use_alibi :
427
+ alibi_slopes = torch .randn (num_query_heads , dtype = torch .float )
428
+ attn_bias = _make_alibi_bias (alibi_slopes , num_kv_heads , dtype ,
429
+ seq_lens )
430
+ output = torch .empty_like (query )
431
+ start = 0
432
+ # Dynamic sequence length not supported with custom attn_bias.
433
+ for i , seq_len in enumerate (seq_lens ):
434
+ end = start + seq_len
435
+ out = xops .memory_efficient_attention_forward (
436
+ query [None , start :end ],
437
+ key [None , start :end ],
438
+ value [None , start :end ],
439
+ attn_bias = attn_bias [i ],
440
+ p = 0.0 ,
441
+ scale = scale )
442
+ output [start :end ].copy_ (out .view_as (query [start :end ]))
443
+ start += seq_len
444
+ # xformers.AttentionBias to Tensor for use in reference impl.
445
+ alibi_bias = [
446
+ b .materialize (b .shape , device = device ).squeeze () for b in attn_bias
447
+ ]
448
+ else :
449
+ attn_bias = BlockDiagonalCausalMask .from_seqlens (seq_lens )
450
+ output = xops .memory_efficient_attention_forward (
451
+ query .unsqueeze (0 ),
452
+ key .unsqueeze (0 ),
453
+ value .unsqueeze (0 ),
454
+ attn_bias = attn_bias ,
455
+ p = 0.0 ,
456
+ scale = scale ,
457
+ )
458
+ output = output .squeeze (0 )
427
459
428
460
cu_seq_lens = [0 ]
429
461
for seq_len in seq_lens :
@@ -434,8 +466,37 @@ def test_multi_query_kv_attention(
434
466
key ,
435
467
value ,
436
468
scale ,
469
+ alibi_bias ,
437
470
dtype ,
438
471
)
439
472
atol = get_default_atol (output ) if current_platform .is_rocm () else 1e-3
440
473
rtol = get_default_rtol (output ) if current_platform .is_rocm () else 1e-5
441
- torch .testing .assert_close (output , ref_output , atol = atol , rtol = rtol )
474
+ torch .testing .assert_close (output , ref_output , atol = atol , rtol = rtol )
475
+
476
+
477
+ @pytest .mark .parametrize ("num_seqs" , NUM_PREFILL_SEQS )
478
+ @pytest .mark .parametrize ("num_heads" , NUM_HEADS )
479
+ @pytest .mark .parametrize ("head_size" , [64 ])
480
+ @pytest .mark .parametrize ("dtype" , DTYPES )
481
+ @pytest .mark .parametrize ("seed" , SEEDS )
482
+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
483
+ @pytest .mark .skipif (current_platform .is_rocm (),
484
+ reason = "Xformers backend is not supported on ROCm." )
485
+ @torch .inference_mode ()
486
+ def test_multi_query_kv_attention_with_alibi (
487
+ num_seqs : int ,
488
+ num_heads : tuple [int , int ],
489
+ head_size : int ,
490
+ dtype : torch .dtype ,
491
+ seed : int ,
492
+ device : str ,
493
+ ) -> None :
494
+ return test_multi_query_kv_attention (
495
+ num_seqs ,
496
+ num_heads ,
497
+ head_size ,
498
+ dtype ,
499
+ seed ,
500
+ device ,
501
+ use_alibi = True ,
502
+ )
0 commit comments