Skip to content

Commit 273c949

Browse files
sanyalingtongshtrasshajrawitjtanaavllmellm
authored
Faster Custom Paged Attention kernels (#372)
* integrate new cpa kernel, update tests and benchmark * added comments to mfma4 kernel * further comments for mfma16 kernel * clang-format * Lint * add flag for logits rtz conversion and disable by default * lint * [Bugfix]: Fix paged attention unit tests of #372 (#389) * [Bugfix]: fix paged attention tests based on the updated kernels in `csrc/attention/paged_attention_v1.cu`,`csrc/attention/paged_attention_v2.cu` and `csrc/rocm/attention.cu`. * improve code documentation. * lint --------- Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> --------- Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Joe Shajrawi <17753158+shajrawi@users.noreply.github.com> Co-authored-by: TJian <tunjian1996@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent 7a292f9 commit 273c949

File tree

3 files changed

+1016
-402
lines changed

3 files changed

+1016
-402
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
1010
create_kv_caches_with_random)
1111

12-
NUM_BLOCKS = 1024 * 1024
12+
NUM_BLOCKS = 128 * 1024
1313
PARTITION_SIZE = 512
14+
PARTITION_SIZE_ROCM = 256
1415

1516

1617
@torch.inference_mode()
@@ -78,9 +79,12 @@ def main(
7879
# Prepare for the paged attention kernel.
7980
output = torch.empty_like(query)
8081
if version == "v2":
81-
if current_platform.is_rocm() and not args.custom_paged_attn:
82+
if current_platform.is_rocm():
8283
global PARTITION_SIZE
83-
PARTITION_SIZE = 1024
84+
if not args.custom_paged_attn:
85+
PARTITION_SIZE = 1024
86+
else:
87+
PARTITION_SIZE = PARTITION_SIZE_ROCM
8488
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
8589
tmp_output = torch.empty(
8690
size=(num_seqs, num_query_heads, num_partitions, head_size),
@@ -163,6 +167,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
163167
kv_cache_dtype,
164168
k_scale,
165169
v_scale,
170+
None,
171+
PARTITION_SIZE,
166172
)
167173
else:
168174
raise ValueError(f"Invalid version: {version}")
@@ -176,13 +182,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
176182
# Warmup.
177183
print("Warming up...")
178184
run_benchmark = run_cuda_benchmark
179-
run_benchmark(num_iters=3, profile=False)
185+
run_benchmark(num_iters=500, profile=False)
180186

181187
# Benchmark.
182188
if do_profile:
183189
latency = run_benchmark(num_iters=1, profile=True)
184190
else:
185-
latency = run_benchmark(num_iters=1000, profile=False)
191+
latency = run_benchmark(num_iters=10000, profile=False)
186192
print(f"Kernel running time: {latency * 1000000:.3f} us")
187193

188194

0 commit comments

Comments
 (0)