Skip to content

[Feature] Enable triton scaled mm for NVIDIA GPUs with ahead-of-time autotuning #20163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

gau-nernst
Copy link
Contributor

@gau-nernst gau-nernst commented Jun 27, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

When deploying Qwen3-4B on RTX A6000, I noticed that the cutlass scaled mm kernel is not performant enough. Switching to triton implementation is significantly faster for some shapes, thanks to autotuning. Hence, this PR enables triton scaled mm on NVIDIA GPUs with ahead-of-time autotuning, similar to triton MoE kernels.

Summary of changes

  • Add TritonScaledMMLinearKernel as a possible kernel for CUDA platform. Cutlass still takes precedence. Hence, this is only enabled if VLLM_DISABLED_KERNELS="CutlassScaledMMLinearKernel" is set
  • Revamp existing triton scaled mm kernel
    • Add pid swizzling (to improve L2 cache)
    • Simplify tensor-wise vs row-wise scaling logic
    • Remove unnecessary masks_am and masks_bn computation by doing offsets_am % M and offsets_bn % N
  • Add ahead-of-time autotuning facility
    • For small M (<=32), choose tile size based on heuristic (this is bandwidth-bound, so tuning is not necessary)
    • For large M (>32), read pre-tuned configs from file. If pre-tuned configs do not exist, fallback to default config
    • Add tuning script at tune_triton_w8a8.py
  • Add tuned configs for Qwen3-4B on RTX A6000

My focus is Qwen3-4B W8A8 INT8 on RTX A6000, hence I only added those configs. Technically this triton scaled_mm can support FP8 as well for sm89 GPUs.

Test Plan

  • For correctness test, there is an existing test at tests/kernels/quantization/test_triton_scaled_mm.py
  • For kernel perf test, I include a script below (not included in PR)
  • For e2e perf test, I will be using vllm bench
Kernel microbenchmark
import time

import torch
import pandas as pd
from triton.testing import do_bench
from vllm._custom_ops import cutlass_scaled_mm
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import triton_scaled_mm

torch.set_default_device("cuda")

# Qwen3-4B shapes
problem_sizes = [
    (256, 6144, 2560),
    (256, 2560, 4096),
    (256, 19456, 2560),
    (256, 2560, 9728),
    (512, 6144, 2560),
    (512, 2560, 4096),
    (512, 19456, 2560),
    (512, 2560, 9728),
    (1024, 6144, 2560),
    (1024, 2560, 4096),
    (1024, 19456, 2560),
    (1024, 2560, 9728),
]

results = []

for M, N, K in problem_sizes:
    print(f"{M=}, {N=}, {K=}")

    def measure_tflops(f, *args, **kwargs):
        for _ in range(5):
            f(*args, **kwargs)

        latency = do_bench(lambda: f(*args, **kwargs))
        return 2 * M * N * K / latency / 1e9

    A = torch.randn(M, K, dtype=torch.bfloat16)
    B = torch.randn(N, K, dtype=torch.bfloat16)
    bf16_tflops = measure_tflops(torch.mm, A, B.T)

    A = torch.randint(-128, 127, size=(M, K), dtype=torch.int8)
    B = torch.randint(-128, 127, size=(N, K), dtype=torch.int8)
    scaleA = torch.randn(M)
    scaleB = torch.randn(N)

    cutlass_i8_tflops = measure_tflops(
        cutlass_scaled_mm, A, B.T, scaleA, scaleB, torch.bfloat16
    )
    triton_i8_tflops = measure_tflops(
        triton_scaled_mm, A, B.T, scaleA, scaleB, torch.bfloat16
    )

    results.append([M, N, K, bf16_tflops, cutlass_i8_tflops, triton_i8_tflops])
    time.sleep(2)  # let the GPU cool down

columns = ["M", "N", "K", "BF16 TFLOPS", "Cutlass INT8 TFLOPS", "Triton INT8 TFLOPS"]
df = pd.DataFrame(results, columns=columns)

float_columns = columns[3:]
df[float_columns] = df[float_columns].round(2)

print(df.to_markdown(index=False))

Test Result

Kernel microbenchmark result

M N K BF16 TFLOPS Cutlass INT8 TFLOPS Triton INT8 TFLOPS
256 6144 2560 83.5 111.42 159.04
256 2560 4096 88.93 104.61 163.69
256 19456 2560 121.65 167.72 211.82
256 2560 9728 89.57 126.91 184.21
512 6144 2560 90.35 148.47 193.32
512 2560 4096 129.29 191.52 210.17
512 19456 2560 106.04 172.77 228.83
512 2560 9728 134.5 238.38 240.23
1024 6144 2560 107.14 173.69 233.81
1024 2560 4096 136.2 208.76 250.56
1024 19456 2560 117.65 176.73 226.36
1024 2560 9728 129.91 236.77 250.57

Note that speed of light for RTX A6000 is 154.8 BF16 TFLOPS and 309.7 INT8 TFLOPS. Hence, the result here is still not reaching hardware limit

E2E benchmark result (TODO)

Python overhead is killing the gain. Will need a bit more work on this...

Notes to reviewers

From what I understand, currently only ROCm uses the triton scaled mm codepath. Hence, ROCm maintainers, please help review this PR to see if it breaks existing functionality or causes performance regression. Though I see that now ROCm has AITER for scaled mm.

I also try to keep existing code structure, though I find some of them don't quite make sense e.g. why is triton_scaled_mm.py under compressed_tensors? Feel free to tell me how to structure the code.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @gau-nernst, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request aims to boost the performance of scaled matrix multiplication for quantized models on NVIDIA GPUs by integrating an optimized Triton kernel with an ahead-of-time autotuning mechanism. It refines the existing Triton implementation, introduces a comprehensive tuning workflow, and provides pre-tuned configurations for specific hardware and model scenarios, offering a faster alternative to the default Cutlass kernel.

Highlights

  • Performance Improvement for NVIDIA GPUs: This PR introduces the TritonScaledMMLinearKernel for NVIDIA GPUs, specifically targeting improved performance for scaled matrix multiplication (MM) in INT8 quantization, particularly for models like Qwen3-4B on RTX A6000. While Cutlass remains the default, users can explicitly enable Triton via VLLM_DISABLED_KERNELS.
  • Triton Kernel Enhancements: The existing Triton scaled MM kernel has been significantly revamped. Key improvements include the addition of PID swizzling for better L2 cache utilization, simplification of tensor-wise vs. row-wise scaling logic, and optimization of index calculations within the kernel.
  • Ahead-of-Time Autotuning: A new ahead-of-time autotuning facility has been implemented. This includes a dedicated tuning script (tune_triton_w8a8.py) to generate optimal kernel configurations for various shapes and devices. For small input sizes (M <= 32), a heuristic is used, while larger sizes leverage pre-tuned configurations loaded from JSON files, with a fallback to default heuristics if no tuned config is found.
  • Integration and Configuration: The TritonScaledMMLinearKernel is now a selectable option for CUDA platforms. The cutlass_scaled_mm function has been streamlined, removing its internal fallback to Triton, as kernel selection is now managed at a higher level. Pre-tuned configurations for Qwen3-4B on RTX A6000 have been added to the repository.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the performance Performance-related issues label Jun 27, 2025
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
@gau-nernst gau-nernst force-pushed the thien/triton_w8a8_int8 branch from 95bf565 to 3b1fd0b Compare June 27, 2025 04:17
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performant Triton-based scaled matrix multiplication kernel for NVIDIA GPUs, with ahead-of-time autotuning to improve performance. The changes include a new tuning script and pre-tuned configurations.

Comment on lines 69 to 93
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
masks_am = offsets_am < M
offsets_am = offsets_am % M

offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
masks_bn = offsets_bn < N
offsets_bn = offsets_bn % N

offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
offsets_a = (stride_am * offsets_am[:, None] +
stride_ak * offsets_k[None, :])
offsets_b = (stride_bk * offsets_k[:, None] +
stride_bn * offsets_bn[None, :])

# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
# appropriate offsets and masks for each case. Same goes for
# BLOCK_SIZE_SCALE_B.
offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) +
(BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M)
masks_scale_am = offsets_scale_am < M

offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) +
(BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N)
masks_scale_bn = offsets_scale_bn < N

a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b

scale_a_ptrs = scale_a_ptr + offsets_scale_am
scale_b_ptrs = scale_b_ptr + offsets_scale_bn

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a)

masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b)
for k in range(K, 0, -BLOCK_SIZE_K):
masks_k = offsets_k < k
a = tl.load(a_ptrs, mask=masks_k[None, :])
b = tl.load(b_ptrs, mask=masks_k[:, None])

# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

offsets_k += BLOCK_SIZE_K
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The updated scaled_mm_kernel has a critical issue in how it handles boundary conditions for the M and N dimensions.

Using offsets_am = offsets_am % M and offsets_bn = offsets_bn % N for boundary checks is incorrect. This will cause the kernel to read from incorrect memory locations (wrapping around to the beginning of the tensor) for out-of-bounds indices, leading to incorrect results. This happens whenever M or N are not perfectly divisible by BLOCK_SIZE_M or BLOCK_SIZE_N, respectively.

The correct approach is to use masks to safely load data, padding with zeros for elements that are out of bounds. The tl.load calls for matrices a and b are missing these masks for the M and N dimensions.

Please restore the masking logic for correctness.

    offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
    masks_am = offsets_am < M

    offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
    masks_bn = offsets_bn < N

    offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
    offsets_a = (stride_am * offsets_am[:, None] +
                 stride_ak * offsets_k[None, :])
    offsets_b = (stride_bk * offsets_k[:, None] +
                 stride_bn * offsets_bn[None, :])

    a_ptrs = a_ptr + offsets_a
    b_ptrs = b_ptr + offsets_b

    for k in range(K, 0, -BLOCK_SIZE_K):
        masks_k = offsets_k < k
        a = tl.load(a_ptrs, mask=masks_am[:, None] & masks_k[None, :], other=0.0)
        b = tl.load(b_ptrs, mask=masks_bn[None, :] & masks_k[:, None], other=0.0)

        # Accumulate results.
        accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap around is a correct trick to minimize mask computation. Since we have masking when writing output to global memory, the out-of-bounds M and N here do not matter.

best_configs_list = pool.map(tune_on_gpu, process_args)

# merge configs from all GPU. sort by M
best_configs = dict(sorted(sum(best_configs_list, start=[])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using sum(best_configs_list, start=[]) to flatten the list of lists can be inefficient for a large number of sublists, as it can lead to quadratic runtime complexity due to intermediate list creations.

A more memory-efficient and performant approach is to use a list comprehension, like [item for sublist in best_configs_list for item in sublist]. Since this is a utility script and the number of GPUs is likely small, this is not a critical issue, but it's a good practice to be aware of.

Suggested change
best_configs = dict(sorted(sum(best_configs_list, start=[])))
best_configs = dict(sorted([item for sublist in best_configs_list for item in sublist]))

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant