-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel #12591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel #12591
Conversation
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Very Exciting! I'll give it a try tomorrow morning. Thanks for the kernel! |
Hi, @rasmith. Are you planning to add block_table support to this kernel? If so, great! Feel free to either add that here or make a follow on PR. If not, I'd be happy to look into this, but we will need it in order to integrate this kernel into V1. |
Hi @SageMoore, please try this if you need cache/block support: vllm/vllm/attention/ops/prefix_prefill.py Line 20 in e3f7ff6
|
I think there are a couple of things getting lost in communication here.
|
@SageMoore Did this work for you? |
Would you be OK with a follow on PR? Do you have any other comments for the PR? |
@rasmith in this kernel's current state, it's only usable on v0 when the kv cache is empty. Are you seeing speedups over the previous implementation? If so, could you post your results to the PR? Assuming it does make things faster and we want to merge it, I'll give it a full review. It would also be good to add in some unit tests since I suspect the CI integration testing for this pathway is somewhat lacking. We should still remove the backwards pass stuff as well. |
This pull request has merge conflicts that must be resolved before it can be |
71f89c5
to
b260782
Compare
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM now, thank you!
@mgoin could you manually kick off the docker build with extended timeout? |
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
…iler Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
…r Triton fp8 Signed-off-by: Randall Smith <Randall.Smith@amd.com>
…to ransmith_triton_fav2_vsl
… to Triton FAv2 kernel (vllm-project#12591) Signed-off-by: Randall Smith <Randall.Smith@amd.com>
… to Triton FAv2 kernel (vllm-project#12591) Signed-off-by: Randall Smith <Randall.Smith@amd.com>
… to Triton FAv2 kernel (vllm-project#12591) Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
… to Triton FAv2 kernel (vllm-project#12591) Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
… to Triton FAv2 kernel (vllm-project#12591) Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
… to Triton FAv2 kernel (vllm-project#12591) Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: minpeter <kali2005611@gmail.com>
This PR adds fp8 and variable length sequence support to Triton FAv2 kernel.
This kernel supports 8-bit KV cache, and also the following (forward only):
This kernel is slightly faster than the current one:
Benchmarks:
================================
Llama-3-8B-Instruct
VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Llama-3-8B-Instruct/
old kernel:
Avg latency: 7.028674733600928 seconds
new kernel:
Avg latency: 6.267468033730983 seconds
===============================
Phi-3-medium-128k-instruct-quantized.w8a8
VLLM_USE_TRITON_FLASH_ATTN=1 python benchmarks/benchmark_latency.py --enable-chunked-prefill False --load-format dummy --batch-size 64 --num-iters-warmup 2 --num-iters 5 --input-len 2048 --output-len 128 --model /models/Phi-3-medium-128k-instruct-quantized.w8a8/
old kernel:
Avg latency: 10.266006066184492 seconds
new kernel:
Avg latency: 9.983664013911039 seconds
PPL Measurements:
context-size=1024
sample-size=512
max-model-len 32768
model=Llama-3.1-8B-Instruct-FP8-QKV-Prob
PPL=6.9046958999710615
model=Llama-3.1-8B-Instruct
PPL=6.5381874291070545
LM EVAL:
I included lm eval for KV and QKV-Prob models since the Prob model has the output of softmax(QK^T) quantized to fp8 before the second dot, which does seem to reduce accuracy. Well, QKV-Prob probably has the most quantization, so lowest accuracy, I think.
LM EVAL: Llama-3.1-8B-Instruct-FP8-KV
LM EVAL: Llama-3.1-8B-Instruct-FP8-QKV-Prob
LM EVAL: Llama-3.1-8B-Instruct