Skip to content

opt flashinfer mla cat #5822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Conversation

xu-yfei
Copy link

@xu-yfei xu-yfei commented Apr 28, 2025

Motivation

Base on #5748 and #5638 , for flashinfer mla, remove q and k cat.

Accuracy

Accuracy: 0.951
Invalid: 0.000
Latency: 228.672 s
Output throughput: 554.173 token/s

Performance

main branch:

{"run_name": "default", "batch_size": 1, "input_len": 1024, "output_len": 1024, "latency": 14.4687, "output_throughput": 70.77, "overall_throughput": 141.55}

{"run_name": "default", "batch_size": 16, "input_len": 1024, "output_len": 1024, "latency": 28.8723, "output_throughput": 567.47, "overall_throughput": 1134.93}

{"run_name": "default", "batch_size": 32, "input_len": 1024, "output_len": 1024, "latency": 38.0349, "output_throughput": 861.52, "overall_throughput": 1723.05}

this PR:

{"run_name": "default", "batch_size": 1, "input_len": 1024, "output_len": 1024, "latency": 14.5066, "output_throughput": 70.59, "overall_throughput": 141.18}

{"run_name": "default", "batch_size": 16, "input_len": 1024, "output_len": 1024, "latency": 28.4372, "output_throughput": 576.15, "overall_throughput": 1152.29}

{"run_name": "default", "batch_size": 32, "input_len": 1024, "output_len": 1024, "latency": 37.2972, "output_throughput": 878.57, "overall_throughput": 1757.13}

Profile

Prefill

main branch:
image

this PR:
image

47 us to 3us

Decode

main branch bs=1 cuda graph+torch compile:
image

this PR bs=1 cuda graph+torch compile:
image
bs=1, cuda graph+torch compile, almost the same

main branch bs=1 cuda graph+ without torch compile fused with other ops:
image

this PR bs=1 cuda graph+ without torch compile fused with other ops:
image
bs=1, cuda graph, without torch compile, 6~7 us -> 1us

Modifications

  • Update deepseek_v2 code, remove q and k cat.
  • In flashinfer_mla_backend:

Cat when ragged, no cat in other scenes.
Use set_mla_kv_buffer when k_rope is not empty

Checklist

@xu-yfei xu-yfei force-pushed the flashinfer_cat_opt branch from 4f25ae3 to 69840db Compare April 28, 2025 07:30
@lambert0312
Copy link
Contributor

I pulled the latest commit and did some experiments, and it seems to be consistent with the optimizations mentioned above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants