-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Feature] Enable triton scaled mm for NVIDIA GPUs with ahead-of-time autotuning #20163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feature] Enable triton scaled mm for NVIDIA GPUs with ahead-of-time autotuning #20163
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @gau-nernst, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request aims to boost the performance of scaled matrix multiplication for quantized models on NVIDIA GPUs by integrating an optimized Triton kernel with an ahead-of-time autotuning mechanism. It refines the existing Triton implementation, introduces a comprehensive tuning workflow, and provides pre-tuned configurations for specific hardware and model scenarios, offering a faster alternative to the default Cutlass kernel.
Highlights
- Performance Improvement for NVIDIA GPUs: This PR introduces the
TritonScaledMMLinearKernel
for NVIDIA GPUs, specifically targeting improved performance for scaled matrix multiplication (MM) in INT8 quantization, particularly for models like Qwen3-4B on RTX A6000. While Cutlass remains the default, users can explicitly enable Triton viaVLLM_DISABLED_KERNELS
. - Triton Kernel Enhancements: The existing Triton scaled MM kernel has been significantly revamped. Key improvements include the addition of PID swizzling for better L2 cache utilization, simplification of tensor-wise vs. row-wise scaling logic, and optimization of index calculations within the kernel.
- Ahead-of-Time Autotuning: A new ahead-of-time autotuning facility has been implemented. This includes a dedicated tuning script (
tune_triton_w8a8.py
) to generate optimal kernel configurations for various shapes and devices. For small input sizes (M <= 32), a heuristic is used, while larger sizes leverage pre-tuned configurations loaded from JSON files, with a fallback to default heuristics if no tuned config is found. - Integration and Configuration: The
TritonScaledMMLinearKernel
is now a selectable option for CUDA platforms. Thecutlass_scaled_mm
function has been streamlined, removing its internal fallback to Triton, as kernel selection is now managed at a higher level. Pre-tuned configurations for Qwen3-4B on RTX A6000 have been added to the repository.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
95bf565
to
3b1fd0b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a performant Triton-based scaled matrix multiplication kernel for NVIDIA GPUs, with ahead-of-time autotuning to improve performance. The changes include a new tuning script and pre-tuned configurations.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) | ||
masks_am = offsets_am < M | ||
offsets_am = offsets_am % M | ||
|
||
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) | ||
masks_bn = offsets_bn < N | ||
offsets_bn = offsets_bn % N | ||
|
||
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) | ||
offsets_a = (stride_am * offsets_am[:, None] + | ||
stride_ak * offsets_k[None, :]) | ||
offsets_b = (stride_bk * offsets_k[:, None] + | ||
stride_bn * offsets_bn[None, :]) | ||
|
||
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create | ||
# appropriate offsets and masks for each case. Same goes for | ||
# BLOCK_SIZE_SCALE_B. | ||
offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) + | ||
(BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M) | ||
masks_scale_am = offsets_scale_am < M | ||
|
||
offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) + | ||
(BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N) | ||
masks_scale_bn = offsets_scale_bn < N | ||
|
||
a_ptrs = a_ptr + offsets_a | ||
b_ptrs = b_ptr + offsets_b | ||
|
||
scale_a_ptrs = scale_a_ptr + offsets_scale_am | ||
scale_b_ptrs = scale_b_ptr + offsets_scale_bn | ||
|
||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | ||
masks_k = offsets_k < K | ||
masks_a = masks_am[:, None] & masks_k[None, :] | ||
a = tl.load(a_ptrs, mask=masks_a) | ||
|
||
masks_b = masks_k[:, None] & masks_bn[None, :] | ||
b = tl.load(b_ptrs, mask=masks_b) | ||
for k in range(K, 0, -BLOCK_SIZE_K): | ||
masks_k = offsets_k < k | ||
a = tl.load(a_ptrs, mask=masks_k[None, :]) | ||
b = tl.load(b_ptrs, mask=masks_k[:, None]) | ||
|
||
# Accumulate results. | ||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) | ||
|
||
offsets_k += BLOCK_SIZE_K | ||
a_ptrs += BLOCK_SIZE_K * stride_ak | ||
b_ptrs += BLOCK_SIZE_K * stride_bk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updated scaled_mm_kernel
has a critical issue in how it handles boundary conditions for the M and N dimensions.
Using offsets_am = offsets_am % M
and offsets_bn = offsets_bn % N
for boundary checks is incorrect. This will cause the kernel to read from incorrect memory locations (wrapping around to the beginning of the tensor) for out-of-bounds indices, leading to incorrect results. This happens whenever M
or N
are not perfectly divisible by BLOCK_SIZE_M
or BLOCK_SIZE_N
, respectively.
The correct approach is to use masks to safely load data, padding with zeros for elements that are out of bounds. The tl.load
calls for matrices a
and b
are missing these masks for the M and N dimensions.
Please restore the masking logic for correctness.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
masks_am = offsets_am < M
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
masks_bn = offsets_bn < N
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
offsets_a = (stride_am * offsets_am[:, None] +
stride_ak * offsets_k[None, :])
offsets_b = (stride_bk * offsets_k[:, None] +
stride_bn * offsets_bn[None, :])
a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b
for k in range(K, 0, -BLOCK_SIZE_K):
masks_k = offsets_k < k
a = tl.load(a_ptrs, mask=masks_am[:, None] & masks_k[None, :], other=0.0)
b = tl.load(b_ptrs, mask=masks_bn[None, :] & masks_k[:, None], other=0.0)
# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap around is a correct trick to minimize mask computation. Since we have masking when writing output to global memory, the out-of-bounds M and N here do not matter.
best_configs_list = pool.map(tune_on_gpu, process_args) | ||
|
||
# merge configs from all GPU. sort by M | ||
best_configs = dict(sorted(sum(best_configs_list, start=[]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using sum(best_configs_list, start=[])
to flatten the list of lists can be inefficient for a large number of sublists, as it can lead to quadratic runtime complexity due to intermediate list creations.
A more memory-efficient and performant approach is to use a list comprehension, like [item for sublist in best_configs_list for item in sublist]
. Since this is a utility script and the number of GPUs is likely small, this is not a critical issue, but it's a good practice to be aware of.
best_configs = dict(sorted(sum(best_configs_list, start=[]))) | |
best_configs = dict(sorted([item for sublist in best_configs_list for item in sublist])) |
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
When deploying Qwen3-4B on RTX A6000, I noticed that the cutlass scaled mm kernel is not performant enough. Switching to triton implementation is significantly faster for some shapes, thanks to autotuning. Hence, this PR enables triton scaled mm on NVIDIA GPUs with ahead-of-time autotuning, similar to triton MoE kernels.
Summary of changes
TritonScaledMMLinearKernel
as a possible kernel for CUDA platform. Cutlass still takes precedence. Hence, this is only enabled ifVLLM_DISABLED_KERNELS="CutlassScaledMMLinearKernel"
is setmasks_am
andmasks_bn
computation by doingoffsets_am % M
andoffsets_bn % N
tune_triton_w8a8.py
My focus is Qwen3-4B W8A8 INT8 on RTX A6000, hence I only added those configs. Technically this triton scaled_mm can support FP8 as well for sm89 GPUs.
Test Plan
tests/kernels/quantization/test_triton_scaled_mm.py
vllm bench
Kernel microbenchmark
Test Result
Kernel microbenchmark result
Note that speed of light for RTX A6000 is 154.8 BF16 TFLOPS and 309.7 INT8 TFLOPS. Hence, the result here is still not reaching hardware limit
E2E benchmark result (TODO)
Python overhead is killing the gain. Will need a bit more work on this...
Notes to reviewers
From what I understand, currently only ROCm uses the triton scaled mm codepath. Hence, ROCm maintainers, please help review this PR to see if it breaks existing functionality or causes performance regression. Though I see that now ROCm has AITER for scaled mm.
I also try to keep existing code structure, though I find some of them don't quite make sense e.g. why is
triton_scaled_mm.py
undercompressed_tensors
? Feel free to tell me how to structure the code.