Skip to content

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Jun 3, 2025

Purpose

Integration of batched deepgemm kernels. These kernels are Gemm kernels used in the fused MOE operation for block-quantized matmuls.

Test

  • Added unit tests. Verified that the tests pass locally on an H100.
  • Verified correctness with lm_eval on the Qwen/Qwen3-30B-A3B-FP8 model for DP=2, TP=1, Expert-Parallel case

Test Result

sever command: VLLM_ALL2ALL_BACKEND="deepep_low_latency" VLLM_USE_DEEP_GEMM=1 vllm serve Qwen/Qwen3-30B-A3B-FP8 --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --port 9010 --no-enable-prefix-caching

lm_eval command : lm_eval --model local-completions --tasks gsm8k --model_args model=Qwen/Qwen3-30B-A3B-FP8,base_url=http://127.0.0.1:9010/v1/completions,num_concurrent=30,max_retries=3,tokenized_requests=False --limit 100

Output:

...
local-completions (model=Qwen/Qwen3-30B-A3B-FP8,base_url=http://127.0.0.1:9010/v1/completions,num_concurrent=30,max_retries=3,tokenized_requests=False), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.85|±  |0.0359|
|     |       |strict-match    |     5|exact_match|↑  | 0.94|±  |0.0239|

Signed-off-by: Varun <vsundarr@redhat.com>
Copy link

github-actions bot commented Jun 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @varun-sundar-rabindranath, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini or gemini-code-assist here, providing a summary of this pull request. This PR focuses on integrating batched and masked deep_gemm kernels into the fused MoE operation, specifically targeting block-quantized matrix multiplications. The goal is to provide an alternative, potentially optimized, kernel implementation for this specific use case within the vLLM framework. The changes involve adding the new kernel implementation, a dispatcher to select between the new kernel and the existing Triton one, updating the DeepEP low-latency path to handle block quantization and dispatching, and adding a new test suite to verify correctness.

Highlights

  • DeepGEMM Kernel Integration: Introduces a new BatchedDeepGemmExperts class (vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py) that wraps the deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked kernel for performing batched and masked matrix multiplications required in the fused MoE layer, specifically for FP8 block-quantized inputs and weights with a 128x128 block shape.
  • Kernel Dispatcher: Adds a new BatchedTritonOrDeepGemmExperts class (vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py) that acts as a dispatcher. It selects between the new BatchedDeepGemmExperts and the existing BatchedTritonExperts based on whether allow_deep_gemm is enabled, the quantization type is FP8, and the block shape is 128x128.
  • DeepEP Low-Latency Path Updates: Modifies the DeepEPLLPrepareAndFinalize class (vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py) to correctly handle input quantization and dequantization when using the FP8 dispatch path, especially in the context of block quantization. A new helper method _do_quant is introduced for this logic.
  • New Test Suite: Adds a new test file (tests/kernels/moe/test_ll_deepep_deepgemm_moe.py) dedicated to testing the Low-Latency DeepEP integration with the new DeepGEMM kernels. This test includes helper functions for block quantization and compares the output of the DeepEP/DeepGEMM path against a Triton reference implementation.
  • Quantization Configuration Update: Updates the FP8QuantizeConfig (vllm/model_executor/layers/quantization/fp8.py) to use the new BatchedTritonOrDeepGemmExperts dispatcher for batched MoE cases, passing the configured block shape and the allow_deep_gemm flag.

Changelog

Click here to see the changelog
  • tests/kernels/moe/deepep_utils.py
    • Added block_shape parameter to DeepEPLLPrepareAndFinalize constructor call (line 172).
    • Added block_shape argument to make_deepep_ll_a2a call (line 191).
  • tests/kernels/moe/test_ht_deepep_deepgemm_moe.py
    • Updated docstring to specify 'High-Throughput' integration (line 3).
  • tests/kernels/moe/test_ll_deepep_deepgemm_moe.py
    • Added new file containing tests for Low-Latency DeepEP + DeepGEMM MoE integration.
  • vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
    • Added new file defining the BatchedDeepGemmExperts class for DeepGEMM kernel execution.
  • vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
    • Added new file defining the BatchedTritonOrDeepGemmExperts class for dispatching between Triton and DeepGEMM kernels.
  • vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
    • Added Union to type hints import (line 2).
    • Added _do_quant helper method for input quantization logic (lines 68-111).
    • Updated assertion message related to per-token scales (line 139).
    • Replaced inline quantization logic with a call to _do_quant in the prepare method (line 158).
  • vllm/model_executor/layers/quantization/fp8.py
    • Modified select_gemm_impl to use BatchedTritonOrDeepGemmExperts for batched MoE (line 788).
    • Passed block_shape and allow_deep_gemm to the new dispatcher (lines 797, 798).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Kernels run fast,
Matrix math, built to last.
DeepGEMM joins the race,
Bringing speed and grace.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements by integrating batched/masked DeepGEMM kernels for fused MoE operations, particularly targeting block-quantized matmuls. The changes include new kernel implementations, a new test suite for low-latency DeepEP + DeepGEMM, and updates to existing utilities and configurations to support these new kernels.

The overall structure is well-organized, with clear separation of concerns in the new BatchedDeepGemmExperts and BatchedTritonOrDeepGemmExperts classes. The refactoring in DeepEPLLPrepareAndFinalize to centralize quantization logic is also a good improvement.

There are a few areas that require attention, primarily related to testing and potential performance optimizations, which are detailed in the comments below.

Summary of Findings

  • Disabled Tests for FP8 Dispatch: Tests for USE_FP8_DISPATCH=True in test_ll_deepep_deepgemm_moe.py are currently disabled, indicating a gap in test coverage for this feature. This should be addressed before merging.
  • High Test Tolerances: The numerical comparison tolerances (atol/rtol of 6e-2) in test_ll_deepep_deepgemm_moe.py are relatively high and might mask subtle issues. Further investigation or justification for these values is recommended.
  • Potential Performance Optimizations (TODOs): There are TODO comments in batched_deep_gemm_moe.py and deepep_ll_prepare_finalize.py regarding the use of batched operations for activation and quantization. Addressing these could lead to performance improvements.

Merge Readiness

This pull request makes valuable contributions by integrating DeepGEMM kernels. However, before it can be considered ready for merging, I recommend addressing the following:

  1. Disabled Tests: The tests for USE_FP8_DISPATCH=True in test_ll_deepep_deepgemm_moe.py should be enabled and passing to ensure full coverage of the FP8 dispatch functionality.
  2. Test Tolerances: The high atol and rtol values in the tests should be reviewed. If tighter tolerances are not possible, a brief explanation of why these values are necessary would be helpful.

Addressing the TODO items for batched operations would be a good follow-up for performance optimization but is not strictly blocking for this integration PR, assuming correctness is verified.

As an AI reviewer, I am not authorized to approve pull requests. Please ensure these points are discussed and resolved, and that other maintainers review and approve the changes before merging.

]
TOPKS = [2, 6]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment # Fix tests for USE_FP8_DISPATCH=True above USE_FP8_DISPATCH = [False] indicates that tests for the use_fp8_dispatch=True scenario are currently disabled. This is a concern as it implies a lack of test coverage for a key feature variant.

Could you please provide more context on this?

  • Is there an underlying issue preventing these tests from being enabled?
  • Are there plans to enable these tests in this PR or a follow-up?

Ensuring full test coverage, especially for different dispatch modes, is crucial for maintaining the reliability of these new kernels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag is specifically set to false when constructing the DeepEPLLPrepareFinalize object here

# Note (varun): Whether to use FP8 dispatch or not needs some

The test failures are also not abysmal - they look like minor numerical issues - To enable can be a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could turn it fp8_dispatch on by default - the lm_eval results dont look that bad,

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.73|±  |0.0446|
|     |       |strict-match    |     5|exact_match|↑  | 0.86|±  |0.0349|

🤷

Comment on lines 350 to 355
torch.testing.assert_close(
triton_moe,
deepep_moe,
atol=6e-2,
rtol=6e-2,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tolerances atol=6e-2 and rtol=6e-2 for torch.testing.assert_close seem somewhat high, even considering FP8 arithmetic. Could this potentially mask subtle numerical discrepancies between the Triton reference and the DeepEP implementation? It might be beneficial to investigate if tighter tolerances are achievable or to confirm if this level of tolerance is expected due to the specific kernels and data types involved.

Comment on lines +101 to +105
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
per_token_quant,
self.block_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the batched_deep_gemm_moe.py file, this TODO points out an optimization opportunity by using a batched version of quantization within the _do_quant method.

Is this related to the optimization mentioned in batched_deep_gemm_moe.py? Consolidating these efforts or tracking them together might be beneficial.

DEEPEP_BLOCK_SIZE = [128, 128]


def next_power_of_2(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you factor out all these utilities so they aren't duplicated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it 👍 Collapsed test_ll_deepep_deepgemm_moe.py and test_ht_deepep_deepgemm_moe.py into test_deepep_deepgemm_moe.py

Varun added 5 commits June 3, 2025 23:01
Signed-off-by: Varun <vsundarr@redhat.com>
Signed-off-by: Varun <vsundarr@redhat.com>
Signed-off-by: Varun <vsundarr@redhat.com>
Signed-off-by: Varun <vsundarr@redhat.com>
Signed-off-by: Varun <vsundarr@redhat.com>
Copy link
Contributor

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think everything looks good but it would be nice to have a pplx + masked deep gemm kernel test. I think it should be pretty straightforward to modify test_pplx_moe.py to use the deep gemm experts in addition to batched triton experts.

Comment on lines +50 to +51
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.max_num_tokens isn't optional

is_fp8_128_block_quantized = (self.use_fp8_w8a8
and self.block_shape is not None
and len(self.block_shape) == 2 and all(
[b == 128
Copy link
Contributor

@bnellnm bnellnm Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use the deep gemm block size constant here instead of 128? (or make this a function in the batched_deep_gemm_moe.py file, similar to _valid_deep_gemm_shape?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the masked deep gemm kernel have any other size restrictions? iirc the grouped kernel requires almost all the sizes to be 128 aligned.

Comment on lines +141 to +142
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this condition should be part of BatchedTritonOrDeepGemmExperts logic to select which kernels to use?

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 4, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) June 4, 2025 18:41
@tlrmchlsmth tlrmchlsmth merged commit c3fd4d6 into vllm-project:main Jun 4, 2025
80 of 81 checks passed
leoli1208 pushed a commit to leoli1208/vllm that referenced this pull request Jul 22, 2025
Signed-off-by: Varun <vsundarr@redhat.com>
Co-authored-by: Varun <vsundarr@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants