-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[Performance] Fused blockwise quant RMS norm #27883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Fused blockwise quant RMS norm #27883
Conversation
Signed-off-by: ElizaWszola <ewszola@redhat.com>
…or int8 Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
|
The optimization of this commit is beneficial: [-------------------------------------------- rms-norm-dynamic-per-token-quant --------------------------------------------]
| unfused_groupwise_fp8_impl | fused_groupwise_fp8_impl
1 threads: -----------------------------------------------------------------------------------------------------------------
N 1 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 31.4 | 29.4
N 1 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 34.0 | 30.4
N 1 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 31.3 | 29.6
N 1 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 34.0 | 29.5
N 4 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 30.1 | 29.5
N 4 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 35.1 | 31.2
N 4 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 32.4 | 32.5
N 4 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 36.1 | 30.7
N 16 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 31.6 | 31.4
N 16 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 35.2 | 32.3
N 16 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 32.8 | 32.2
N 16 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 35.1 | 31.6
N 64 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 31.8 | 31.5
N 64 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 35.2 | 32.7
N 64 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 31.8 | 31.6
N 64 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 36.1 | 32.1
N 256 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 32.8 | 32.3
N 256 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 36.1 | 32.0
N 256 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 32.6 | 32.3
N 256 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 35.2 | 31.5
N 1024 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 31.4 | 39.0
N 1024 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 35.1 | 36.9
N 1024 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 31.8 | 53.3
N 1024 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 35.5 | 49.3 now [-------------------------------------------- rms-norm-dynamic-per-token-quant --------------------------------------------]
| unfused_groupwise_fp8_impl | fused_groupwise_fp8_impl
1 threads: -----------------------------------------------------------------------------------------------------------------
N 1 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 30.9 | 19.6
N 1 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 36.5 | 19.4
N 1 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 30.5 | 19.6
N 1 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 36.5 | 19.6
N 4 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 30.4 | 19.5
N 4 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 34.2 | 19.3
N 4 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 30.5 | 19.6
N 4 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 34.2 | 19.4
N 16 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 31.8 | 19.6
N 16 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 36.4 | 19.5
N 16 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 30.7 | 19.7
N 16 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 36.5 | 19.7
N 64 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 31.8 | 19.7
N 64 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 36.5 | 19.6
N 64 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 30.4 | 19.6
N 64 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 34.3 | 19.5
N 256 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 30.1 | 19.4
N 256 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 34.4 | 19.8
N 256 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 30.7 | 19.6
N 256 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 34.2 | 19.5
N 1024 x D 1024 x R True x DT torch.bfloat16x GS [1, 128] | 30.7 | 19.4
N 1024 x D 1024 x R False x DT torch.bfloat16x GS [1, 128] | 34.4 | 19.4
N 1024 x D 5120 x R True x DT torch.bfloat16x GS [1, 128] | 30.7 | 28.7
N 1024 x D 5120 x R False x DT torch.bfloat16x GS [1, 128] | 34.5 | 28.7 |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: ElizaWszola <ewszola@redhat.com>
…agic/vllm into blockwise-quant-rms-norm Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yewentao256 @varun-sundar-rabindranath for kernel review as well
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
|
||
| if (residual.has_value()) { | ||
| if (is_scale_transposed) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have a bool dispatch macro
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found one in SM100 CUTLASS file, but it didn't do quite what I needed it for, so I ended up adding my own macro in dispatch_utils. If it duplicates some already existing code, please lmk
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
Outdated
Show resolved
Hide resolved
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
Outdated
Show resolved
Hide resolved
Signed-off-by: ElizaWszola <ewszola@redhat.com>
|
Figured it out now, pushed the fix :) |
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
e4aa624 to
f4a206c
Compare
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com>
|
After this PR, Qwen3 VLs (and most likely other FP8 VLMs I guess) are failing with the following error: which is raised at |
Summary: Fix AMD compilation failure for DeepSeek models introduced in vllm-project#27883. The issue was that RMSNormQuantFusionPass unconditionally creates FusedAddRMSNormGroupQuantPattern and RMSNormGroupQuantPattern for group quantization (GroupShape 64 and 128), but the underlying C++ operation per_token_group_fp8_quant is only available on CUDA (wrapped in #ifndef USE_ROCM in torch_bindings.cpp). On AMD platforms, this caused an assertion failure: AssertionError: unsupported quantization scheme QuantKey(f8e4m3fnuz,scale(f32,dynamic,GroupShape(row=1, col=128)),symmetric) The fix guards the creation of group quant patterns with current_platform.is_cuda(), matching the guard used for registering these keys in QUANT_OPS. Test Plan: Waiting for this deepseek job on amd to complete: https://www.internalfb.com/vanguard/serving_test_cases/1967790977283741 Will also wait for external CI Differential Revision: D88608586 Privacy Context Container: L1370295
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: mayoohee <yiweiii.fang@gmail.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com>
CUDA kernel and fusion code for Fused Groupwise FP8-Quantized RMS Norm. This code allows to fuse RMS Norm + FP8 Quantization of the RMS Norm's output when
enable_fusion==True.Testing:
Test fused op
pytest tests/kernels/core/test_fused_quant_layernorm.pyTest fusion
pytest tests/compile/test_fusion.py(tested with both
VLLM_USE_DEEP_GEMM=1andVLLM_USE_DEEP_GEMM=0)Offline inference
Run with
(tested with both
VLLM_USE_DEEP_GEMM=1andVLLM_USE_DEEP_GEMM=0, verified that the fused kernel is being produced)Benchmarking:
Microbenchmark isolated op:
python benchmarks/fused_kernels/layernorm_rms_benchmarks.pyResults on H100 (click to show)
Results of E2E sonnet benchmark of
Qwen/Qwen3-30B-A3B-FP8compared to main (H100):