[feat] GB300 (sm103a / cu13 arm64) support#638
Conversation
Summary of ChangesHello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enables comprehensive support for running Miles on GB300 / Blackwell hardware (sm103a / cu13 arm64). It resolves critical compatibility issues by updating core dependencies, applying targeted patches to address specific bugs in Triton and Flash-Linear-Attention, and refining the weight update workflow to ensure robust functionality across diverse model configurations. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for GB300/Blackwell hardware. The changes include upgrading transformer_engine, adding a patch script for fla and Triton on Blackwell, and making post_process_weights calls unconditional. The changes look good and are well-documented. I have a couple of suggestions for improvement. In the Dockerfile, the command to add an environment variable to .bashrc can be made more robust. In the new patch script, the hardcoded Python paths could be made dynamic to improve maintainability.
| COPY docker/patch/cu13/patch_fla_blackwell.py /tmp/patch_fla_blackwell.py | ||
| RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ | ||
| python3 /tmp/patch_fla_blackwell.py && \ | ||
| echo 'export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas' >> /root/.bashrc; \ |
There was a problem hiding this comment.
Appending to .bashrc without checking if the line already exists can lead to duplicate entries if this Docker layer is rebuilt. It's better to make this operation idempotent to avoid this.
grep -qF 'export TRITON_PTXAS_PATH' /root/.bashrc || echo 'export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas' >> /root/.bashrc; \
| FLA_WY_FAST = "/usr/local/lib/python3.12/dist-packages/fla/ops/gated_delta_rule/wy_fast.py" | ||
| TRITON_JIT = "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py" | ||
| TRITON_GLUON_SEMANTIC = "/usr/local/lib/python3.12/dist-packages/triton/experimental/gluon/language/_semantic.py" |
There was a problem hiding this comment.
The file paths are hardcoded with the Python version 3.12. This makes the script brittle and likely to break if the Python version in the base Docker image changes. It's better to determine the site-packages directory dynamically. You can use the sysconfig module for this. For example:
import os
import sysconfig
_site_packages = sysconfig.get_paths()["purelib"]
FLA_WY_FAST = os.path.join(_site_packages, "fla/ops/gated_delta_rule/wy_fast.py")
# ... and so on for other pathsThree changes for running Miles on GB300 / Blackwell hardware:
1. docker/Dockerfile.dev — cu13 build improvements:
- Upgrade transformer_engine to 2.12.0 for cu13 (was 2.10.0). TE 2.12 adds
THD layout + head_dim=256 support in the unfused attention backend, required
for Qwen3-Next-80B-A3B on sm103a (cuDNN and FA2/FA3 don't support this config).
- Apply docker/patch/cu13/patch_fla_blackwell.py at build time to fix three
Triton 3.5.1 / fla 0.4.0 bugs on Blackwell (see patch docstring for details).
- Set TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas in /root/.bashrc since
Triton's bundled ptxas does not support sm_103a.
2. docker/patch/cu13/patch_fla_blackwell.py — new file:
Patches installed fla 0.4.0 + Triton 3.5.1 packages for Blackwell (sm103a):
- Bug 1: fla wy_fast.py backward kernel: Triton 3.5.1 lowers `+= tl.dot()`
to a three-operand SSA form that TritonGPUHoistTMEMAlloc rejects with a
dominance violation. Fix: safe_dot() wraps the result in inline ASM to
prevent the fusion (fla PR #687, Triton issue #8695).
- Bug 2/3: gluon_ir not compiled in the arm64+cu130 Triton build. Affects
any @triton.autotune kernel on a cache miss. Fix: try/except guards in
triton/runtime/jit.py and triton/experimental/gluon/language/_semantic.py.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
c37b799 to
6434349
Compare
Three changes for running Miles on GB300 / Blackwell hardware:
docker/Dockerfile.dev — cu13 build improvements:
docker/patch/cu13/patch_fla_blackwell.py — new file: Patches installed fla 0.4.0 + Triton 3.5.1 packages for Blackwell (sm103a):
+= tl.dot()to a three-operand SSA form that TritonGPUHoistTMEMAlloc rejects with a dominance violation. Fix: safe_dot() wraps the result in inline ASM to prevent the fusion (fla PR #687, Triton issue #8695).3. update_weight_from_distributed.py / update_weight_from_tensor.py: Call post_process_weights() unconditionally instead of only for compressed-tensors / mxfp8. The function is already a no-op when no module implements restore_weights_before_loading / process_weights_after_loading, so this is safe for all existing models. Required for BF16 MoE models where the SGLang flashinfer_trtllm backend reshapes w13_weight after load and needs the restore→load→re-apply cycle on update_weights.(3 is removed from this PR, it will be added back when fixing the trtllm_moe backend. Currently we use triton backend.)
For repro and more details: https://gist.github.com/kaixih/7ca3ff5ca41e7e78b6cf6c6bc56d4dc1