3
3
import math
4
4
import random
5
5
import time
6
+ from collections .abc import Callable
6
7
7
8
import pytest
8
9
import torch
9
10
from xformers import ops as xops
10
11
from xformers .ops .fmha .attn_bias import BlockDiagonalCausalFromBottomRightMask
11
12
12
13
from vllm .attention .backends .xformers import _make_alibi_bias
14
+ from vllm .attention .ops .chunked_prefill_paged_decode import (
15
+ chunked_prefill_paged_decode )
13
16
from vllm .attention .ops .prefix_prefill import context_attention_fwd
14
17
from vllm .platforms import current_platform
15
18
from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE
24
27
SLIDING_WINDOW = [0 , 16 , 64 , 128 , 256 , 512 , 2048 ]
25
28
KV_CACHE_DTYPES = ["auto" , "fp8" , "fp8_e5m2" ]
26
29
30
+ OPS = [chunked_prefill_paged_decode , context_attention_fwd ]
31
+
27
32
28
33
@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
29
34
@pytest .mark .parametrize ("num_queries_per_kv" , NUM_QUERIES_PER_KV )
32
37
@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPES )
33
38
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
34
39
@pytest .mark .parametrize ("sliding_window" , SLIDING_WINDOW )
40
+ @pytest .mark .parametrize ("op" , OPS )
35
41
@torch .inference_mode ()
36
42
def test_contexted_kv_attention (
37
43
num_heads : int ,
@@ -41,6 +47,7 @@ def test_contexted_kv_attention(
41
47
dtype : torch .dtype ,
42
48
kv_cache_dtype : str ,
43
49
device : str ,
50
+ op : Callable ,
44
51
) -> None :
45
52
46
53
if 'fp8' in kv_cache_dtype and not current_platform .has_device_capability (
@@ -65,6 +72,9 @@ def test_contexted_kv_attention(
65
72
block_size = 32
66
73
max_block_per_request = 64
67
74
query_lens = [random .randint (16 , MAX_SEQ_LEN ) for _ in range (BS )]
75
+ # ensure one sequence in batch is a decode
76
+ query_lens [- 1 ] = 1
77
+
68
78
ctx_lens = [random .randint (16 , MAX_CTX_LEN ) for _ in range (BS )]
69
79
seq_lens = [a + b for a , b in zip (query_lens , ctx_lens )]
70
80
num_kv_heads = num_heads // num_queries_per_kv
@@ -144,36 +154,36 @@ def test_contexted_kv_attention(
144
154
145
155
# Warm up the Triton kernel by calling it once before actually measuring
146
156
# generation time
147
- context_attention_fwd (query ,
148
- k ,
149
- v ,
150
- output ,
151
- kv_cache_dtype ,
152
- k_cache ,
153
- v_cache ,
154
- block_table ,
155
- b_start_loc ,
156
- b_seq_len ,
157
- max_input_len ,
158
- k_scale ,
159
- v_scale ,
160
- sliding_window = sliding_window )
157
+ op (query ,
158
+ k ,
159
+ v ,
160
+ output ,
161
+ kv_cache_dtype ,
162
+ k_cache ,
163
+ v_cache ,
164
+ block_table ,
165
+ b_start_loc ,
166
+ b_seq_len ,
167
+ max_input_len ,
168
+ k_scale ,
169
+ v_scale ,
170
+ sliding_window = sliding_window )
161
171
torch .cuda .synchronize ()
162
172
start_time = time .time ()
163
- context_attention_fwd (query ,
164
- k ,
165
- v ,
166
- output ,
167
- kv_cache_dtype ,
168
- k_cache ,
169
- v_cache ,
170
- block_table ,
171
- b_start_loc ,
172
- b_seq_len ,
173
- max_input_len ,
174
- k_scale ,
175
- v_scale ,
176
- sliding_window = sliding_window )
173
+ op (query ,
174
+ k ,
175
+ v ,
176
+ output ,
177
+ kv_cache_dtype ,
178
+ k_cache ,
179
+ v_cache ,
180
+ block_table ,
181
+ b_start_loc ,
182
+ b_seq_len ,
183
+ max_input_len ,
184
+ k_scale ,
185
+ v_scale ,
186
+ sliding_window = sliding_window )
177
187
torch .cuda .synchronize ()
178
188
end_time = time .time ()
179
189
print (f"triton Time: { (end_time - start_time )* 1000 :.2f} ms" )
@@ -228,7 +238,7 @@ def test_contexted_kv_attention(
228
238
end_time = time .time ()
229
239
print (f"xformers Time: { (end_time - start_time )* 1000 :.2f} ms" )
230
240
output_ref = output_ref .reshape (output .shape )
231
- atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
241
+ atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
232
242
torch .testing .assert_close (output , output_ref , atol = atol , rtol = 0 )
233
243
234
244
@@ -238,6 +248,7 @@ def test_contexted_kv_attention(
238
248
@pytest .mark .parametrize ("dtype" , DTYPES )
239
249
@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPES )
240
250
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
251
+ @pytest .mark .parametrize ("op" , OPS )
241
252
@torch .inference_mode ()
242
253
def test_contexted_kv_attention_alibi (
243
254
num_heads : int ,
@@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi(
246
257
dtype : torch .dtype ,
247
258
kv_cache_dtype : str ,
248
259
device : str ,
260
+ op : Callable ,
249
261
) -> None :
250
262
251
263
if 'fp8' in kv_cache_dtype and not current_platform .has_device_capability (
@@ -375,36 +387,36 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
375
387
376
388
# Warm up the Triton kernel by calling it once before actually measuring
377
389
# generation time
378
- context_attention_fwd (query ,
379
- k ,
380
- v ,
381
- output ,
382
- kv_cache_dtype ,
383
- k_cache ,
384
- v_cache ,
385
- block_table ,
386
- b_start_loc ,
387
- b_seq_len ,
388
- max_input_len ,
389
- k_scale ,
390
- v_scale ,
391
- alibi_slopes = alibi_slopes )
390
+ op (query ,
391
+ k ,
392
+ v ,
393
+ output ,
394
+ kv_cache_dtype ,
395
+ k_cache ,
396
+ v_cache ,
397
+ block_table ,
398
+ b_start_loc ,
399
+ b_seq_len ,
400
+ max_input_len ,
401
+ k_scale ,
402
+ v_scale ,
403
+ alibi_slopes = alibi_slopes )
392
404
torch .cuda .synchronize ()
393
405
start_time = time .time ()
394
- context_attention_fwd (query ,
395
- k ,
396
- v ,
397
- output ,
398
- kv_cache_dtype ,
399
- k_cache ,
400
- v_cache ,
401
- block_table ,
402
- b_start_loc ,
403
- b_seq_len ,
404
- max_input_len ,
405
- k_scale ,
406
- v_scale ,
407
- alibi_slopes = alibi_slopes )
406
+ op (query ,
407
+ k ,
408
+ v ,
409
+ output ,
410
+ kv_cache_dtype ,
411
+ k_cache ,
412
+ v_cache ,
413
+ block_table ,
414
+ b_start_loc ,
415
+ b_seq_len ,
416
+ max_input_len ,
417
+ k_scale ,
418
+ v_scale ,
419
+ alibi_slopes = alibi_slopes )
408
420
torch .cuda .synchronize ()
409
421
end_time = time .time ()
410
422
print (f"triton Time: { (end_time - start_time )* 1000 :.2f} ms" )
@@ -503,6 +515,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
503
515
@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPES )
504
516
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
505
517
@pytest .mark .parametrize ("sliding_window" , SLIDING_WINDOW )
518
+ @pytest .mark .parametrize ("op" , OPS )
506
519
@torch .inference_mode ()
507
520
def test_contexted_kv_attention_f32 (
508
521
num_heads : int ,
@@ -512,9 +525,11 @@ def test_contexted_kv_attention_f32(
512
525
dtype : torch .dtype ,
513
526
kv_cache_dtype : str ,
514
527
device : str ,
528
+ op : Callable ,
515
529
) -> None :
516
530
test_contexted_kv_attention (num_heads , num_queries_per_kv , head_size ,
517
- sliding_window , dtype , kv_cache_dtype , device )
531
+ sliding_window , dtype , kv_cache_dtype , device ,
532
+ op )
518
533
519
534
520
535
@pytest .mark .optional
@@ -524,6 +539,7 @@ def test_contexted_kv_attention_f32(
524
539
@pytest .mark .parametrize ("dtype" , [torch .float32 ])
525
540
@pytest .mark .parametrize ("kv_cache_dtype" , KV_CACHE_DTYPES )
526
541
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
542
+ @pytest .mark .parametrize ("op" , OPS )
527
543
@torch .inference_mode ()
528
544
def test_contexted_kv_attention_alibi_f32 (
529
545
num_heads : int ,
@@ -532,6 +548,7 @@ def test_contexted_kv_attention_alibi_f32(
532
548
dtype : torch .dtype ,
533
549
kv_cache_dtype : str ,
534
550
device : str ,
551
+ op : Callable ,
535
552
) -> None :
536
553
test_contexted_kv_attention_alibi (num_heads , num_queries_per_kv , head_size ,
537
- dtype , kv_cache_dtype , device )
554
+ dtype , kv_cache_dtype , device , op )
0 commit comments