-
Notifications
You must be signed in to change notification settings - Fork 424
feat: Fused rope fp8 quantize kernel for MLA #1339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant performance optimization by implementing a new fused CUDA kernel that combines the application of Rotary Positional Embeddings (RoPE) with FP8 quantization. This specialized kernel is designed to efficiently prepare query and key/value tensors for subsequent FP8 Multi-Layer Attention (MLA) operations, aiming to reduce computational overhead and memory bandwidth requirements in large language models.
Highlights
- Fused RoPE and FP8 Quantization: Introduces a new fused kernel that combines Rotary Positional Embedding (RoPE) application with FP8 quantization. This is specifically designed to prepare inputs for FP8 Multi-Layer Attention (MLA) kernels, optimizing the data pipeline for efficient model inference.
- Performance Optimization: The primary goal of this change is to improve computational efficiency and reduce memory bandwidth requirements by fusing two sequential operations (RoPE and quantization) into a single, optimized CUDA kernel.
- CUDA Kernel Implementation: A new CUDA kernel,
MLARopeQuantizeKernel
, is implemented to perform the fused operation directly on the GPU. It handles both the RoPE and non-RoPE components of query and key/value tensors, supporting both interleaved and non-interleaved RoPE application, and applies quantization scales to output FP8 tensors. - PyTorch Integration: The new fused operation is exposed to the PyTorch framework through a
mla_rope_quantize
function, allowing it to be seamlessly integrated and called from Python code within the FlashInfer library.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new fused RoPE and FP8 quantization kernel for MLA, which is a great addition for performance. However, the current implementation has several critical issues that will prevent it from compiling and running correctly. These include undefined variables, incorrect function calls, and flawed logic within the new CUDA kernel. I've provided detailed feedback and suggestions to address these points. Once fixed, this will be a solid contribution.
csrc/rope.cu
Outdated
const uint32_t k_rope_out_stride = k_rope_out.stride(0); | ||
const uint32_t k_nope_out_stride = k_nope_out.stride(0); | ||
|
||
const c10::cuda::OptionalCUDAGuard device_guard(q.device()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
csrc/rope.cu
Outdated
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( | ||
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()), | ||
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()), | ||
static_cast<c_quant_type*>(q_rope_out.data_ptr()), | ||
static_cast<c_quant_type*>(k_rope_out.data_ptr()), | ||
static_cast<c_quant_type*>(q_nope_out.data_ptr()), | ||
static_cast<c_quant_type*>(k_nope_out.data_ptr()), | ||
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_heads, q_rope_in_stride_n, | ||
q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, q_rope_out_stride_n, | ||
q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, k_rope_in_stride, | ||
k_nope_in_stride, k_rope_out_stride, k_nope_out_stride, quant_scale_q, quant_scale_kv, | ||
interleave, stream); | ||
TORCH_CHECK(status == cudaSuccess, | ||
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + | ||
std::string(cudaGetErrorString(status))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is calling BatchQKApplyRotaryPosIdsCosSinCache
, but the function signature and arguments do not match. It seems you intended to call the new flashinfer::MLARopeQuantize
function. This will cause a compilation error.
cudaError_t status = flashinfer::MLARopeQuantize(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(k_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()),
static_cast<c_quant_type*>(k_nope_out.data_ptr()),
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_heads, q_rope_in_stride_n,
q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, q_rope_out_stride_n,
q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, k_rope_in_stride,
k_nope_in_stride, k_rope_out_stride, k_nope_out_stride, quant_scale_q, quant_scale_kv,
interleave, stream);
TORCH_CHECK(status == cudaSuccess,
"MLARopeQuantize failed with error code " +
std::string(cudaGetErrorString(status)));
constexpr uint32_t rotary_dim = 64; | ||
|
||
vec_t<float, vec_size> cos, sin; | ||
if (bx * bdy + ty < nnz) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx); | ||
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include/flashinfer/pos_enc.cuh
Outdated
#pragma unroll | ||
for (uint32_t i = 0; i < vec_size; ++i) { | ||
q_rope_out_ptr[i] = q_rope_vec[i] * quant_scale_q; | ||
} | ||
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantization logic here is incorrect and will lead to incorrect results and race conditions.
- The loop at lines 402-404 writes to
q_rope_out_ptr[i]
without a thread-specific offset, causing multiple threads in a warp to write to the same memory locations. - The
cast_store
on line 405 overwrites the result of the loop without applying thequant_scale_q
.
The correct approach is to scale the q_rope_vec
vector and then store it. This same issue exists for k_rope_vec
(lines 422-426), k_nope_vec
(lines 439-443), and q_nope_vec
(lines 457-461).
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_rope_vec[i] *= quant_scale_q;
}
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);
include/flashinfer/pos_enc.cuh
Outdated
|
||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { | ||
constexpr uint32_t rotary_dim = 64; | ||
constexpr uint32_t vec_size = std::max(16 / sizeof(DType)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include/flashinfer/pos_enc.cuh
Outdated
(void*)&k_nope_out_stride, | ||
(void*)&quant_scale_q, | ||
(void*)&quant_scale_kv}; | ||
auto kernel = MLARopeQuantizeKernel<INTERLEAVE, rotary_dim, vec_size, bdx, DType, IdType>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The template arguments for MLARopeQuantizeKernel
are incorrect. rotary_dim
is a runtime value, not a template parameter. Additionally, the QuantType
template parameter is missing. This will cause a compilation error.
auto kernel = MLARopeQuantizeKernel<INTERLEAVE, vec_size, bdx, DType, IdType, QuantType>;
include/flashinfer/pos_enc.cuh
Outdated
(void*)&quant_scale_q, | ||
(void*)&quant_scale_kv}; | ||
auto kernel = MLARopeQuantizeKernel<INTERLEAVE, rotary_dim, vec_size, bdx, DType, IdType>; | ||
dim3 nblks(nblks_x, num_heads + 8 + 1 + num_heads * 8) dim3 nthrs(bdx, bdy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
csrc/rope.cu
Outdated
CHECK_EQ(q_rope_in.size(1), num_heads); | ||
CHECK_EQ(q_nope_in.size(1), num_heads); | ||
CHECK_EQ(q_rope_out.size(1), num_heads); | ||
CHECK_EQ(k_rope_out.size(1), num_heads); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cos_sin_cache: torch.Tensor, | ||
pos_ids: torch.Tensor, | ||
is_neox: bool = True, | ||
quantize_dtype: Optional[torch.dtype] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Are you also supporting just applying rope and not quantizing? I think this function must make the quantize_dtype argument not optional or explicitly convey through function name that it is rope_quantize_to_fp8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
9b0d7fd
to
d7d142a
Compare
Looking forward to this |
📌 Description
Fusing RoPE + fp8 quantization kernel to prepare input for fp8 mla kernel.
Reference:
https://github.com/NVIDIA/TensorRT-LLM/blob/0df758ec9f8409410bac8b60d117374054391c2d/cpp/tensorrt_llm/kernels/mlaKernels.cu#L358
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes