Skip to content

Conversation

@Paran0idy
Copy link
Contributor

@Paran0idy Paran0idy commented Oct 27, 2025

Summary by CodeRabbit

  • New Features

    • MFMA-based GEMM backend with support for multiple data-flow configurations, swizzled/load-store layouts, and mixed-precision math
    • Added an enum path for a new MMA variant (TCGEN5MMA) and updated GEMM implementation selection
  • Tests

    • End-to-end GEMM tests covering float16 and int8 across transposition and layout combinations, with correctness and profiling
  • Documentation

    • New example that constructs, visualizes, and prints MFMA load layouts

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 27, 2025

Walkthrough

Adds MFMA-based GEMM support: new GemmMFMA implementation with layout inference and lowering for four data-flow variants, updated MatrixCoreIntrinEmitter with layout/thread-binding helpers and mfma signature changes, GEMM enum updates, tests, and a layout example.

Changes

Cohort / File(s) Summary
MFMA Intrinsics Infrastructure
tilelang/intrinsics/mfma_macro_generator.py
Added thread_var to MatrixCoreIntrinEmitter, new APIs get_store_index_map(), get_thread_binding(), make_mfma_load_layout(), make_mfma_store_layout(); adjusted mfma() signature and fragment-aware layout/stride handling; updated ldmatrix/stmatrix to use new abstractions.
MFMA GEMM Implementation
tilelang/tileop/gemm/gemm_mfma.py
New GemmMFMA class with infer_layout() and lower() implementing four data-flow variants (SS/SR/RS/RR), layout inference, lowering to TIR, and helper predicates (is_gemm_*).
GEMM Framework & Enum Updates
tilelang/tileop/gemm/__init__.py
Renamed WGMMMAWGMMA, added TCGEN5MMA, reordered enum values, added is_tcgen5mma(), imported GemmMFMA, and updated implementation dispatch to return GemmMFMA for MFMA.
GEMM Test Coverage
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py
New comprehensive tests and helpers for GEMM variants (matmul, run_gemm_/test_gemm_ for SS/SR/RS/RR) covering float16/int8, transpositions, reference checks, and profiling.
MFMA Layout Example
examples/plot_layout/fragment_mfma_load_a.py
New example with make_mfma_load_base_layout() and module constants (block_rows, block_cols, warp_rows, warp_cols, chunk); composes and visualizes base/warp/block layouts.

Sequence Diagram(s)

sequenceDiagram
    participant User as Client Code
    participant Gemm as GemmMFMA
    participant Emitter as MatrixCoreIntrinEmitter
    participant TIR as TIR Lowering

    User->>Gemm: infer_layout(target, thread_nums)
    Gemm->>Emitter: instantiate(thread_var)
    Emitter-->>Gemm: A/B/C layouts

    User->>Gemm: lower(layout_map, target, thread_nums, thread_var)
    Gemm->>Gemm: select variant (SS/SR/RS/RR)
    alt SS (shared/shared)
        Gemm->>Emitter: make_mfma_load_layout(A), make_mfma_load_layout(B)
    else Fragment involved
        Gemm->>Emitter: use swizzled/fragment layouts
    end
    Gemm->>Emitter: mfma(A_local, B_local, C_local, k_inner)
    Emitter->>TIR: emit prim_func
    TIR-->>User: compiled kernel
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Pay attention to: tilelang/tileop/gemm/gemm_mfma.py (branching per variant, buffer/layout correctness), tilelang/intrinsics/mfma_macro_generator.py (mfma signature, thread binding, stride computation), and tests in testing/python/... (correctness and coverage). Enum reordering in tilelang/tileop/gemm/__init__.py may require cross-checks with callers.

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 Hopping through registers, layouts weave and play,

Threads bind, fragments load, MFMA leads the way.
SS to RR, four dances spun,
Tiles align, the kernels run—
A rabbit cheers: compute's in bloom today!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 26.79% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[AMD] Supoort T.gemm_v2 for AMD Backend" clearly and specifically communicates the main objective of the changeset. The title directly relates to the primary changes, which collectively implement MFMA-based GEMM support for the AMD backend through the new GemmMFMA class, updated macro generation infrastructure, comprehensive tests across multiple GEMM variants (ss, rs, sr, rr), and example layout construction. The title is concise, specific, and provides sufficient context for understanding the core contribution at a glance. However, the title contains a typo: "Supoort" should be "Support."
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (10)
tilelang/tileop/gemm/gemm_mfma.py (4)

18-19: Remove stray debug prints from codegen paths.

print(...) in infer_layout/mfma pollutes logs and slows CI.

-        print(f"thread_nums: {thread_nums}, m_warp: {m_warp}, n_warp: {n_warp}")

And in mfma():

-        print(a_local_stride, b_local_stride)

Also applies to: 361-362


63-81: layout_map is unused in lower().

Acknowledge or consume to appease linters without behavior change. Based on static analysis hints.

 def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
+        _ = layout_map  # unused: kept for API parity

93-94: Assertion message style (nit).

Long f-strings in asserts trigger TRY003; consider shorter message or raising ValueError with message. Based on static analysis hints.

-        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+        assert block_K >= micro_size_k, "block_K must be >= micro_size_k"

160-202: Duplicate helper name _gemm_rsr for RS and RR (clarity).

Both branches define _gemm_rsr; harmless but confusing during debugging.

  • Rename RR’s inner prim_func to _gemm_rrr (or similar).

Also applies to: 188-203

examples/plot_layout/fragment_mfma_load_a.py (2)

18-22: Docstring wording (minor).

This builds a load layout, not “storing MFMA results”.

-    Create a layout function for storing MFMA results into a fragment buffer.
+    Create a load layout function for MFMA fragments.

46-48: Typing nit.

transform_func_sr_a/b initialized to None but typed as Callable. Use Optional[Callable].

-from typing import Literal, Callable
+from typing import Literal, Callable, Optional
@@
-    transform_func_sr_a: Callable = None
-    transform_func_sr_b: Callable = None
+    transform_func_sr_a: Optional[Callable] = None
+    transform_func_sr_b: Optional[Callable] = None
tilelang/intrinsics/mfma_macro_generator.py (1)

356-362: Strip debug prints and unused local.

  • mfma(): print(...) in the hot path.
  • preshuffle ldmatrix_a(): stray print(self.a_preshuffle).
  • make_mfma_load_layout(): unused local dtype. Based on static analysis hints.
@@ def mfma(...):
-        print(a_local_stride, b_local_stride)
@@ def _warp_ldmatrix_a_shared(...):
-                print(self.a_preshuffle)
@@ def make_mfma_load_layout(...):
-        dtype = self.a_dtype if matrix_is_a else self.b_dtype

Also applies to: 763-763, 459-461

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (3)

85-93: Pin target to AMD for stability.

Tests are AMD‑specific; relying on target="auto" can select CUDA on CI hosts.

-    kernel = tilelang.compile(
+    kernel = tilelang.compile(
         program,
         out_idx=[2],
+        target="rocm",  # or "amdgpu"; choose the identifier your build uses
         pass_configs={

If your build expects "amdgpu" instead of "rocm", please adjust. Run a quick build locally to confirm.

Also applies to: 211-218, 333-341, 459-466


108-110: Don’t benchmark or spam logs in unit tests.

Benchmarks and kernel dumps are noisy and slow CI.

-    latency = profiler.do_bench(profiler.func, warmup=100)
-    print(f"GEMM SS latency: {latency} ms")
+    # Optional perf-only: enable via TL_BENCH=1
+    import os
+    if os.environ.get("TL_BENCH") == "1":
+        _ = profiler.do_bench(profiler.func, warmup=50)
@@
-    print(program)
-    print(kernel.get_kernel_source())
+    # Debug-only; guarded to avoid CI noise
+    import os
+    if os.environ.get("TL_DEBUG") == "1":
+        print(program)
+        print(kernel.get_kernel_source())

Also applies to: 459-469


95-106: Reference cast path (small).

matmul for int8 uses float accumulate then casts to out_dtype. This is fine; if flakiness arises, consider clamping before cast to emulate int8 saturation.

# inside ref_program before final cast:
# C = C.clamp(min_val, max_val)  # emulate saturation if needed

Also applies to: 221-230, 342-352, 472-481

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e1dc6a and 492a0c4.

📒 Files selected for processing (5)
  • examples/plot_layout/fragment_mfma_load_a.py (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (1 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (11 hunks)
  • tilelang/tileop/gemm/__init__.py (3 hunks)
  • tilelang/tileop/gemm/gemm_mfma.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
examples/plot_layout/fragment_mfma_load_a.py (5)
tilelang/intrinsics/utils.py (1)
  • get_mma_micro_size (89-109)
tilelang/intrinsics/mfma_layout.py (4)
  • shared_16x4_to_local_64x1_layout_A (6-8)
  • shared_16x16_to_local_64x4_layout_A (46-49)
  • shared_16x32_to_local_64x8_layout_A (88-91)
  • shared_16x64_to_local_64x16_layout_A (112-115)
tilelang/intrinsics/mfma_macro_generator.py (4)
  • forward_thread (517-522)
  • forward_thread (608-623)
  • forward_index (524-529)
  • forward_index (625-637)
tilelang/tools/plot_layout.py (1)
  • plot_layout (4-207)
tilelang/layout/fragment.py (2)
  • repeat (123-144)
  • replicate (146-160)
tilelang/tileop/gemm/gemm_mfma.py (6)
tilelang/tileop/gemm/gemm_base.py (10)
  • GemmBase (12-120)
  • policy (119-120)
  • M (34-35)
  • N (38-39)
  • trans_A (46-47)
  • trans_B (50-51)
  • chunk (63-64)
  • A (67-68)
  • B (71-72)
  • C (75-76)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/intrinsics/mfma_macro_generator.py (8)
  • MatrixCoreIntrinEmitter (35-643)
  • make_mfma_store_layout (576-643)
  • make_mfma_load_layout (433-574)
  • ldmatrix_a (254-292)
  • ldmatrix_a (700-773)
  • ldmatrix_b (294-337)
  • ldmatrix_b (775-850)
  • mfma (339-382)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (31-49)
tilelang/tileop/gemm/__init__.py (2)
  • infer_layout (75-79)
  • lower (81-85)
tilelang/intrinsics/mfma_macro_generator.py (7)
tilelang/intrinsics/utils.py (1)
  • mfma_store_index_map (85-86)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/intrinsics/mfma_layout.py (16)
  • shared_16x4_to_local_64x1_layout_A (6-8)
  • shared_4x16_to_local_64x1_layout_B (17-19)
  • shared_16x16_to_local_64x4_layout_A (46-49)
  • shared_16x16_to_local_64x4_layout_B (58-61)
  • shared_16x32_to_local_64x8_layout_A (88-91)
  • shared_16x32_to_local_64x8_layout_B (100-103)
  • shared_16x64_to_local_64x16_layout_A (112-115)
  • shared_16x64_to_local_64x16_layout_B (124-127)
  • thread_id_shared_access_64x1_to_16x4_layout_A (11-14)
  • thread_id_shared_access_64x1_to_4x16_layout_B (22-25)
  • thread_id_shared_access_64x4_to_16x16_layout_A (40-43)
  • thread_id_shared_access_64x4_to_16x16_layout_B (52-55)
  • thread_id_shared_access_64x8_to_16x32_layout_A (82-85)
  • thread_id_shared_access_64x8_to_16x32_layout_B (94-97)
  • thread_id_shared_access_64x16_to_16x64_layout_A (106-109)
  • thread_id_shared_access_64x16_to_16x64_layout_B (118-121)
tilelang/intrinsics/mma_macro_generator.py (7)
  • get_store_index_map (160-166)
  • get_thread_binding (152-158)
  • stmatrix (398-451)
  • forward_thread (536-541)
  • forward_thread (629-644)
  • forward_index (543-548)
  • forward_index (646-658)
tilelang/language/kernel.py (4)
  • get_thread_binding (171-176)
  • get_thread_binding (306-310)
  • KernelLaunchFrame (95-226)
  • Current (135-141)
tilelang/language/tir/entry.py (1)
  • macro (66-117)
tilelang/layout/fragment.py (3)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/tileop/gemm/__init__.py (1)
tilelang/tileop/gemm/gemm_mfma.py (1)
  • GemmMFMA (13-217)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (11)
tilelang/language/allocate.py (2)
  • alloc_shared (24-39)
  • alloc_fragment (56-67)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/gemm.py (1)
  • gemm_v2 (215-426)
tilelang/jit/__init__.py (1)
  • compile (30-79)
tilelang/jit/kernel.py (2)
  • out_idx (453-454)
  • get_profiler (367-383)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-144)
tilelang/profiler/__init__.py (1)
  • assert_allclose (77-146)
tilelang/language/annotations.py (1)
  • annotate_layout (25-36)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
🪛 Ruff (0.14.1)
examples/plot_layout/fragment_mfma_load_a.py

62-62: Avoid specifying long messages outside the exception class

(TRY003)


83-83: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_mfma.py

60-61: Avoid specifying long messages outside the exception class

(TRY003)


63-63: Unused method argument: layout_map

(ARG002)


204-205: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mfma_macro_generator.py

459-459: Local variable dtype is assigned to but never used

Remove assignment to unused variable dtype

(F841)


486-486: Avoid specifying long messages outside the exception class

(TRY003)


501-501: Avoid specifying long messages outside the exception class

(TRY003)


558-558: Avoid specifying long messages outside the exception class

(TRY003)


572-572: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (2)
tilelang/intrinsics/mfma_macro_generator.py (1)

209-216: Minor: share helpers via existing getters (consistency).

The new get_store_index_map/forward_* closures look good and align with MMA emitter interfaces. LGTM.

Also applies to: 608-643

tilelang/tileop/gemm/__init__.py (1)

31-34: Enum ordinals are correctly aligned; however, Python lacks TCGEN5MMA support while C++ can select it.

The C++ and Python enum ordinals (MMA=0, WGMMA=1, TCGEN5MMA=2, MFMA=3) are already synchronized and the FFI integer-to-enum conversion works correctly. However, when C++ returns GemmInst::kTCGEN5MMA (for SM100 targets matching buffer/dtype constraints), the Python side immediately raises NotImplementedError in _get_implementation_class since TCGEN5MMA is not yet implemented.

This is not an enum alignment issue, but an incomplete feature: the C++ backend supports SM100 TCGEN5MMA instruction selection, while the Python dispatch layer does not. When a compatible SM100 configuration is encountered, the call will fail at runtime in Python.

If TCGEN5MMA support is planned, the Python implementation class needs to be added. If it is not yet ready, the C++ AllowTCGEN5MMA check should be disabled or guarded until Python support is complete.

Comment on lines +6 to +11
from tilelang.intrinsics.mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_16x16_to_local_64x4_layout_A,
shared_16x32_to_local_64x8_layout_A,
shared_16x64_to_local_64x16_layout_A,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use B-specific layout transforms for matrix='B'.

Currently B paths reuse A transforms, which is incorrect for MFMA B operand.

-from tilelang.intrinsics.mfma_layout import (
-    shared_16x4_to_local_64x1_layout_A,
-    shared_16x16_to_local_64x4_layout_A,
-    shared_16x32_to_local_64x8_layout_A,
-    shared_16x64_to_local_64x16_layout_A,
-)
+from tilelang.intrinsics.mfma_layout import (
+    shared_16x4_to_local_64x1_layout_A,
+    shared_16x16_to_local_64x4_layout_A,
+    shared_16x32_to_local_64x8_layout_A,
+    shared_16x64_to_local_64x16_layout_A,
+    shared_4x16_to_local_64x1_layout_B,
+    shared_16x16_to_local_64x4_layout_B,
+    shared_16x32_to_local_64x8_layout_B,
+    shared_16x64_to_local_64x16_layout_B,
+)
@@ def make_mfma_load_base_layout(...):
-    if k_dim == 4:
-        transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
-        transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
+    if k_dim == 4:
+        transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
+        transform_func_sr_b = shared_4x16_to_local_64x1_layout_B
@@
-    elif k_dim == 16:
-        transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
-        transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
+    elif k_dim == 16:
+        transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
+        transform_func_sr_b = shared_16x16_to_local_64x4_layout_B
@@
-    elif k_dim == 32:
-        transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
-        transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
+    elif k_dim == 32:
+        transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
+        transform_func_sr_b = shared_16x32_to_local_64x8_layout_B
@@
-    elif k_dim == 64:
-        transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
-        transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
+    elif k_dim == 64:
+        transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
+        transform_func_sr_b = shared_16x64_to_local_64x16_layout_B

Also applies to: 49-61, 71-83

🤖 Prompt for AI Agents
In examples/plot_layout/fragment_mfma_load_a.py around lines 6-11 (and also
affecting blocks at 49-61 and 71-83), the code imports and uses A-specific MFMA
layout transforms for both matrix='A' and matrix='B'; replace those with the
B-specific transform functions for the B operand (e.g., import and use
shared_16x4_to_local_64x1_layout_B, shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_B, shared_16x64_to_local_64x16_layout_B) and
update all code paths that handle matrix=='B' to call the corresponding _B
transform functions instead of the A variants.

Comment on lines +13 to +16
class GemmMFMA(GemmBase):

def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Constructor missing: GemmMFMA() is instantiated with an argument and will TypeError.

GemmPy calls impl_class(self), passing the GemmPy node. Without init(gemm_node), this raises at runtime.

Apply:

@@
-class GemmMFMA(GemmBase):
+from tvm.ir.base import Node  # at top with other imports
+
+class GemmMFMA(GemmBase):
+    def __init__(self, gemm_node: Node):
+        # Store GemmPy node for GemmBase accessors (A/B/C, shapes, policy, etc.)
+        self.gemm_node = gemm_node
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class GemmMFMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
from tvm.ir.base import Node
class GemmMFMA(GemmBase):
def __init__(self, gemm_node: Node):
# Store GemmPy node for GemmBase accessors (A/B/C, shapes, policy, etc.)
self.gemm_node = gemm_node
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 13 to 16, the GemmMFMA class
lacks an __init__ that accepts the gemm_node argument which causes a TypeError
when instantiated via impl_class(self); add a constructor def __init__(self,
gemm_node): that calls the base initializer (super().__init__(gemm_node)) and
performs any GemmMFMA-specific setup (e.g., store gemm_node or initialize
policy/state) so instances can be created with the GemmPy node without error.

Comment on lines 16 to 20
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
print(f"thread_nums: {thread_nums}, m_warp: {m_warp}, n_warp: {n_warp}")

warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Guard warp tiling assumptions.

Ensure M and N are divisible by warp partitions; otherwise emitter math breaks.

@@ def infer_layout(self, target: Target, thread_nums: int):
-        m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
-                                                            False)
-        print(f"thread_nums: {thread_nums}, m_warp: {m_warp}, n_warp: {n_warp}")
+        m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
+        assert self.M % m_warp == 0 and self.N % n_warp == 0, (
+            f"M ({self.M}) % m_warp ({m_warp}) or N ({self.N}) % n_warp ({n_warp}) != 0"
+        )

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (2)
examples/plot_layout/fragment_mfma_load_a.py (1)

6-11: Use B‑specific transforms for matrix='B'.

B paths currently reuse A transforms; this yields wrong thread/local mapping for MFMA B.

 from tilelang.intrinsics.mfma_layout import (
     shared_16x4_to_local_64x1_layout_A,
+    shared_4x16_to_local_64x1_layout_B,
     shared_16x16_to_local_64x4_layout_A,
+    shared_16x16_to_local_64x4_layout_B,
     shared_16x32_to_local_64x8_layout_A,
+    shared_16x32_to_local_64x8_layout_B,
     shared_16x64_to_local_64x16_layout_A,
+    shared_16x64_to_local_64x16_layout_B,
 )
@@
     if k_dim == 4:
         transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
-        transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
+        transform_func_sr_b = shared_4x16_to_local_64x1_layout_B
     elif k_dim == 16:
         transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
-        transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
+        transform_func_sr_b = shared_16x16_to_local_64x4_layout_B
     elif k_dim == 32:
         transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
-        transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
+        transform_func_sr_b = shared_16x32_to_local_64x8_layout_B
     elif k_dim == 64:
         transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
-        transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
+        transform_func_sr_b = shared_16x64_to_local_64x16_layout_B

Also applies to: 49-61

tilelang/tileop/gemm/gemm_mfma.py (1)

13-14: Constructor missing: wire GemmPy node into GemmMFMA.

GemmPy instantiates impl_class(self); without init(gemm_node) the subclass can’t access properties via GemmBase.

+from tvm.ir.base import Node
@@
-class GemmMFMA(GemmBase):
+class GemmMFMA(GemmBase):
+
+    def __init__(self, gemm_node: Node):
+        self.gemm_node = gemm_node
🧹 Nitpick comments (5)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (1)

54-68: Plumb an explicit AMD target through helpers to reliably hit MFMA. Gate verbose IR prints.

“auto” may pick a non‑AMD backend on CI/dev hosts, bypassing MFMA. Add a target parameter to run_gemm_* and pass it into tilelang.compile; also add a verbose flag for RR to avoid noisy logs by default.

@@ def run_gemm_ss(
-    num_stages=3,
-    num_threads=256,
+    num_stages=3,
+    num_threads=256,
+    target="auto",
 ):
@@
-    kernel = tilelang.compile(
+    kernel = tilelang.compile(
         program,
         out_idx=[2],
+        target=target,
         pass_configs={
             tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
             tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
         })
@@ def run_gemm_rs(
-    num_stages=3,
-    num_threads=256,
+    num_stages=3,
+    num_threads=256,
+    target="auto",
 ):
@@
-    kernel = tilelang.compile(
+    kernel = tilelang.compile(
         program,
         out_idx=[2],
+        target=target,
         pass_configs={
             tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
             tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
         })
@@ def run_gemm_sr(
-    num_stages=3,
-    num_threads=256,
+    num_stages=3,
+    num_threads=256,
+    target="auto",
 ):
@@
-    kernel = tilelang.compile(
+    kernel = tilelang.compile(
         program,
         out_idx=[2],
+        target=target,
         pass_configs={
             tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
             tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
         })
@@ def run_gemm_rr(
-    num_stages=3,
-    num_threads=256,
+    num_stages=3,
+    num_threads=256,
+    target="auto",
+    verbose=False,
 ):
@@
-    kernel = tilelang.compile(
+    kernel = tilelang.compile(
         program,
         out_idx=[2],
+        target=target,
         pass_configs={
             tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
             tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
         })
-    print(program)
-
-    print(kernel.get_kernel_source())
+    if verbose:
+        print(program)
+        print(kernel.get_kernel_source())

To verify MFMA is selected at runtime, run the tests with an explicit AMD target (e.g., “rocm”/“amdgpu” as supported by your TVM build) and confirm the GemmInst resolves to MFMA in the lowered IR or logs.

Also applies to: 85-91, 180-194, 211-217, 302-316, 333-339, 428-442, 459-465

examples/plot_layout/fragment_mfma_load_a.py (1)

18-22: Docstring nit: this builds a load layout, not a store layout.

Rename “storing MFMA results” → “loading MFMA operands” for clarity.

-    Create a layout function for storing MFMA results into a fragment buffer.
+    Create a layout function for loading MFMA operands into a fragment buffer.
tilelang/tileop/gemm/gemm_mfma.py (1)

61-79: Strengthen K‑chunk check; optionally validate layout_map presence.

Add modulus check on K and touch layout_map to avoid ARG002.

     def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
         m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
                                                             False)
@@
             thread_var=thread_var,
         )
+        # Ensure expected entries exist (A,B,C) – helps catch plumbing mistakes
+        _ = layout_map
@@
-        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+        assert (block_K % micro_size_k) == 0, (
+            f"block_K ({block_K}) must be a multiple of micro_size_k ({micro_size_k})"
+        )

Also applies to: 80-91

tilelang/intrinsics/mfma_macro_generator.py (2)

455-456: Drop redundant local import.

Already imported at module top; keep one place.

-        from tilelang.utils import is_fragment

433-454: Docstring clarity: load vs store.

These helpers generate load/store layouts respectively; make wording explicit.

-        Create a layout function for storing MFMA results into a fragment buffer.
+        Create a layout function describing MFMA load layout for a fragment buffer.
@@
-        Create a layout function for storing MFMA results into a fragment buffer.
+        Create a layout function describing MFMA store layout for a fragment buffer.

Also applies to: 575-594

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 492a0c4 and 62a3e32.

📒 Files selected for processing (5)
  • examples/plot_layout/fragment_mfma_load_a.py (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (1 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (11 hunks)
  • tilelang/tileop/gemm/__init__.py (3 hunks)
  • tilelang/tileop/gemm/gemm_mfma.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
tilelang/tileop/gemm/gemm_mfma.py (7)
tilelang/tileop/gemm/gemm_base.py (10)
  • GemmBase (12-120)
  • policy (119-120)
  • M (34-35)
  • N (38-39)
  • trans_A (46-47)
  • trans_B (50-51)
  • chunk (63-64)
  • A (67-68)
  • B (71-72)
  • C (75-76)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/intrinsics/mfma_macro_generator.py (8)
  • MatrixCoreIntrinEmitter (35-642)
  • make_mfma_store_layout (575-642)
  • make_mfma_load_layout (433-573)
  • ldmatrix_a (254-292)
  • ldmatrix_a (699-772)
  • ldmatrix_b (294-337)
  • ldmatrix_b (774-849)
  • mfma (339-382)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (31-49)
tilelang/tileop/gemm/__init__.py (2)
  • infer_layout (75-79)
  • lower (81-85)
tilelang/language/allocate.py (1)
  • alloc_local (42-53)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py (10)
tilelang/language/allocate.py (2)
  • alloc_shared (24-39)
  • alloc_fragment (56-67)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/gemm.py (1)
  • gemm_v2 (215-426)
tilelang/jit/__init__.py (1)
  • compile (30-79)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-144)
tilelang/utils/tensor.py (1)
  • TensorSupplyType (11-18)
tilelang/profiler/__init__.py (1)
  • assert_allclose (77-146)
tilelang/language/annotations.py (1)
  • annotate_layout (25-36)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/tileop/gemm/__init__.py (1)
tilelang/tileop/gemm/gemm_mfma.py (1)
  • GemmMFMA (13-215)
tilelang/intrinsics/mfma_macro_generator.py (6)
tilelang/intrinsics/utils.py (1)
  • mfma_store_index_map (85-86)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/intrinsics/mfma_layout.py (16)
  • shared_16x4_to_local_64x1_layout_A (6-8)
  • shared_4x16_to_local_64x1_layout_B (17-19)
  • shared_16x16_to_local_64x4_layout_A (46-49)
  • shared_16x16_to_local_64x4_layout_B (58-61)
  • shared_16x32_to_local_64x8_layout_A (88-91)
  • shared_16x32_to_local_64x8_layout_B (100-103)
  • shared_16x64_to_local_64x16_layout_A (112-115)
  • shared_16x64_to_local_64x16_layout_B (124-127)
  • thread_id_shared_access_64x1_to_16x4_layout_A (11-14)
  • thread_id_shared_access_64x1_to_4x16_layout_B (22-25)
  • thread_id_shared_access_64x4_to_16x16_layout_A (40-43)
  • thread_id_shared_access_64x4_to_16x16_layout_B (52-55)
  • thread_id_shared_access_64x8_to_16x32_layout_A (82-85)
  • thread_id_shared_access_64x8_to_16x32_layout_B (94-97)
  • thread_id_shared_access_64x16_to_16x64_layout_A (106-109)
  • thread_id_shared_access_64x16_to_16x64_layout_B (118-121)
tilelang/intrinsics/mma_macro_generator.py (7)
  • get_store_index_map (160-166)
  • get_thread_binding (152-158)
  • stmatrix (398-451)
  • forward_thread (536-541)
  • forward_thread (629-644)
  • forward_index (543-548)
  • forward_index (646-658)
tilelang/language/kernel.py (4)
  • get_thread_binding (171-176)
  • get_thread_binding (306-310)
  • KernelLaunchFrame (95-226)
  • Current (135-141)
tilelang/language/tir/entry.py (1)
  • macro (66-117)
examples/plot_layout/fragment_mfma_load_a.py (4)
tilelang/intrinsics/utils.py (1)
  • get_mma_micro_size (89-109)
tilelang/intrinsics/mfma_layout.py (4)
  • shared_16x4_to_local_64x1_layout_A (6-8)
  • shared_16x16_to_local_64x4_layout_A (46-49)
  • shared_16x32_to_local_64x8_layout_A (88-91)
  • shared_16x64_to_local_64x16_layout_A (112-115)
tilelang/tools/plot_layout.py (1)
  • plot_layout (4-207)
tilelang/layout/fragment.py (2)
  • repeat (123-144)
  • replicate (146-160)
🪛 Ruff (0.14.1)
tilelang/tileop/gemm/gemm_mfma.py

58-59: Avoid specifying long messages outside the exception class

(TRY003)


61-61: Unused method argument: layout_map

(ARG002)


202-203: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mfma_macro_generator.py

485-485: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Avoid specifying long messages outside the exception class

(TRY003)


557-557: Avoid specifying long messages outside the exception class

(TRY003)


571-571: Avoid specifying long messages outside the exception class

(TRY003)

examples/plot_layout/fragment_mfma_load_a.py

62-62: Avoid specifying long messages outside the exception class

(TRY003)


83-83: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (2)
tilelang/tileop/gemm/__init__.py (2)

118-126: MFMA mapping looks good.

Routing MFMA to GemmMFMA is correct; TCGEN5MMA guarded with NotImplemented is fine for now.


30-35: No issues found. Enums are properly synchronized across FFI boundary.

The Python GemmInst enum (MMA=0, WGMMA=1, TCGEN5MMA=2, MFMA=3) matches the C++ GemmInst enum (kMMA, kWGMMA, kTCGEN5MMA, kMFMA) with identical implicit numeric values. The FFI call at line 103 receives the native enum value directly and wraps it in the Python IntEnum—no mis-dispatch risk exists because the numeric values align. The code comment at line 29 confirms intentional synchronization.

Likely an incorrect or invalid review comment.

Comment on lines +361 to +362
print(a_local_stride, b_local_stride)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove debug prints in hot paths.

These prints execute per‑tile and will severely slow kernels and spam logs.

-        print(a_local_stride, b_local_stride)
+        # debug print removed
@@
-                print(self.a_preshuffle)
+                # debug print removed

Also applies to: 762-762

🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around lines 361-362 (and also at
line 762), remove the debug print statements that emit a_local_stride and
b_local_stride because they run in hot per-tile code paths; replace them with
either conditional debug logging gated by a verbose/debug flag or remove them
entirely so no I/O occurs during normal kernel generation. Ensure any needed
diagnostic info can be toggled via a logger.debug(...) behind a
runtime/compile-time flag rather than unguarded print calls.

Comment on lines +472 to +484
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Correct B‑operand load transforms in make_mfma_load_layout.

Use B‑specific transforms; current A‑ones misplace lanes/local IDs for matrix B.

         if k_dim == 4:
             transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
-            transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
+            transform_func_sr_b = shared_4x16_to_local_64x1_layout_B
         elif k_dim == 16:
             transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
-            transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
+            transform_func_sr_b = shared_16x16_to_local_64x4_layout_B
         elif k_dim == 32:
             transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
-            transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
+            transform_func_sr_b = shared_16x32_to_local_64x8_layout_B
         elif k_dim == 64:
             transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
-            transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
+            transform_func_sr_b = shared_16x64_to_local_64x16_layout_B
@@
-        else:
-            raise ValueError(f"Unsupported matrix {matrix}")
+        else:
+            raise ValueError(f"Unsupported matrix {matrix}")

Also applies to: 493-501

🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around lines 472 to 484 (and
similarly lines 493 to 501), the code assigns A-specific shared-to-local
transform functions to transform_func_sr_b which misplaces lanes/local IDs for
matrix B; replace the transform_func_sr_b assignments with the corresponding
B-specific transform functions for each k_dim case (e.g., use
shared_16x4_to_local_64x1_layout_B for k_dim==4,
shared_16x16_to_local_64x4_layout_B for k_dim==16,
shared_16x32_to_local_64x8_layout_B for k_dim==32, and
shared_16x64_to_local_64x16_layout_B for k_dim==64) so B uses the correct
lane/local ID mapping.

Comment on lines +15 to +31
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard warp tiling divisibility before creating emitter math.

Ensure per‑warp tile counts are multiples of MFMA micro sizes to avoid silent floor division in emitter.

     def infer_layout(self, target: Target, thread_nums: int):
         m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
                                                             False)
         warp_row_tiles = int(self.M // m_warp)
         warp_col_tiles = int(self.N // n_warp)
         mfma_emitter = MatrixCoreIntrinEmitter(
@@
             chunk=self.chunk,
         )
+        # Validate divisibility: warp tiles must be integer multiples of micro tile sizes
+        assert warp_row_tiles % mfma_emitter.micro_size_x == 0, (
+            f"warp_row_tiles ({warp_row_tiles}) not divisible by micro_size_x "
+            f"({mfma_emitter.micro_size_x})"
+        )
+        assert warp_col_tiles % mfma_emitter.micro_size_y == 0, (
+            f"warp_col_tiles ({warp_col_tiles}) not divisible by micro_size_y "
+            f"({mfma_emitter.micro_size_y})"
+        )
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 15 to 31, you compute
warp_row_tiles and warp_col_tiles then instantiate MatrixCoreIntrinEmitter
without verifying that those per-warp tile counts are multiples of the MFMA
micro-tile sizes; add a guard that fetches the MFMA micro-tile dimensions (e.g.,
from self.policy (mfma_micro_m, mfma_micro_n) or from MatrixCoreIntrinEmitter if
that API exists), check warp_row_tiles % mfma_micro_m == 0 and warp_col_tiles %
mfma_micro_n == 0, and if either check fails raise a clear ValueError explaining
which dimension is not divisible and what the required micro-tile size is,
before creating the emitter.

Comment on lines +186 to +201
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""

for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)

# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename duplicate inner func to reflect RR path.

Avoid two “_gemm_rsr” definitions; name RR kernel accordingly.

-            def _gemm_rsr() -> None:
+            def _gemm_rr() -> None:
@@
-            return _Simplify(_gemm_rsr, inline_let=True)
+            return _Simplify(_gemm_rr, inline_let=True)
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 186 to 201, there is a
duplicate inner function named "_gemm_rsr" used for the RR path; rename this
inner function to a distinct RR-specific name (e.g., "_gemm_rr") and update its
decorator/definition and the subsequent return call (change _Simplify(_gemm_rsr,
...) to _Simplify(_gemm_rr, ...)). Also scan and update any local references to
that inner function within this block so they point to the new name.

@LeiWang1999
Copy link
Member

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. What shall we delve into next?

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@LeiWang1999 LeiWang1999 merged commit 60567ba into tile-ai:main Oct 28, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants