Skip to content

[Bugfix] Add padding for block-scale fused-moe weights for AITER lib #19234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

qli88
Copy link
Contributor

@qli88 qli88 commented Jun 5, 2025

[Bugfix] Add necessary padding for block-scale fused_moe weights for AITER lib

Copy link

github-actions bot commented Jun 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @qli88, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, I, gemini-code-assist, have reviewed this pull request. The core purpose of this PR is to address a bug specifically impacting the use of block-scaled fused Mixture-of-Experts (MoE) layers with the AITER library on ROCm platforms. The AITER kernel fmoe_fp8_blockscale_g1u1 has a requirement that the intermediate dimension (inter_dim) must be a multiple of 256, and the block scales (block_k, block_n) must be 128. This pull request introduces padding logic to ensure that MoE weights and their corresponding inverse scales meet these dimensional constraints when they are not naturally aligned, thereby fixing compatibility issues and enabling correct execution.

Highlights

  • ROCm AITER MoE Bugfix: Fixes a specific bug where block-scaled fused MoE layers fail with the ROCm AITER library due to strict dimensional requirements of the kernel.
  • Weight and Scale Padding: Adds a new function to conditionally pad MoE weights (w2_weight, w13_weight) and their inverse scales (w2_weight_scale_inv, w13_weight_scale_inv).
  • Dimension Alignment for AITER: Ensures the intermediate dimension (inter_dim) of the weights is padded to the next multiple of 256, using 0.0 for weights and 1.0 for scales, to satisfy AITER kernel requirements.

Changelog

  • vllm/model_executor/layers/quantization/fp8.py
    • Added a new private method _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights (starting around line 473) to encapsulate the padding logic for ROCm AITER block-scaled MoE weights and scales.
    • Modified the process_weights_after_loading method (around line 733) to call the new padding function for block-scaled FP8 MoE layers before creating the final weight parameters.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a bugfix by adding necessary padding for block-scale fused MoE weights when using the AITER library on ROCm. The new method _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights correctly handles the padding of weights and their inverse scales to meet AITER's alignment requirements.

The overall implementation is clear, and the padding logic appears correct. The detailed comments explaining the tensor shapes and padding steps are particularly helpful. I have a couple of suggestions to improve the clarity and robustness of the conditional checks.

Summary of Findings

  • Redundant conditional check: The self.rocm_aiter_moe_enabled check in the second conditional (lines 493-497) is redundant as it's already covered by the first conditional (lines 489-491).
  • Clarity of early exit condition for padding: The condition for skipping padding (w2_weight.shape[-1] % 256 == 0 and w13_weight.shape[-2] % 256 == 0) could potentially be simplified to primarily check w2_weight.shape[-1] % 256 == 0, as inter_dim (derived from w2_weight) is the key dimension requiring alignment according to the docstring and padding logic.
  • Type Hinting: The arguments w2_weight, w2_weight_scale_inv, w13_weight, w13_weight_scale_inv in _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights are not type-hinted. Adding torch.Tensor type hints would improve readability. The return type could also be hinted as tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor].
  • Docstring Grammar: Minor grammatical corrections in the docstring of _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights (e.g., "support" -> "supports", "is" -> "are") would enhance clarity.
  • Magic Numbers: The alignment value 256 is used in calculations (e.g., line 509). Consider defining it as a constant like AITER_INTER_DIM_ALIGNMENT = 256 for better maintainability.
  • Commented-out Code: There are commented-out lines (510-511) that seem like remnants or notes. These should be removed if they are no longer necessary.
  • Redundant Assertion: In _maybe_pad_rocm_aiter_block_scaled_fused_moe_weights, the assertion assert block_k == block_n (line 505) is redundant if the preceding assertions assert block_k == 128 and assert block_n == 128 both pass.

Merge Readiness

The pull request addresses an important bugfix for AITER compatibility. The core padding logic is sound. However, there are a couple of medium-severity issues related to conditional checks that could be improved for clarity and robustness. I recommend addressing these points before merging. I am unable to approve the pull request myself; please ensure other reviewers approve this code before merging.

Co-authored-by: tjtanaavllm <tunjian.tan@embeddedllm.com>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
@qli88 qli88 force-pushed the qiang_qwen3_moe_padding branch from 55ab43c to ccf25ca Compare June 5, 2025 19:39
qli88 added 4 commits June 5, 2025 19:51
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
gshtras added a commit to ROCm/vllm that referenced this pull request Jun 5, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. What models does this padding fix? I assume this bug manifests as a correctness issue?

return (w2_weight, w2_weight_scale_inv, w13_weight,
w13_weight_scale_inv)

if (self.rocm_aiter_moe_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you don't need to check self.rocm_aiter_mode_enabled here

# after reshape
# [expert(local_expert:EP), 2 * padded_inter_dim_block_scale, k_block_scale] # noqa: E501
k_block_scale = w13_weight_scale_inv.shape[
2] # k_block_scale = (hidden_size + block_k - 1) // block_k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you move the comment to the line above so that the code is all on the same line?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants