Skip to content

[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel #12591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 101 commits into from
Apr 27, 2025

Conversation

rasmith
Copy link
Contributor

@rasmith rasmith commented Jan 30, 2025

This PR adds fp8 and variable length sequence support to Triton FAv2 kernel.

This kernel supports 8-bit KV cache, and also the following (forward only):

  1. Fwd with causal masking
  2. Arbitrary Q and KV sequence lengths
  3. Arbitrary head sizes
  4. Multi and grouped query attention
  5. Variable sequence lengths
  6. ALiBi and matrix bias
  7. Supports fp8 for models, currently for Llama-3.1-8B-Instruct-FP8-QKV-Prob

This kernel is slightly faster than the current one:

Benchmarks:

================================
Llama-3-8B-Instruct

VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Llama-3-8B-Instruct/

old kernel:
Avg latency: 7.028674733600928 seconds

new kernel:
Avg latency: 6.267468033730983 seconds

===============================
Phi-3-medium-128k-instruct-quantized.w8a8

VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Phi-3-medium-128k-instruct-quantized.w8a8/

old kernel:
Avg latency: 10.266006066184492 seconds

new kernel:
Avg latency: 9.983664013911039 seconds

PPL Measurements:

context-size=1024
sample-size=512
max-model-len 32768

model=Llama-3.1-8B-Instruct-FP8-QKV-Prob

PPL=6.9046958999710615

model=Llama-3.1-8B-Instruct

PPL=6.5381874291070545

LM EVAL:

I included lm eval for KV and QKV-Prob models since the Prob model has the output of softmax(QK^T) quantized to fp8 before the second dot, which does seem to reduce accuracy. Well, QKV-Prob probably has the most quantization, so lowest accuracy, I think.

LM EVAL: Llama-3.1-8B-Instruct-FP8-KV

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.796 ± 0.0255
strict-match 5 exact_match 0.732 ± 0.0281

LM EVAL: Llama-3.1-8B-Instruct-FP8-QKV-Prob

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.756 ± 0.0272
strict-match 5 exact_match 0.592 ± 0.0311

LM EVAL: Llama-3.1-8B-Instruct

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.780 ± 0.0263
strict-match 5 exact_match 0.724 ± 0.0283

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@rasmith rasmith changed the title [Kernel][Triton][Quantization] Adding variable length sequence support to Triton FAv2 kernel [Kernel][Triton] Adding variable length sequence support to Triton FAv2 kernel Jan 30, 2025
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@SageMoore
Copy link
Contributor

Very Exciting! I'll give it a try tomorrow morning. Thanks for the kernel!

@SageMoore
Copy link
Contributor

Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1.

@maleksan85
Copy link
Contributor

Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1.

Hi @SageMoore, please try this if you need cache/block support:

@SageMoore
Copy link
Contributor

I think there are a couple of things getting lost in communication here.

  1. In order for this kernel to be used in vllm going forward, it needs to support mixing prefills and decodes into the same batch.
  2. There's no need for a backwards pass kernel.
  3. I think there's some general confusion regarding the rocm backend. This is a v0 only backend that will be deprecated once we have a working kernel for v1. A good template to look at would be the vllm/v1/attention/backends/flash_attn.py backend. Matching the flash_attn_varlen_func signature with this kernel would be a good goal.

@rasmith
Copy link
Contributor Author

rasmith commented Feb 4, 2025

Very Exciting! I'll give it a try tomorrow morning. Thanks for the kernel!

@SageMoore Did this work for you?

@rasmith
Copy link
Contributor Author

rasmith commented Feb 4, 2025

Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1.

Would you be OK with a follow on PR? Do you have any other comments for the PR?

@SageMoore
Copy link
Contributor

@rasmith in this kernel's current state, it's only usable on v0 when the kv cache is empty. Are you seeing speedups over the previous implementation? If so, could you post your results to the PR?

Assuming it does make things faster and we want to merge it, I'll give it a full review. It would also be good to add in some unit tests since I suspect the CI integration testing for this pathway is somewhat lacking. We should still remove the backwards pass stuff as well.

Copy link

mergify bot commented Feb 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @rasmith.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2025
@rasmith rasmith closed this Feb 7, 2025
@rasmith rasmith force-pushed the ransmith_triton_fav2_vsl branch from 71f89c5 to b260782 Compare February 7, 2025 03:22
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@rasmith rasmith reopened this Feb 7, 2025
@mergify mergify bot removed the needs-rebase label Feb 7, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM now, thank you!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 22, 2025
@ProExpertProg
Copy link
Collaborator

@mgoin could you manually kick off the docker build with extended timeout?

rasmith added 9 commits April 23, 2025 17:30
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
…iler

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
…r Triton fp8

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) April 26, 2025 01:23
@tlrmchlsmth tlrmchlsmth merged commit 8e4b351 into vllm-project:main Apr 27, 2025
44 checks passed
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
… to Triton FAv2 kernel (vllm-project#12591)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
… to Triton FAv2 kernel (vllm-project#12591)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
… to Triton FAv2 kernel (vllm-project#12591)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
… to Triton FAv2 kernel (vllm-project#12591)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
… to Triton FAv2 kernel (vllm-project#12591)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
… to Triton FAv2 kernel (vllm-project#12591)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants