Skip to content

feat: Fused rope fp8 quantize kernel for MLA #1339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 9, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jul 28, 2025

📌 Description

Fusing RoPE + fp8 quantization kernel to prepare input for fp8 mla kernel.
Reference:
https://github.com/NVIDIA/TensorRT-LLM/blob/0df758ec9f8409410bac8b60d117374054391c2d/cpp/tensorrt_llm/kernels/mlaKernels.cu#L358

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization by implementing a new fused CUDA kernel that combines the application of Rotary Positional Embeddings (RoPE) with FP8 quantization. This specialized kernel is designed to efficiently prepare query and key/value tensors for subsequent FP8 Multi-Layer Attention (MLA) operations, aiming to reduce computational overhead and memory bandwidth requirements in large language models.

Highlights

  • Fused RoPE and FP8 Quantization: Introduces a new fused kernel that combines Rotary Positional Embedding (RoPE) application with FP8 quantization. This is specifically designed to prepare inputs for FP8 Multi-Layer Attention (MLA) kernels, optimizing the data pipeline for efficient model inference.
  • Performance Optimization: The primary goal of this change is to improve computational efficiency and reduce memory bandwidth requirements by fusing two sequential operations (RoPE and quantization) into a single, optimized CUDA kernel.
  • CUDA Kernel Implementation: A new CUDA kernel, MLARopeQuantizeKernel, is implemented to perform the fused operation directly on the GPU. It handles both the RoPE and non-RoPE components of query and key/value tensors, supporting both interleaved and non-interleaved RoPE application, and applies quantization scales to output FP8 tensors.
  • PyTorch Integration: The new fused operation is exposed to the PyTorch framework through a mla_rope_quantize function, allowing it to be seamlessly integrated and called from Python code within the FlashInfer library.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new fused RoPE and FP8 quantization kernel for MLA, which is a great addition for performance. However, the current implementation has several critical issues that will prevent it from compiling and running correctly. These include undefined variables, incorrect function calls, and flawed logic within the new CUDA kernel. I've provided detailed feedback and suggestions to address these points. Once fixed, this will be a solid contribution.

csrc/rope.cu Outdated
const uint32_t k_rope_out_stride = k_rope_out.stride(0);
const uint32_t k_nope_out_stride = k_nope_out.stride(0);

const c10::cuda::OptionalCUDAGuard device_guard(q.device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable q is not defined in this scope. This appears to be a copy-paste error and will cause a compilation failure. You should use one of the input tensors, such as q_rope_in, to get the device.

  const c10::cuda::OptionalCUDAGuard device_guard(q_rope_in.device());

csrc/rope.cu Outdated
Comment on lines 334 to 351
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(k_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()),
static_cast<c_quant_type*>(k_nope_out.data_ptr()),
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_heads, q_rope_in_stride_n,
q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, q_rope_out_stride_n,
q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, k_rope_in_stride,
k_nope_in_stride, k_rope_out_stride, k_nope_out_stride, quant_scale_q, quant_scale_kv,
interleave, stream);
TORCH_CHECK(status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " +
std::string(cudaGetErrorString(status)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This is calling BatchQKApplyRotaryPosIdsCosSinCache, but the function signature and arguments do not match. It seems you intended to call the new flashinfer::MLARopeQuantize function. This will cause a compilation error.

        cudaError_t status = flashinfer::MLARopeQuantize(
            static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
            static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
            static_cast<c_quant_type*>(q_rope_out.data_ptr()),
            static_cast<c_quant_type*>(k_rope_out.data_ptr()),
            static_cast<c_quant_type*>(q_nope_out.data_ptr()),
            static_cast<c_quant_type*>(k_nope_out.data_ptr()),
            static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_heads, q_rope_in_stride_n,
            q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, q_rope_out_stride_n,
            q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, k_rope_in_stride,
            k_nope_in_stride, k_rope_out_stride, k_nope_out_stride, quant_scale_q, quant_scale_kv,
            interleave, stream);
        TORCH_CHECK(status == cudaSuccess,
                    "MLARopeQuantize failed with error code " +
                        std::string(cudaGetErrorString(status)));

constexpr uint32_t rotary_dim = 64;

vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable bdy is not defined within the kernel's scope. This will cause a compilation error. You should use blockDim.y instead, as it holds the value of bdy passed during kernel launch.

  if (bx * blockDim.y + ty < nnz) {

Comment on lines +379 to +382
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The cos_sin_cache variable is used here but is not defined in the MLARopeQuantizeKernel's signature. It needs to be passed as a kernel argument. This will also require updating the calling function MLARopeQuantize to accept and pass cos_sin_cache.

Comment on lines 401 to 407
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_rope_out_ptr[i] = q_rope_vec[i] * quant_scale_q;
}
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The quantization logic here is incorrect and will lead to incorrect results and race conditions.

  1. The loop at lines 402-404 writes to q_rope_out_ptr[i] without a thread-specific offset, causing multiple threads in a warp to write to the same memory locations.
  2. The cast_store on line 405 overwrites the result of the loop without applying the quant_scale_q.

The correct approach is to scale the q_rope_vec vector and then store it. This same issue exists for k_rope_vec (lines 422-426), k_nope_vec (lines 439-443), and q_nope_vec (lines 457-461).

#pragma unroll
      for (uint32_t i = 0; i < vec_size; ++i) {
        q_rope_vec[i] *= quant_scale_q;
      }
      q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);


DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
constexpr uint32_t rotary_dim = 64;
constexpr uint32_t vec_size = std::max(16 / sizeof(DType));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

std::max requires two arguments, but only one is provided. This will cause a compilation error. Please correct the expression to be a valid constexpr value.

    constexpr uint32_t vec_size = 16 / sizeof(DType);

(void*)&k_nope_out_stride,
(void*)&quant_scale_q,
(void*)&quant_scale_kv};
auto kernel = MLARopeQuantizeKernel<INTERLEAVE, rotary_dim, vec_size, bdx, DType, IdType>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The template arguments for MLARopeQuantizeKernel are incorrect. rotary_dim is a runtime value, not a template parameter. Additionally, the QuantType template parameter is missing. This will cause a compilation error.

    auto kernel = MLARopeQuantizeKernel<INTERLEAVE, vec_size, bdx, DType, IdType, QuantType>;

(void*)&quant_scale_q,
(void*)&quant_scale_kv};
auto kernel = MLARopeQuantizeKernel<INTERLEAVE, rotary_dim, vec_size, bdx, DType, IdType>;
dim3 nblks(nblks_x, num_heads + 8 + 1 + num_heads * 8) dim3 nthrs(bdx, bdy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a missing semicolon between the dim3 declarations, which will cause a compilation error.

    dim3 nblks(nblks_x, num_heads + 8 + 1 + num_heads * 8); dim3 nthrs(bdx, bdy);

csrc/rope.cu Outdated
CHECK_EQ(q_rope_in.size(1), num_heads);
CHECK_EQ(q_nope_in.size(1), num_heads);
CHECK_EQ(q_rope_out.size(1), num_heads);
CHECK_EQ(k_rope_out.size(1), num_heads);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tensor k_rope_out is 2D with shape (nnz, 64), as specified by the check on line 300. Its second dimension is the rotary dimension (64), not num_heads. This check incorrectly constrains num_heads to be 64. Since k_rope_out does not have a head dimension, this check should be removed.

@yzh119 yzh119 marked this pull request as ready for review July 30, 2025 09:33
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
is_neox: bool = True,
quantize_dtype: Optional[torch.dtype] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Are you also supporting just applying rope and not quantizing? I think this function must make the quantize_dtype argument not optional or explicitly convey through function name that it is rope_quantize_to_fp8

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

yzh119 and others added 4 commits August 8, 2025 22:47
upd

double

cos/sin

bugfix

udp
@cyx-6 cyx-6 force-pushed the fused-rope-quantize branch from 9b0d7fd to d7d142a Compare August 8, 2025 22:55
@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 9, 2025

Looking forward to this

@cyx-6 cyx-6 merged commit fc88829 into flashinfer-ai:main Aug 9, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants