Skip to content

Commit 6bd1dd9

Browse files
authored
[Kernel] [V1] Improved performance for V1 Triton (ROCm) backend (#14152)
1 parent 4f27044 commit 6bd1dd9

File tree

4 files changed

+398
-77
lines changed

4 files changed

+398
-77
lines changed

tests/kernels/test_prefix_prefill.py

Lines changed: 76 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import math
44
import random
55
import time
6+
from collections.abc import Callable
67

78
import pytest
89
import torch
910
from xformers import ops as xops
1011
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
1112

1213
from vllm.attention.backends.xformers import _make_alibi_bias
14+
from vllm.attention.ops.chunked_prefill_paged_decode import (
15+
chunked_prefill_paged_decode)
1316
from vllm.attention.ops.prefix_prefill import context_attention_fwd
1417
from vllm.platforms import current_platform
1518
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
@@ -24,6 +27,8 @@
2427
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
2528
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
2629

30+
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
31+
2732

2833
@pytest.mark.parametrize("num_heads", NUM_HEADS)
2934
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@@ -32,6 +37,7 @@
3237
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
3338
@pytest.mark.parametrize("device", CUDA_DEVICES)
3439
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
40+
@pytest.mark.parametrize("op", OPS)
3541
@torch.inference_mode()
3642
def test_contexted_kv_attention(
3743
num_heads: int,
@@ -41,6 +47,7 @@ def test_contexted_kv_attention(
4147
dtype: torch.dtype,
4248
kv_cache_dtype: str,
4349
device: str,
50+
op: Callable,
4451
) -> None:
4552

4653
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
@@ -65,6 +72,9 @@ def test_contexted_kv_attention(
6572
block_size = 32
6673
max_block_per_request = 64
6774
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
75+
# ensure one sequence in batch is a decode
76+
query_lens[-1] = 1
77+
6878
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
6979
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
7080
num_kv_heads = num_heads // num_queries_per_kv
@@ -144,36 +154,36 @@ def test_contexted_kv_attention(
144154

145155
# Warm up the Triton kernel by calling it once before actually measuring
146156
# generation time
147-
context_attention_fwd(query,
148-
k,
149-
v,
150-
output,
151-
kv_cache_dtype,
152-
k_cache,
153-
v_cache,
154-
block_table,
155-
b_start_loc,
156-
b_seq_len,
157-
max_input_len,
158-
k_scale,
159-
v_scale,
160-
sliding_window=sliding_window)
157+
op(query,
158+
k,
159+
v,
160+
output,
161+
kv_cache_dtype,
162+
k_cache,
163+
v_cache,
164+
block_table,
165+
b_start_loc,
166+
b_seq_len,
167+
max_input_len,
168+
k_scale,
169+
v_scale,
170+
sliding_window=sliding_window)
161171
torch.cuda.synchronize()
162172
start_time = time.time()
163-
context_attention_fwd(query,
164-
k,
165-
v,
166-
output,
167-
kv_cache_dtype,
168-
k_cache,
169-
v_cache,
170-
block_table,
171-
b_start_loc,
172-
b_seq_len,
173-
max_input_len,
174-
k_scale,
175-
v_scale,
176-
sliding_window=sliding_window)
173+
op(query,
174+
k,
175+
v,
176+
output,
177+
kv_cache_dtype,
178+
k_cache,
179+
v_cache,
180+
block_table,
181+
b_start_loc,
182+
b_seq_len,
183+
max_input_len,
184+
k_scale,
185+
v_scale,
186+
sliding_window=sliding_window)
177187
torch.cuda.synchronize()
178188
end_time = time.time()
179189
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
@@ -228,7 +238,7 @@ def test_contexted_kv_attention(
228238
end_time = time.time()
229239
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
230240
output_ref = output_ref.reshape(output.shape)
231-
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
241+
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
232242
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
233243

234244

@@ -238,6 +248,7 @@ def test_contexted_kv_attention(
238248
@pytest.mark.parametrize("dtype", DTYPES)
239249
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
240250
@pytest.mark.parametrize("device", CUDA_DEVICES)
251+
@pytest.mark.parametrize("op", OPS)
241252
@torch.inference_mode()
242253
def test_contexted_kv_attention_alibi(
243254
num_heads: int,
@@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi(
246257
dtype: torch.dtype,
247258
kv_cache_dtype: str,
248259
device: str,
260+
op: Callable,
249261
) -> None:
250262

251263
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
@@ -375,36 +387,36 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
375387

376388
# Warm up the Triton kernel by calling it once before actually measuring
377389
# generation time
378-
context_attention_fwd(query,
379-
k,
380-
v,
381-
output,
382-
kv_cache_dtype,
383-
k_cache,
384-
v_cache,
385-
block_table,
386-
b_start_loc,
387-
b_seq_len,
388-
max_input_len,
389-
k_scale,
390-
v_scale,
391-
alibi_slopes=alibi_slopes)
390+
op(query,
391+
k,
392+
v,
393+
output,
394+
kv_cache_dtype,
395+
k_cache,
396+
v_cache,
397+
block_table,
398+
b_start_loc,
399+
b_seq_len,
400+
max_input_len,
401+
k_scale,
402+
v_scale,
403+
alibi_slopes=alibi_slopes)
392404
torch.cuda.synchronize()
393405
start_time = time.time()
394-
context_attention_fwd(query,
395-
k,
396-
v,
397-
output,
398-
kv_cache_dtype,
399-
k_cache,
400-
v_cache,
401-
block_table,
402-
b_start_loc,
403-
b_seq_len,
404-
max_input_len,
405-
k_scale,
406-
v_scale,
407-
alibi_slopes=alibi_slopes)
406+
op(query,
407+
k,
408+
v,
409+
output,
410+
kv_cache_dtype,
411+
k_cache,
412+
v_cache,
413+
block_table,
414+
b_start_loc,
415+
b_seq_len,
416+
max_input_len,
417+
k_scale,
418+
v_scale,
419+
alibi_slopes=alibi_slopes)
408420
torch.cuda.synchronize()
409421
end_time = time.time()
410422
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
@@ -503,6 +515,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
503515
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
504516
@pytest.mark.parametrize("device", CUDA_DEVICES)
505517
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
518+
@pytest.mark.parametrize("op", OPS)
506519
@torch.inference_mode()
507520
def test_contexted_kv_attention_f32(
508521
num_heads: int,
@@ -512,9 +525,11 @@ def test_contexted_kv_attention_f32(
512525
dtype: torch.dtype,
513526
kv_cache_dtype: str,
514527
device: str,
528+
op: Callable,
515529
) -> None:
516530
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
517-
sliding_window, dtype, kv_cache_dtype, device)
531+
sliding_window, dtype, kv_cache_dtype, device,
532+
op)
518533

519534

520535
@pytest.mark.optional
@@ -524,6 +539,7 @@ def test_contexted_kv_attention_f32(
524539
@pytest.mark.parametrize("dtype", [torch.float32])
525540
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
526541
@pytest.mark.parametrize("device", CUDA_DEVICES)
542+
@pytest.mark.parametrize("op", OPS)
527543
@torch.inference_mode()
528544
def test_contexted_kv_attention_alibi_f32(
529545
num_heads: int,
@@ -532,6 +548,7 @@ def test_contexted_kv_attention_alibi_f32(
532548
dtype: torch.dtype,
533549
kv_cache_dtype: str,
534550
device: str,
551+
op: Callable,
535552
) -> None:
536553
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
537-
dtype, kv_cache_dtype, device)
554+
dtype, kv_cache_dtype, device, op)

0 commit comments

Comments
 (0)