-
Couldn't load subscription status.
- Fork 286
[AMD] support mfma i32_16x16x32_i8 #800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] support mfma i32_16x16x32_i8 #800
Conversation
WalkthroughAdds int8 MFMA support across HIP codegen and templates: fixes dtype placeholder names, extends dtype map with int8x8→int64_t, adjusts MFMA prefix tokens, and introduces MfmaTraits<int8_t> using the AMD i8 MFMA intrinsic. Updates macro generation for int8 suffix and k-dimension, and expands tests to include int8 with k_pack scaling and kernel source printing. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant T as Test Runner
participant TL as TileLang GEMM (tests)
participant MG as MFMA Macro Generator
participant CG as HIP Codegen
participant K as Generated Kernel
participant HW as AMD MFMA Intrinsic
T->>TL: Invoke GEMM with dtype=int8, accum=int32, k_pack
TL->>MG: Request MFMA config (dtype=int8)
MG-->>TL: k_dim=32, suffix "..._x32_i8"
TL->>CG: Build kernel with MFMA tokens and dtypes
CG-->>TL: HIP kernel source (uses {A/B/C_dtype}, int8x8→int64_t)
TL->>K: Launch kernel
K->>HW: __builtin_amdgcn_mfma_i32_16x16x32_i8(...)
HW-->>K: Accumulate results
K-->>T: Completion (optionally print kernel source)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Paran0idy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily focuses on extending the AMD backend to fully support int8 matrix multiplication operations using the __builtin_amdgcn_mfma_i32_16x16x32_i8 intrinsic. It includes necessary updates to the code generation logic, type mappings, and MFMA macro generation, along with new test cases to verify the integration and correctness of these low-precision computations.
Highlights
- MFMA i32_16x16x32_i8 Support: Added comprehensive support for the __builtin_amdgcn_mfma_i32_16x16x32_i8 intrinsic, enabling int8 matrix multiplication accumulation to int32 on AMD GPUs.
- Backend Codegen Enhancements: Updated the HIP backend code generation to correctly handle int8x8 types and fixed a typo in the MFMA code generation template.
- MFMA Macro Generation Logic: Modified the MFMA macro generator to correctly determine k_dim and generate the appropriate suffix for int8 inputs.
- Testing and Validation: Introduced new test cases to validate the correctness of int8 matrix multiplication with int32 accumulation, ensuring proper functionality.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the __builtin_amdgcn_mfma_i32_16x16x32_i8 instruction for AMD GPUs, enabling int8 matrix multiplication. The changes are well-structured, touching the HIP codegen, C++ templates, and Python layers for intrinsic generation and testing. The implementation is largely correct. I have two suggestions for improvement: one to enhance code quality and safety in a C++ template by removing an unnecessary const_cast, and another to remove a leftover debugging print statement from a Python test file.
| int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b)); | ||
| int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of const_cast here is unnecessary and not considered best practice. Since the data pointed to by a and b is not modified, you can directly reinterpret_cast to a const pointer type. This preserves const correctness and improves code safety.
const int64_t *b_packed = reinterpret_cast<const int64_t *>(b);
const int64_t *a_packed = reinterpret_cast<const int64_t *>(a);| C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) | ||
|
|
||
| kernel(A, B, C) | ||
| print(kernel.get_kernel_source()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py (2)
45-53: Scale of block_K now depends on k_pack — add divisibility guard to avoid silent tail-drop.With
chunk = 32 * k_packandblock_K = chunk, the outer K loop uses floor division and can silently skip a remainder whenK % block_K != 0. Add an explicit assert to fail fast. Thekiinner loop math remains consistent.Apply this diff:
block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk + # Ensure we don't drop a tail on K + assert K % block_K == 0, f"K ({K}) must be divisible by block_K ({block_K})"Also applies to: 126-127
211-214: Typo causes AttributeError: use A.T.to(...), not A.Tto(...).This path will crash for
a_transposed=True, b_transposed=False.Apply this diff:
- ref_c = torch.matmul(A.Tto(torch.float32), + ref_c = torch.matmul(A.T.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
🧹 Nitpick comments (4)
src/target/codegen_hip.cc (1)
883-883: Comment improvement: Fix grammar in parameter description.The comment has a grammatical error - "multiplicand" should be "multiplier" for the A and B parameters.
Apply this diff to improve the comment:
- // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype} + // arg 0: mfma suffix: {otype}_{intrM}x{intrN}x{intrK}_{itype}testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py (3)
182-183: Gate noisy prints and reuse fetched source to keep CI logs clean.Unconditional prints bloat logs. Make them opt-in via env vars and avoid calling
get_kernel_source()twice.Apply this diff:
- print(matmul) + if os.getenv("TL_DEBUG"): + print(matmul) @@ - print(kernel.get_kernel_source()) + if os.getenv("TL_DUMP_KERNEL"): + print(src_code) @@ - print(C) - print(ref_c) + if os.getenv("TL_DEBUG"): + print(C) + print(ref_c)And add this import near the top of the file:
import osAlso applies to: 198-198, 223-224
190-191: Include full int8 range in randint.
highis exclusive; using 127 excludes 127. Use 128 to cover [-128, 127].Apply this diff:
- A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) - B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) + A = torch.randint(-128, 128, A_shape, device="cuda", dtype=torch.int8) + B = torch.randint(-128, 128, B_shape, device="cuda", dtype=torch.int8)
185-185: Clarify backend in comment.This runs on ROCm; “CUDA source” is misleading.
Apply this diff:
- # src_code is the generated cuda source + # src_code is the generated GPU kernel source (HIP/CUDA)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/target/codegen_hip.cc(3 hunks)src/tl_templates/hip/gemm.h(1 hunks)testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py(3 hunks)tilelang/intrinsics/mfma_macro_generator.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py (4)
tilelang/jit/kernel.py (1)
get_kernel_source(378-389)tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(290-296)tilelang/jit/adapter/cython/adapter.py (1)
get_kernel_source(516-522)tilelang/jit/adapter/base.py (1)
get_kernel_source(51-52)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (5)
tilelang/intrinsics/mfma_macro_generator.py (2)
84-84: LGTM! Support for int8 k_dim is correctly added.The change correctly adds "int8" to the set of dtypes that trigger k_dim = 32, aligning with AMD's int8 MFMA instruction requirements which use 32 elements in the K dimension.
126-127: LGTM! Correct suffix generation for int8 MFMA intrinsic.The new branch properly generates the suffix with an underscore before "i8" (e.g.,
i32_16x16x32_i8), matching the expected AMD intrinsic naming convention__builtin_amdgcn_mfma_i32_16x16x32_i8.src/target/codegen_hip.cc (2)
917-917: LGTM! Correct mapping for int8x8 to int64_t.The addition of int8x8 → int64_t mapping is correct for packing 8 int8 values into a 64-bit integer, which aligns with the AMD MFMA int8 intrinsic requirements.
929-931: LGTM! Fixed typo in placeholder names.Good catch fixing the typo from
{C_dytpe},{A_dytpe},{B_dytpe}to the correct{C_dtype},{A_dtype},{B_dtype}. This ensures proper type casting in the generated MFMA code.Also applies to: 937-939
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py (1)
233-235: Int8/int32 test coverage looks good.Solid additions covering k_pack=1 and k_pack=2. No issues spotted.
| int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b)); | ||
| int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential undefined behavior: const_cast removes const qualifier inappropriately.
Using const_cast to remove the const qualifier from const int8_t* parameters and then modifying the memory through reinterpret_cast could lead to undefined behavior if the underlying memory is actually const. The pointers should be cast directly without removing const.
Apply this diff to fix the const-correctness issue:
- int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
- int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
+ const int64_t *b_packed = reinterpret_cast<const int64_t *>(b);
+ const int64_t *a_packed = reinterpret_cast<const int64_t *>(a);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b)); | |
| int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a)); | |
| const int64_t *b_packed = reinterpret_cast<const int64_t *>(b); | |
| const int64_t *a_packed = reinterpret_cast<const int64_t *>(a); |
🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 15 to 16, the code uses const_cast
to strip const from const int8_t* and then reinterpret_cast to int64_t*, which
can invoke undefined behavior; change the target pointer types to const int64_t*
and cast directly from the original const int8_t* using reinterpret_cast<const
int64_t*>(...), removing any const_cast so the const qualifier is preserved and
you don't attempt to modify potentially const memory.
__builtin_amdgcn_mfma_i32_16x16x32_i8.Summary by CodeRabbit
New Features
Bug Fixes
Tests