Skip to content

Conversation

hlu1
Copy link
Collaborator

@hlu1 hlu1 commented Oct 19, 2025

Motivation

Add a fused kernel for fast_topk_transform_ragged_fused. The difference between this kernel and fast_topk_transform_fused is that fast_topk_transform_fused outputs indices into the paged kvcache and fast_topk_transform_ragged_fused outputs indices into the ragged kv that's the input to the flashmla_prefill kernel.

Accuracy Tests

Tested with #11655.
Before
Repeat: 4, mean: 0.787
Scores: ['0.768', '0.823', '0.773', '0.783']

After
Repeat: 4, mean: 0.784
Scores: ['0.783', '0.803', '0.793', '0.758']

Benchmarking and Profiling

Benchmarked with with #11655.

Before
image

After
image

The two kernels are about the same performance

Checklist

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Copy link
Contributor

Summary of Changes

Hello @hlu1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the sgl-kernel library by adding a specialized fused CUDA kernel for topk index transformation. The primary goal is to support ragged Key-Value (KV) caches, which are essential for specific prefill operations like those used in flashmla_prefill. This new kernel complements existing functionality by providing an alternative indexing mechanism, ensuring broader compatibility and optimized performance for different memory layouts.

Highlights

  • New Fused Kernel: Introduced a new fused CUDA kernel, fast_topk_transform_ragged_fused, designed to transform top-k indices specifically for ragged KV caches.
  • Ragged KV Cache Support: This new kernel outputs indices into a ragged KV cache, differentiating it from the existing fast_topk_transform_fused which targets paged KV caches. This is crucial for inputs to the flashmla_prefill kernel.
  • Performance and Accuracy: Accuracy tests show comparable results before and after the change, with mean scores of 0.787 and 0.784 respectively. Benchmarking indicates that the new kernel offers similar performance to the existing fast_topk_transform_fused.
  • Python API and Testing: A new Python wrapper fast_topk_transform_ragged_fused has been added, along with a reference PyTorch implementation and comprehensive unit tests to ensure correctness across various batch sizes and sequence lengths.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new fused CUDA kernel, fast_topk_transform_ragged_fused, aimed at optimizing performance for specific use cases like the DeepseekV32 model. The changes are well-implemented, including the C++ kernel, Python bindings, and corresponding tests. The codebase is also improved by consistently using const & for tensor arguments in function signatures. My review includes a couple of suggestions to enhance code clarity and correctness.

Comment on lines +55 to +61
__device__ void naive_topk_transform_ragged(
const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The score parameter in the newly added naive_topk_transform_ragged function is not used. To improve code clarity and maintainability, it's best to remove unused parameters. This would also require updating the call to this function at line 348.

__device__ void naive_topk_transform_ragged(
    int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
  const auto tid = threadIdx.x;
  for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
    topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
  }
}

Comment on lines +56 to +57
topk == 2048
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a minor typo in the assertion message. The function is named fast_topk_transform_ragged_fused, but the error message refers to it as fast_topk_transform_fused_ragged. Correcting this will improve clarity when debugging.

Suggested change
topk == 2048
), "fast_topk_transform_fused_ragged is only optimized for deepseek v3.2 model, where topk=2048"
topk == 2048
), "fast_topk_transform_ragged_fused is only optimized for deepseek v3.2 model, where topk=2048"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant