Skip to content

[feat] GB300 (sm103a / cu13 arm64) support#638

Open
kaixih wants to merge 1 commit intoradixark:mainfrom
kaixih:feat/gb300-support
Open

[feat] GB300 (sm103a / cu13 arm64) support#638
kaixih wants to merge 1 commit intoradixark:mainfrom
kaixih:feat/gb300-support

Conversation

@kaixih
Copy link

@kaixih kaixih commented Feb 25, 2026

Three changes for running Miles on GB300 / Blackwell hardware:

  1. docker/Dockerfile.dev — cu13 build improvements:

    • Upgrade transformer_engine to 2.12.0 for cu13 (was 2.10.0). TE 2.12 adds THD layout + head_dim=256 support in the unfused attention backend, required for Qwen3-Next-80B-A3B on sm103a (cuDNN and FA2/FA3 don't support this config).
    • Apply docker/patch/cu13/patch_fla_blackwell.py at build time to fix three Triton 3.5.1 / fla 0.4.0 bugs on Blackwell (see patch docstring for details).
    • Set TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas in /root/.bashrc since Triton's bundled ptxas does not support sm_103a.
  2. docker/patch/cu13/patch_fla_blackwell.py — new file: Patches installed fla 0.4.0 + Triton 3.5.1 packages for Blackwell (sm103a):

    • Bug 1: fla wy_fast.py backward kernel: Triton 3.5.1 lowers += tl.dot() to a three-operand SSA form that TritonGPUHoistTMEMAlloc rejects with a dominance violation. Fix: safe_dot() wraps the result in inline ASM to prevent the fusion (fla PR #687, Triton issue #8695).
    • Bug 2/3: gluon_ir not compiled in the arm64+cu130 Triton build. Affects any @triton.autotune kernel on a cache miss. Fix: try/except guards in triton/runtime/jit.py and triton/experimental/gluon/language/_semantic.py.

3. update_weight_from_distributed.py / update_weight_from_tensor.py: Call post_process_weights() unconditionally instead of only for compressed-tensors / mxfp8. The function is already a no-op when no module implements restore_weights_before_loading / process_weights_after_loading, so this is safe for all existing models. Required for BF16 MoE models where the SGLang flashinfer_trtllm backend reshapes w13_weight after load and needs the restore→load→re-apply cycle on update_weights.
(3 is removed from this PR, it will be added back when fixing the trtllm_moe backend. Currently we use triton backend.)

For repro and more details: https://gist.github.com/kaixih/7ca3ff5ca41e7e78b6cf6c6bc56d4dc1

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enables comprehensive support for running Miles on GB300 / Blackwell hardware (sm103a / cu13 arm64). It resolves critical compatibility issues by updating core dependencies, applying targeted patches to address specific bugs in Triton and Flash-Linear-Attention, and refining the weight update workflow to ensure robust functionality across diverse model configurations.

Highlights

  • CUDA 13 Build Improvements: Upgraded the transformer_engine dependency to version 2.12.0 for CUDA 13 builds, applied a new patch for Triton and Flash-Linear-Attention (FLA) bugs on Blackwell hardware, and configured the TRITON_PTXAS_PATH environment variable to support sm_103a.
  • Blackwell-Specific Patches for Triton and FLA: Introduced a new Python script, docker/patch/cu13/patch_fla_blackwell.py, which applies critical fixes to fla 0.4.0 and Triton 3.5.1 packages. These patches address a tl.dot() fusion issue in FLA's backward kernel and gluon_ir import crashes in Triton on arm64 cu130 builds.
  • Unconditional Weight Post-Processing: Modified update_weight_from_distributed.py and update_weight_from_tensor.py to unconditionally call post_process_weights(). This change ensures proper weight restoration and re-application for various model types, including BF16 MoE models using the SGLang flashinfer_trtllm backend, where weight reshaping is required after loading.
Changelog
  • docker/Dockerfile.dev
    • Updated transformer_engine version from 2.10.0 to 2.12.0 for CUDA 13 builds.
    • Added steps to copy and execute patch_fla_blackwell.py when ENABLE_CUDA_13 is set.
    • Added an environment variable TRITON_PTXAS_PATH to .bashrc for CUDA 13 builds to ensure correct ptxas usage.
  • docker/patch/cu13/patch_fla_blackwell.py
    • Added a new Python script to patch fla/ops/gated_delta_rule/wy_fast.py, triton/runtime/jit.py, and triton/experimental/gluon/language/_semantic.py.
    • Introduced a safe_dot function in wy_fast.py to prevent tl.dot() fusion issues on Blackwell.
    • Wrapped gluon_ir imports in try/except blocks within Triton's jit.py and _semantic.py to handle cases where gluon_ir is not compiled.
    • Included logic to clear the Triton cache after applying patches.
  • miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py
    • Removed the conditional check for quantization_config and quant_method before calling post_process_weights for both weight restoration and post-quantization phases.
    • Updated comments to clarify the unconditional call's necessity for BF16 MoE models and its no-op behavior when not applicable.
  • miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py
    • Removed the conditional check for quantization_config and quant_method before calling post_process_weights for both weight restoration and post-quantization phases.
    • Updated comments to clarify the unconditional call's necessity for BF16 MoE models and its no-op behavior when not applicable.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for GB300/Blackwell hardware. The changes include upgrading transformer_engine, adding a patch script for fla and Triton on Blackwell, and making post_process_weights calls unconditional. The changes look good and are well-documented. I have a couple of suggestions for improvement. In the Dockerfile, the command to add an environment variable to .bashrc can be made more robust. In the new patch script, the hardcoded Python paths could be made dynamic to improve maintainability.

COPY docker/patch/cu13/patch_fla_blackwell.py /tmp/patch_fla_blackwell.py
RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \
python3 /tmp/patch_fla_blackwell.py && \
echo 'export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas' >> /root/.bashrc; \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Appending to .bashrc without checking if the line already exists can lead to duplicate entries if this Docker layer is rebuilt. It's better to make this operation idempotent to avoid this.

      grep -qF 'export TRITON_PTXAS_PATH' /root/.bashrc || echo 'export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas' >> /root/.bashrc; \

Comment on lines +26 to +28
FLA_WY_FAST = "/usr/local/lib/python3.12/dist-packages/fla/ops/gated_delta_rule/wy_fast.py"
TRITON_JIT = "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py"
TRITON_GLUON_SEMANTIC = "/usr/local/lib/python3.12/dist-packages/triton/experimental/gluon/language/_semantic.py"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file paths are hardcoded with the Python version 3.12. This makes the script brittle and likely to break if the Python version in the base Docker image changes. It's better to determine the site-packages directory dynamically. You can use the sysconfig module for this. For example:

import os
import sysconfig

_site_packages = sysconfig.get_paths()["purelib"]
FLA_WY_FAST = os.path.join(_site_packages, "fla/ops/gated_delta_rule/wy_fast.py")
# ... and so on for other paths

Three changes for running Miles on GB300 / Blackwell hardware:

1. docker/Dockerfile.dev — cu13 build improvements:
   - Upgrade transformer_engine to 2.12.0 for cu13 (was 2.10.0). TE 2.12 adds
     THD layout + head_dim=256 support in the unfused attention backend, required
     for Qwen3-Next-80B-A3B on sm103a (cuDNN and FA2/FA3 don't support this config).
   - Apply docker/patch/cu13/patch_fla_blackwell.py at build time to fix three
     Triton 3.5.1 / fla 0.4.0 bugs on Blackwell (see patch docstring for details).
   - Set TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas in /root/.bashrc since
     Triton's bundled ptxas does not support sm_103a.

2. docker/patch/cu13/patch_fla_blackwell.py — new file:
   Patches installed fla 0.4.0 + Triton 3.5.1 packages for Blackwell (sm103a):
   - Bug 1: fla wy_fast.py backward kernel: Triton 3.5.1 lowers `+= tl.dot()`
     to a three-operand SSA form that TritonGPUHoistTMEMAlloc rejects with a
     dominance violation. Fix: safe_dot() wraps the result in inline ASM to
     prevent the fusion (fla PR #687, Triton issue #8695).
   - Bug 2/3: gluon_ir not compiled in the arm64+cu130 Triton build. Affects
     any @triton.autotune kernel on a cache miss. Fix: try/except guards in
     triton/runtime/jit.py and triton/experimental/gluon/language/_semantic.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@kaixih kaixih force-pushed the feat/gb300-support branch from c37b799 to 6434349 Compare February 25, 2026 22:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant