-
Couldn't load subscription status.
- Fork 286
[AMD] Supoort T.gemm_v2 for AMD Backend #1136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds MFMA-based GEMM support: new GemmMFMA implementation with layout inference and lowering for four data-flow variants, updated MatrixCoreIntrinEmitter with layout/thread-binding helpers and mfma signature changes, GEMM enum updates, tests, and a layout example. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Client Code
participant Gemm as GemmMFMA
participant Emitter as MatrixCoreIntrinEmitter
participant TIR as TIR Lowering
User->>Gemm: infer_layout(target, thread_nums)
Gemm->>Emitter: instantiate(thread_var)
Emitter-->>Gemm: A/B/C layouts
User->>Gemm: lower(layout_map, target, thread_nums, thread_var)
Gemm->>Gemm: select variant (SS/SR/RS/RR)
alt SS (shared/shared)
Gemm->>Emitter: make_mfma_load_layout(A), make_mfma_load_layout(B)
else Fragment involved
Gemm->>Emitter: use swizzled/fragment layouts
end
Gemm->>Emitter: mfma(A_local, B_local, C_local, k_inner)
Emitter->>TIR: emit prim_func
TIR-->>User: compiled kernel
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (10)
tilelang/tileop/gemm/gemm_mfma.py (4)
18-19: Remove stray debug prints from codegen paths.print(...) in infer_layout/mfma pollutes logs and slows CI.
- print(f"thread_nums: {thread_nums}, m_warp: {m_warp}, n_warp: {n_warp}")And in mfma():
- print(a_local_stride, b_local_stride)Also applies to: 361-362
63-81: layout_map is unused in lower().Acknowledge or consume to appease linters without behavior change. Based on static analysis hints.
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + _ = layout_map # unused: kept for API parity
93-94: Assertion message style (nit).Long f-strings in asserts trigger TRY003; consider shorter message or raising ValueError with message. Based on static analysis hints.
- assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + assert block_K >= micro_size_k, "block_K must be >= micro_size_k"
160-202: Duplicate helper name _gemm_rsr for RS and RR (clarity).Both branches define _gemm_rsr; harmless but confusing during debugging.
- Rename RR’s inner prim_func to _gemm_rrr (or similar).
Also applies to: 188-203
examples/plot_layout/fragment_mfma_load_a.py (2)
18-22: Docstring wording (minor).This builds a load layout, not “storing MFMA results”.
- Create a layout function for storing MFMA results into a fragment buffer. + Create a load layout function for MFMA fragments.
46-48: Typing nit.transform_func_sr_a/b initialized to None but typed as Callable. Use Optional[Callable].
-from typing import Literal, Callable +from typing import Literal, Callable, Optional @@ - transform_func_sr_a: Callable = None - transform_func_sr_b: Callable = None + transform_func_sr_a: Optional[Callable] = None + transform_func_sr_b: Optional[Callable] = Nonetilelang/intrinsics/mfma_macro_generator.py (1)
356-362: Strip debug prints and unused local.
- mfma(): print(...) in the hot path.
- preshuffle ldmatrix_a(): stray print(self.a_preshuffle).
- make_mfma_load_layout(): unused local dtype. Based on static analysis hints.
@@ def mfma(...): - print(a_local_stride, b_local_stride) @@ def _warp_ldmatrix_a_shared(...): - print(self.a_preshuffle) @@ def make_mfma_load_layout(...): - dtype = self.a_dtype if matrix_is_a else self.b_dtypeAlso applies to: 763-763, 459-461
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (3)
85-93: Pin target to AMD for stability.Tests are AMD‑specific; relying on target="auto" can select CUDA on CI hosts.
- kernel = tilelang.compile( + kernel = tilelang.compile( program, out_idx=[2], + target="rocm", # or "amdgpu"; choose the identifier your build uses pass_configs={If your build expects "amdgpu" instead of "rocm", please adjust. Run a quick build locally to confirm.
Also applies to: 211-218, 333-341, 459-466
108-110: Don’t benchmark or spam logs in unit tests.Benchmarks and kernel dumps are noisy and slow CI.
- latency = profiler.do_bench(profiler.func, warmup=100) - print(f"GEMM SS latency: {latency} ms") + # Optional perf-only: enable via TL_BENCH=1 + import os + if os.environ.get("TL_BENCH") == "1": + _ = profiler.do_bench(profiler.func, warmup=50) @@ - print(program) - print(kernel.get_kernel_source()) + # Debug-only; guarded to avoid CI noise + import os + if os.environ.get("TL_DEBUG") == "1": + print(program) + print(kernel.get_kernel_source())Also applies to: 459-469
95-106: Reference cast path (small).matmul for int8 uses float accumulate then casts to out_dtype. This is fine; if flakiness arises, consider clamping before cast to emulate int8 saturation.
# inside ref_program before final cast: # C = C.clamp(min_val, max_val) # emulate saturation if neededAlso applies to: 221-230, 342-352, 472-481
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/plot_layout/fragment_mfma_load_a.py(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py(1 hunks)tilelang/intrinsics/mfma_macro_generator.py(11 hunks)tilelang/tileop/gemm/__init__.py(3 hunks)tilelang/tileop/gemm/gemm_mfma.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
examples/plot_layout/fragment_mfma_load_a.py (5)
tilelang/intrinsics/utils.py (1)
get_mma_micro_size(89-109)tilelang/intrinsics/mfma_layout.py (4)
shared_16x4_to_local_64x1_layout_A(6-8)shared_16x16_to_local_64x4_layout_A(46-49)shared_16x32_to_local_64x8_layout_A(88-91)shared_16x64_to_local_64x16_layout_A(112-115)tilelang/intrinsics/mfma_macro_generator.py (4)
forward_thread(517-522)forward_thread(608-623)forward_index(524-529)forward_index(625-637)tilelang/tools/plot_layout.py (1)
plot_layout(4-207)tilelang/layout/fragment.py (2)
repeat(123-144)replicate(146-160)
tilelang/tileop/gemm/gemm_mfma.py (6)
tilelang/tileop/gemm/gemm_base.py (10)
GemmBase(12-120)policy(119-120)M(34-35)N(38-39)trans_A(46-47)trans_B(50-51)chunk(63-64)A(67-68)B(71-72)C(75-76)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/intrinsics/mfma_macro_generator.py (8)
MatrixCoreIntrinEmitter(35-643)make_mfma_store_layout(576-643)make_mfma_load_layout(433-574)ldmatrix_a(254-292)ldmatrix_a(700-773)ldmatrix_b(294-337)ldmatrix_b(775-850)mfma(339-382)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/transform/simplify.py (1)
_Simplify(31-49)tilelang/tileop/gemm/__init__.py (2)
infer_layout(75-79)lower(81-85)
tilelang/intrinsics/mfma_macro_generator.py (7)
tilelang/intrinsics/utils.py (1)
mfma_store_index_map(85-86)tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/intrinsics/mfma_layout.py (16)
shared_16x4_to_local_64x1_layout_A(6-8)shared_4x16_to_local_64x1_layout_B(17-19)shared_16x16_to_local_64x4_layout_A(46-49)shared_16x16_to_local_64x4_layout_B(58-61)shared_16x32_to_local_64x8_layout_A(88-91)shared_16x32_to_local_64x8_layout_B(100-103)shared_16x64_to_local_64x16_layout_A(112-115)shared_16x64_to_local_64x16_layout_B(124-127)thread_id_shared_access_64x1_to_16x4_layout_A(11-14)thread_id_shared_access_64x1_to_4x16_layout_B(22-25)thread_id_shared_access_64x4_to_16x16_layout_A(40-43)thread_id_shared_access_64x4_to_16x16_layout_B(52-55)thread_id_shared_access_64x8_to_16x32_layout_A(82-85)thread_id_shared_access_64x8_to_16x32_layout_B(94-97)thread_id_shared_access_64x16_to_16x64_layout_A(106-109)thread_id_shared_access_64x16_to_16x64_layout_B(118-121)tilelang/intrinsics/mma_macro_generator.py (7)
get_store_index_map(160-166)get_thread_binding(152-158)stmatrix(398-451)forward_thread(536-541)forward_thread(629-644)forward_index(543-548)forward_index(646-658)tilelang/language/kernel.py (4)
get_thread_binding(171-176)get_thread_binding(306-310)KernelLaunchFrame(95-226)Current(135-141)tilelang/language/tir/entry.py (1)
macro(66-117)tilelang/layout/fragment.py (3)
Fragment(13-213)replicate(146-160)repeat(123-144)
tilelang/tileop/gemm/__init__.py (1)
tilelang/tileop/gemm/gemm_mfma.py (1)
GemmMFMA(13-217)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (11)
tilelang/language/allocate.py (2)
alloc_shared(24-39)alloc_fragment(56-67)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/gemm.py (1)
gemm_v2(215-426)tilelang/jit/__init__.py (1)
compile(30-79)tilelang/jit/kernel.py (2)
out_idx(453-454)get_profiler(367-383)tilelang/transform/pass_config.py (1)
PassConfigKey(6-144)tilelang/profiler/__init__.py (1)
assert_allclose(77-146)tilelang/language/annotations.py (1)
annotate_layout(25-36)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)
🪛 Ruff (0.14.1)
examples/plot_layout/fragment_mfma_load_a.py
62-62: Avoid specifying long messages outside the exception class
(TRY003)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_mfma.py
60-61: Avoid specifying long messages outside the exception class
(TRY003)
63-63: Unused method argument: layout_map
(ARG002)
204-205: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mfma_macro_generator.py
459-459: Local variable dtype is assigned to but never used
Remove assignment to unused variable dtype
(F841)
486-486: Avoid specifying long messages outside the exception class
(TRY003)
501-501: Avoid specifying long messages outside the exception class
(TRY003)
558-558: Avoid specifying long messages outside the exception class
(TRY003)
572-572: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (2)
tilelang/intrinsics/mfma_macro_generator.py (1)
209-216: Minor: share helpers via existing getters (consistency).The new get_store_index_map/forward_* closures look good and align with MMA emitter interfaces. LGTM.
Also applies to: 608-643
tilelang/tileop/gemm/__init__.py (1)
31-34: Enum ordinals are correctly aligned; however, Python lacks TCGEN5MMA support while C++ can select it.The C++ and Python enum ordinals (MMA=0, WGMMA=1, TCGEN5MMA=2, MFMA=3) are already synchronized and the FFI integer-to-enum conversion works correctly. However, when C++ returns GemmInst::kTCGEN5MMA (for SM100 targets matching buffer/dtype constraints), the Python side immediately raises NotImplementedError in _get_implementation_class since TCGEN5MMA is not yet implemented.
This is not an enum alignment issue, but an incomplete feature: the C++ backend supports SM100 TCGEN5MMA instruction selection, while the Python dispatch layer does not. When a compatible SM100 configuration is encountered, the call will fail at runtime in Python.
If TCGEN5MMA support is planned, the Python implementation class needs to be added. If it is not yet ready, the C++ AllowTCGEN5MMA check should be disabled or guarded until Python support is complete.
| from tilelang.intrinsics.mfma_layout import ( | ||
| shared_16x4_to_local_64x1_layout_A, | ||
| shared_16x16_to_local_64x4_layout_A, | ||
| shared_16x32_to_local_64x8_layout_A, | ||
| shared_16x64_to_local_64x16_layout_A, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use B-specific layout transforms for matrix='B'.
Currently B paths reuse A transforms, which is incorrect for MFMA B operand.
-from tilelang.intrinsics.mfma_layout import (
- shared_16x4_to_local_64x1_layout_A,
- shared_16x16_to_local_64x4_layout_A,
- shared_16x32_to_local_64x8_layout_A,
- shared_16x64_to_local_64x16_layout_A,
-)
+from tilelang.intrinsics.mfma_layout import (
+ shared_16x4_to_local_64x1_layout_A,
+ shared_16x16_to_local_64x4_layout_A,
+ shared_16x32_to_local_64x8_layout_A,
+ shared_16x64_to_local_64x16_layout_A,
+ shared_4x16_to_local_64x1_layout_B,
+ shared_16x16_to_local_64x4_layout_B,
+ shared_16x32_to_local_64x8_layout_B,
+ shared_16x64_to_local_64x16_layout_B,
+)
@@ def make_mfma_load_base_layout(...):
- if k_dim == 4:
- transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
- transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
+ if k_dim == 4:
+ transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
+ transform_func_sr_b = shared_4x16_to_local_64x1_layout_B
@@
- elif k_dim == 16:
- transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
- transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
+ elif k_dim == 16:
+ transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
+ transform_func_sr_b = shared_16x16_to_local_64x4_layout_B
@@
- elif k_dim == 32:
- transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
- transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
+ elif k_dim == 32:
+ transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
+ transform_func_sr_b = shared_16x32_to_local_64x8_layout_B
@@
- elif k_dim == 64:
- transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
- transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
+ elif k_dim == 64:
+ transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
+ transform_func_sr_b = shared_16x64_to_local_64x16_layout_BAlso applies to: 49-61, 71-83
🤖 Prompt for AI Agents
In examples/plot_layout/fragment_mfma_load_a.py around lines 6-11 (and also
affecting blocks at 49-61 and 71-83), the code imports and uses A-specific MFMA
layout transforms for both matrix='A' and matrix='B'; replace those with the
B-specific transform functions for the B operand (e.g., import and use
shared_16x4_to_local_64x1_layout_B, shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_B, shared_16x64_to_local_64x16_layout_B) and
update all code paths that handle matrix=='B' to call the corresponding _B
transform functions instead of the A variants.
| class GemmMFMA(GemmBase): | ||
|
|
||
| def infer_layout(self, target: Target, thread_nums: int): | ||
| m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Constructor missing: GemmMFMA() is instantiated with an argument and will TypeError.
GemmPy calls impl_class(self), passing the GemmPy node. Without init(gemm_node), this raises at runtime.
Apply:
@@
-class GemmMFMA(GemmBase):
+from tvm.ir.base import Node # at top with other imports
+
+class GemmMFMA(GemmBase):
+ def __init__(self, gemm_node: Node):
+ # Store GemmPy node for GemmBase accessors (A/B/C, shapes, policy, etc.)
+ self.gemm_node = gemm_node📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class GemmMFMA(GemmBase): | |
| def infer_layout(self, target: Target, thread_nums: int): | |
| m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, | |
| from tvm.ir.base import Node | |
| class GemmMFMA(GemmBase): | |
| def __init__(self, gemm_node: Node): | |
| # Store GemmPy node for GemmBase accessors (A/B/C, shapes, policy, etc.) | |
| self.gemm_node = gemm_node | |
| def infer_layout(self, target: Target, thread_nums: int): | |
| m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, |
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 13 to 16, the GemmMFMA class
lacks an __init__ that accepts the gemm_node argument which causes a TypeError
when instantiated via impl_class(self); add a constructor def __init__(self,
gemm_node): that calls the base initializer (super().__init__(gemm_node)) and
performs any GemmMFMA-specific setup (e.g., store gemm_node or initialize
policy/state) so instances can be created with the GemmPy node without error.
| m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, | ||
| False) | ||
| print(f"thread_nums: {thread_nums}, m_warp: {m_warp}, n_warp: {n_warp}") | ||
|
|
||
| warp_row_tiles = int(self.M // m_warp) | ||
| warp_col_tiles = int(self.N // n_warp) | ||
| mfma_emitter = MatrixCoreIntrinEmitter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Guard warp tiling assumptions.
Ensure M and N are divisible by warp partitions; otherwise emitter math breaks.
@@ def infer_layout(self, target: Target, thread_nums: int):
- m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
- False)
- print(f"thread_nums: {thread_nums}, m_warp: {m_warp}, n_warp: {n_warp}")
+ m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
+ assert self.M % m_warp == 0 and self.N % n_warp == 0, (
+ f"M ({self.M}) % m_warp ({m_warp}) or N ({self.N}) % n_warp ({n_warp}) != 0"
+ )There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (2)
examples/plot_layout/fragment_mfma_load_a.py (1)
6-11: Use B‑specific transforms for matrix='B'.B paths currently reuse A transforms; this yields wrong thread/local mapping for MFMA B.
from tilelang.intrinsics.mfma_layout import ( shared_16x4_to_local_64x1_layout_A, + shared_4x16_to_local_64x1_layout_B, shared_16x16_to_local_64x4_layout_A, + shared_16x16_to_local_64x4_layout_B, shared_16x32_to_local_64x8_layout_A, + shared_16x32_to_local_64x8_layout_B, shared_16x64_to_local_64x16_layout_A, + shared_16x64_to_local_64x16_layout_B, ) @@ if k_dim == 4: transform_func_sr_a = shared_16x4_to_local_64x1_layout_A - transform_func_sr_b = shared_16x4_to_local_64x1_layout_A + transform_func_sr_b = shared_4x16_to_local_64x1_layout_B elif k_dim == 16: transform_func_sr_a = shared_16x16_to_local_64x4_layout_A - transform_func_sr_b = shared_16x16_to_local_64x4_layout_A + transform_func_sr_b = shared_16x16_to_local_64x4_layout_B elif k_dim == 32: transform_func_sr_a = shared_16x32_to_local_64x8_layout_A - transform_func_sr_b = shared_16x32_to_local_64x8_layout_A + transform_func_sr_b = shared_16x32_to_local_64x8_layout_B elif k_dim == 64: transform_func_sr_a = shared_16x64_to_local_64x16_layout_A - transform_func_sr_b = shared_16x64_to_local_64x16_layout_A + transform_func_sr_b = shared_16x64_to_local_64x16_layout_BAlso applies to: 49-61
tilelang/tileop/gemm/gemm_mfma.py (1)
13-14: Constructor missing: wire GemmPy node into GemmMFMA.GemmPy instantiates impl_class(self); without init(gemm_node) the subclass can’t access properties via GemmBase.
+from tvm.ir.base import Node @@ -class GemmMFMA(GemmBase): +class GemmMFMA(GemmBase): + + def __init__(self, gemm_node: Node): + self.gemm_node = gemm_node
🧹 Nitpick comments (5)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (1)
54-68: Plumb an explicit AMD target through helpers to reliably hit MFMA. Gate verbose IR prints.“auto” may pick a non‑AMD backend on CI/dev hosts, bypassing MFMA. Add a target parameter to run_gemm_* and pass it into tilelang.compile; also add a verbose flag for RR to avoid noisy logs by default.
@@ def run_gemm_ss( - num_stages=3, - num_threads=256, + num_stages=3, + num_threads=256, + target="auto", ): @@ - kernel = tilelang.compile( + kernel = tilelang.compile( program, out_idx=[2], + target=target, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) @@ def run_gemm_rs( - num_stages=3, - num_threads=256, + num_stages=3, + num_threads=256, + target="auto", ): @@ - kernel = tilelang.compile( + kernel = tilelang.compile( program, out_idx=[2], + target=target, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) @@ def run_gemm_sr( - num_stages=3, - num_threads=256, + num_stages=3, + num_threads=256, + target="auto", ): @@ - kernel = tilelang.compile( + kernel = tilelang.compile( program, out_idx=[2], + target=target, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) @@ def run_gemm_rr( - num_stages=3, - num_threads=256, + num_stages=3, + num_threads=256, + target="auto", + verbose=False, ): @@ - kernel = tilelang.compile( + kernel = tilelang.compile( program, out_idx=[2], + target=target, pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - print(program) - - print(kernel.get_kernel_source()) + if verbose: + print(program) + print(kernel.get_kernel_source())To verify MFMA is selected at runtime, run the tests with an explicit AMD target (e.g., “rocm”/“amdgpu” as supported by your TVM build) and confirm the GemmInst resolves to MFMA in the lowered IR or logs.
Also applies to: 85-91, 180-194, 211-217, 302-316, 333-339, 428-442, 459-465
examples/plot_layout/fragment_mfma_load_a.py (1)
18-22: Docstring nit: this builds a load layout, not a store layout.Rename “storing MFMA results” → “loading MFMA operands” for clarity.
- Create a layout function for storing MFMA results into a fragment buffer. + Create a layout function for loading MFMA operands into a fragment buffer.tilelang/tileop/gemm/gemm_mfma.py (1)
61-79: Strengthen K‑chunk check; optionally validate layout_map presence.Add modulus check on K and touch layout_map to avoid ARG002.
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) @@ thread_var=thread_var, ) + # Ensure expected entries exist (A,B,C) – helps catch plumbing mistakes + _ = layout_map @@ - assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + assert (block_K % micro_size_k) == 0, ( + f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})" + )Also applies to: 80-91
tilelang/intrinsics/mfma_macro_generator.py (2)
455-456: Drop redundant local import.Already imported at module top; keep one place.
- from tilelang.utils import is_fragment
433-454: Docstring clarity: load vs store.These helpers generate load/store layouts respectively; make wording explicit.
- Create a layout function for storing MFMA results into a fragment buffer. + Create a layout function describing MFMA load layout for a fragment buffer. @@ - Create a layout function for storing MFMA results into a fragment buffer. + Create a layout function describing MFMA store layout for a fragment buffer.Also applies to: 575-594
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/plot_layout/fragment_mfma_load_a.py(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py(1 hunks)tilelang/intrinsics/mfma_macro_generator.py(11 hunks)tilelang/tileop/gemm/__init__.py(3 hunks)tilelang/tileop/gemm/gemm_mfma.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
tilelang/tileop/gemm/gemm_mfma.py (7)
tilelang/tileop/gemm/gemm_base.py (10)
GemmBase(12-120)policy(119-120)M(34-35)N(38-39)trans_A(46-47)trans_B(50-51)chunk(63-64)A(67-68)B(71-72)C(75-76)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/intrinsics/mfma_macro_generator.py (8)
MatrixCoreIntrinEmitter(35-642)make_mfma_store_layout(575-642)make_mfma_load_layout(433-573)ldmatrix_a(254-292)ldmatrix_a(699-772)ldmatrix_b(294-337)ldmatrix_b(774-849)mfma(339-382)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/transform/simplify.py (1)
_Simplify(31-49)tilelang/tileop/gemm/__init__.py (2)
infer_layout(75-79)lower(81-85)tilelang/language/allocate.py (1)
alloc_local(42-53)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (10)
tilelang/language/allocate.py (2)
alloc_shared(24-39)alloc_fragment(56-67)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/gemm.py (1)
gemm_v2(215-426)tilelang/jit/__init__.py (1)
compile(30-79)tilelang/transform/pass_config.py (1)
PassConfigKey(6-144)tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)tilelang/profiler/__init__.py (1)
assert_allclose(77-146)tilelang/language/annotations.py (1)
annotate_layout(25-36)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)
tilelang/tileop/gemm/__init__.py (1)
tilelang/tileop/gemm/gemm_mfma.py (1)
GemmMFMA(13-215)
tilelang/intrinsics/mfma_macro_generator.py (6)
tilelang/intrinsics/utils.py (1)
mfma_store_index_map(85-86)tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/intrinsics/mfma_layout.py (16)
shared_16x4_to_local_64x1_layout_A(6-8)shared_4x16_to_local_64x1_layout_B(17-19)shared_16x16_to_local_64x4_layout_A(46-49)shared_16x16_to_local_64x4_layout_B(58-61)shared_16x32_to_local_64x8_layout_A(88-91)shared_16x32_to_local_64x8_layout_B(100-103)shared_16x64_to_local_64x16_layout_A(112-115)shared_16x64_to_local_64x16_layout_B(124-127)thread_id_shared_access_64x1_to_16x4_layout_A(11-14)thread_id_shared_access_64x1_to_4x16_layout_B(22-25)thread_id_shared_access_64x4_to_16x16_layout_A(40-43)thread_id_shared_access_64x4_to_16x16_layout_B(52-55)thread_id_shared_access_64x8_to_16x32_layout_A(82-85)thread_id_shared_access_64x8_to_16x32_layout_B(94-97)thread_id_shared_access_64x16_to_16x64_layout_A(106-109)thread_id_shared_access_64x16_to_16x64_layout_B(118-121)tilelang/intrinsics/mma_macro_generator.py (7)
get_store_index_map(160-166)get_thread_binding(152-158)stmatrix(398-451)forward_thread(536-541)forward_thread(629-644)forward_index(543-548)forward_index(646-658)tilelang/language/kernel.py (4)
get_thread_binding(171-176)get_thread_binding(306-310)KernelLaunchFrame(95-226)Current(135-141)tilelang/language/tir/entry.py (1)
macro(66-117)
examples/plot_layout/fragment_mfma_load_a.py (4)
tilelang/intrinsics/utils.py (1)
get_mma_micro_size(89-109)tilelang/intrinsics/mfma_layout.py (4)
shared_16x4_to_local_64x1_layout_A(6-8)shared_16x16_to_local_64x4_layout_A(46-49)shared_16x32_to_local_64x8_layout_A(88-91)shared_16x64_to_local_64x16_layout_A(112-115)tilelang/tools/plot_layout.py (1)
plot_layout(4-207)tilelang/layout/fragment.py (2)
repeat(123-144)replicate(146-160)
🪛 Ruff (0.14.1)
tilelang/tileop/gemm/gemm_mfma.py
58-59: Avoid specifying long messages outside the exception class
(TRY003)
61-61: Unused method argument: layout_map
(ARG002)
202-203: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mfma_macro_generator.py
485-485: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Avoid specifying long messages outside the exception class
(TRY003)
557-557: Avoid specifying long messages outside the exception class
(TRY003)
571-571: Avoid specifying long messages outside the exception class
(TRY003)
examples/plot_layout/fragment_mfma_load_a.py
62-62: Avoid specifying long messages outside the exception class
(TRY003)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (2)
tilelang/tileop/gemm/__init__.py (2)
118-126: MFMA mapping looks good.Routing MFMA to GemmMFMA is correct; TCGEN5MMA guarded with NotImplemented is fine for now.
30-35: No issues found. Enums are properly synchronized across FFI boundary.The Python
GemmInstenum (MMA=0, WGMMA=1, TCGEN5MMA=2, MFMA=3) matches the C++GemmInstenum (kMMA, kWGMMA, kTCGEN5MMA, kMFMA) with identical implicit numeric values. The FFI call at line 103 receives the native enum value directly and wraps it in the Python IntEnum—no mis-dispatch risk exists because the numeric values align. The code comment at line 29 confirms intentional synchronization.Likely an incorrect or invalid review comment.
| print(a_local_stride, b_local_stride) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug prints in hot paths.
These prints execute per‑tile and will severely slow kernels and spam logs.
- print(a_local_stride, b_local_stride)
+ # debug print removed
@@
- print(self.a_preshuffle)
+ # debug print removedAlso applies to: 762-762
🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around lines 361-362 (and also at
line 762), remove the debug print statements that emit a_local_stride and
b_local_stride because they run in hot per-tile code paths; replace them with
either conditional debug logging gated by a verbose/debug flag or remove them
entirely so no I/O occurs during normal kernel generation. Ensure any needed
diagnostic info can be toggled via a logger.debug(...) behind a
runtime/compile-time flag rather than unguarded print calls.
| if k_dim == 4: | ||
| transform_func_sr_a = shared_16x4_to_local_64x1_layout_A | ||
| transform_func_sr_b = shared_16x4_to_local_64x1_layout_A | ||
| elif k_dim == 16: | ||
| transform_func_sr_a = shared_16x16_to_local_64x4_layout_A | ||
| transform_func_sr_b = shared_16x16_to_local_64x4_layout_A | ||
| elif k_dim == 32: | ||
| transform_func_sr_a = shared_16x32_to_local_64x8_layout_A | ||
| transform_func_sr_b = shared_16x32_to_local_64x8_layout_A | ||
| elif k_dim == 64: | ||
| transform_func_sr_a = shared_16x64_to_local_64x16_layout_A | ||
| transform_func_sr_b = shared_16x64_to_local_64x16_layout_A | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct B‑operand load transforms in make_mfma_load_layout.
Use B‑specific transforms; current A‑ones misplace lanes/local IDs for matrix B.
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
- transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
+ transform_func_sr_b = shared_4x16_to_local_64x1_layout_B
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
- transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
+ transform_func_sr_b = shared_16x16_to_local_64x4_layout_B
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
- transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
+ transform_func_sr_b = shared_16x32_to_local_64x8_layout_B
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
- transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
+ transform_func_sr_b = shared_16x64_to_local_64x16_layout_B
@@
- else:
- raise ValueError(f"Unsupported matrix {matrix}")
+ else:
+ raise ValueError(f"Unsupported matrix {matrix}")Also applies to: 493-501
🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around lines 472 to 484 (and
similarly lines 493 to 501), the code assigns A-specific shared-to-local
transform functions to transform_func_sr_b which misplaces lanes/local IDs for
matrix B; replace the transform_func_sr_b assignments with the corresponding
B-specific transform functions for each k_dim case (e.g., use
shared_16x4_to_local_64x1_layout_B for k_dim==4,
shared_16x16_to_local_64x4_layout_B for k_dim==16,
shared_16x32_to_local_64x8_layout_B for k_dim==32, and
shared_16x64_to_local_64x16_layout_B for k_dim==64) so B uses the correct
lane/local ID mapping.
| def infer_layout(self, target: Target, thread_nums: int): | ||
| m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, | ||
| False) | ||
| warp_row_tiles = int(self.M // m_warp) | ||
| warp_col_tiles = int(self.N // n_warp) | ||
| mfma_emitter = MatrixCoreIntrinEmitter( | ||
| a_dtype=self.in_dtype, | ||
| b_dtype=self.in_dtype, | ||
| accum_dtype=self.accum_dtype, | ||
| a_transposed=self.trans_A, | ||
| b_transposed=self.trans_B, | ||
| block_row_warps=m_warp, | ||
| block_col_warps=n_warp, | ||
| warp_row_tiles=warp_row_tiles, | ||
| warp_col_tiles=warp_col_tiles, | ||
| chunk=self.chunk, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard warp tiling divisibility before creating emitter math.
Ensure per‑warp tile counts are multiples of MFMA micro sizes to avoid silent floor division in emitter.
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
@@
chunk=self.chunk,
)
+ # Validate divisibility: warp tiles must be integer multiples of micro tile sizes
+ assert warp_row_tiles % mfma_emitter.micro_size_x == 0, (
+ f"warp_row_tiles ({warp_row_tiles}) not divisible by micro_size_x "
+ f"({mfma_emitter.micro_size_x})"
+ )
+ assert warp_col_tiles % mfma_emitter.micro_size_y == 0, (
+ f"warp_col_tiles ({warp_col_tiles}) not divisible by micro_size_y "
+ f"({mfma_emitter.micro_size_y})"
+ )🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 15 to 31, you compute
warp_row_tiles and warp_col_tiles then instantiate MatrixCoreIntrinEmitter
without verifying that those per-warp tile counts are multiples of the MFMA
micro-tile sizes; add a guard that fetches the MFMA micro-tile dimensions (e.g.,
from self.policy (mfma_micro_m, mfma_micro_n) or from MatrixCoreIntrinEmitter if
that API exists), check warp_row_tiles % mfma_micro_m == 0 and warp_col_tiles %
mfma_micro_n == 0, and if either check fails raise a clear ValueError explaining
which dimension is not divisible and what the required micro-tile size is,
before creating the emitter.
| @T.prim_func | ||
| def _gemm_rsr() -> None: | ||
| """ | ||
| The inner macro that loads data from shared buffers A_shared and | ||
| B_shared into local fragments, then issues Matrix Core mfma ops, | ||
| accumulating into C_local. | ||
| """ | ||
|
|
||
| for ki in T.serial(0, (block_K // micro_size_k)): | ||
| # Perform Matrix Multiplication | ||
| mfma_emitter.mfma(A_local, B_local, C_local, ki) | ||
|
|
||
| # Simplify to optimize the index computing | ||
| # Must inline let statements to simplify the analysis | ||
| return _Simplify(_gemm_rsr, inline_let=True) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename duplicate inner func to reflect RR path.
Avoid two “_gemm_rsr” definitions; name RR kernel accordingly.
- def _gemm_rsr() -> None:
+ def _gemm_rr() -> None:
@@
- return _Simplify(_gemm_rsr, inline_let=True)
+ return _Simplify(_gemm_rr, inline_let=True)🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 186 to 201, there is a
duplicate inner function named "_gemm_rsr" used for the RR path; rename this
inner function to a distinct RR-specific name (e.g., "_gemm_rr") and update its
decorator/definition and the subsequent return call (change _Simplify(_gemm_rsr,
...) to _Simplify(_gemm_rr, ...)). Also scan and update any local references to
that inner function within this block so they point to the new name.
|
@codex review |
|
Codex Review: Didn't find any major issues. What shall we delve into next? ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
Summary by CodeRabbit
New Features
Tests
Documentation