-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[DeepseekV32] Add fast_topk_transform_ragged_fused kernel #11815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Summary of ChangesHello @hlu1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new fused CUDA kernel, fast_topk_transform_ragged_fused
, aimed at optimizing performance for specific use cases like the DeepseekV32 model. The changes are well-implemented, including the C++ kernel, Python bindings, and corresponding tests. The codebase is also improved by consistently using const &
for tensor arguments in function signatures. My review includes a couple of suggestions to enhance code clarity and correctness.
__device__ void naive_topk_transform_ragged( | ||
const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { | ||
const auto tid = threadIdx.x; | ||
for (auto i = tid; i < TopK; i += kThreadsPerBlock) { | ||
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The score
parameter in the newly added naive_topk_transform_ragged
function is not used. To improve code clarity and maintainability, it's best to remove unused parameters. This would also require updating the call to this function at line 348.
__device__ void naive_topk_transform_ragged(
int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
}
}
topk == 2048 | ||
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a minor typo in the assertion message. The function is named fast_topk_transform_ragged_fused
, but the error message refers to it as fast_topk_transform_fused_ragged
. Correcting this will improve clarity when debugging.
topk == 2048 | |
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048" | |
topk == 2048 | |
), "fast_topk_transform_ragged_fused is only optimized for deepseek v3.2 model, where topk=2048" |
Motivation
Add a fused kernel for
fast_topk_transform_ragged_fused
. The difference between this kernel andfast_topk_transform_fused
is thatfast_topk_transform_fused
outputs indices into the paged kvcache andfast_topk_transform_ragged_fused
outputs indices into the ragged kv that's the input to the flashmla_prefill kernel.Accuracy Tests
Tested with #11655.
Before
Repeat: 4, mean: 0.787
Scores: ['0.768', '0.823', '0.773', '0.783']
After
Repeat: 4, mean: 0.784
Scores: ['0.783', '0.803', '0.793', '0.758']
Benchmarking and Profiling
Benchmarked with with #11655.
Before

After

The two kernels are about the same performance
Checklist