Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 14, 2025

TODO Items

  • Introduce T.alloc_descriptor to create a static descriptor, allowing only start_address updates within tiled GEMM.
  • Add support for additional data types: int8, fp8, tf32.
  • Implement gemm_rs.

Summary by CodeRabbit

  • New Features

    • WMMA-like Tensor Core support with new intrinsics, descriptor primitives, descriptor allocation helper, and expanded PTX/type tooling.
    • New swizzled layout builders (WGMMA, full/half/quarter-bank, linear) plus layout/fragment equality and improved representations.
  • API Changes

    • GEMM layout/creation and lowering surfaces accept a boolean K-orientation flag and layout-map argument; GEMM backend selection is now dynamic and includes a WGMMA backend.
  • Chores

    • CI: pytest now clears cache before running.

- Added support for the WGMMA intrinsic in the TileLang framework, enabling efficient matrix multiplication on newer architectures.
- Refactored GEMM layout functions to accept a boolean parameter for K dimension handling, improving flexibility in layout generation.
- Updated layout inference logic to accommodate new WGMMA configurations and ensure compatibility with existing GEMM operations.
- Enhanced Python bindings for layout functions, allowing for better integration and usability in user-defined operations.
- Improved documentation for layout functions and GEMM operations to clarify usage and parameters.

These changes enhance the performance and usability of GEMM operations, particularly for advanced architectures, while maintaining backward compatibility with existing implementations.
…bility

- Improved code formatting across multiple files for better readability, including consistent indentation and line breaks.
- Updated layout function signatures to enhance clarity, particularly in `gemm_layouts.cc`, `layout.cc`, and `layout.h`.
- Refactored lambda functions in `builtin.cc` and `gemm_py.cc` for improved structure and maintainability.
- Enhanced comments and documentation in layout-related files to clarify usage and parameters.

These changes contribute to a cleaner codebase and improved maintainability of layout functions in the TileLang framework.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 14, 2025

Caution

Review failed

The pull request is closed.

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Replace integer K-factor parameters with boolean k_inner in GEMM layout APIs; add WGMMA support (intrinsics, descriptors, PTX/WGMMA codegen, TL templates, Python FFI/bindings); extend layout FFI/equality and swizzle helpers; add GemmWGMMA backend and runtime GEMM instruction selection; treat .descriptor scope specially in transforms.

Changes

Cohort / File(s) Summary
GEMM layout API (kfactor → k_inner)
src/layout/gemm_layouts.cc, src/layout/layout.h, src/op/gemm.cc
Replace integer kfactor/kPack with boolean k_inner (defaults added); update branching, callers, and error messages across Volta/Hopper/SM100/64-bit paths.
Layout FFI & Python helpers
src/layout/layout.cc, tilelang/layout/swizzle.py, tilelang/layout/__init__.py, tilelang/layout/layout.py, tilelang/layout/fragment.py
Add FFI bindings tl.Layout_is_equal, tl.Fragment_is_equal; extend make_swizzled_layout (k_major, allow_pad); add make_wgmma_swizzled_layout, bank/linear helpers; expose Layout/Fragment equality, repr, and forward-index accessors.
New builtins / intrinsics
src/op/builtin.cc, src/op/builtin.h
Add ptx_wgmma_ss, ptx_wgmma_rs, initialize_descriptor, increase_descriptor_offset; mark some intrinsics TVM_DLL.
GemmPy dispatch & API surface
src/op/gemm_py.cc, src/op/gemm_py.h, tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_mma.py, tilelang/tileop/gemm/gemm_base.py
Add FFI tl.GemmPyGemmInst; expand gemm_py.lower/GemmMMA.lower to accept layout_map; introduce GemmInst enum and runtime instruction selection via FFI; change clear_accum typing and add compiler-safety returns.
New GemmWGMMA backend & intrinsics emitter
tilelang/tileop/gemm/gemm_wgmma.py, tilelang/intrinsics/wgmma_macro_generator.py, tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_mma.py
Add GemmWGMMA backend and TensorCoreIntrinEmitter with swizzle-aware WGMMA-like emission; infer_layout/lower accept/use layout_map; dispatch between MMA and WGMMA implementations.
CUDA/PTX codegen: WGMMA & descriptors
src/target/codegen_cuda.cc, src/target/ptx.cc, src/target/ptx.h
Implement emission for ptx_wgmma_ss/rs and descriptor intrinsics; add PTX DataType enum, WGMMA config validation, operand assembly helpers and PrintWGMMAAssembly.
CUDA TL templates & WGMMA header
src/tl_templates/cuda/common.h, src/tl_templates/cuda/gemm.h, src/tl_templates/cuda/instruction/wgmma.h
Add tl::DataType, GmmaDescriptor union and initialize/increase helpers; include wgmma.h in SM90+ path; implement WGMMA instruction specializations and dispatch.
Descriptor handling in transforms & storage rewrite
src/transform/lower_device_storage_access_info.cc, src/transform/storage_rewrite.cc
Treat .descriptor scope as non-special: skip memory-info population, exclude from special-tag merging and lowering adjustments.
TileLang language & TIR wrappers
tilelang/language/builtin.py, tilelang/language/allocate.py, tilelang/language/__init__.py, tilelang/language/ast/ir.py, tilelang/language/tir/ir.py, tilelang/language/tir/op.py
Export alloc_descriptor; add initialize_descriptor/increase_descriptor_offset wrappers; add PTX WMMA wrappers ptx_wgmma_ss/rs to IR/TIR/op surfaces.
TileLang layout Python utilities & fragment API
tilelang/layout/fragment.py, tilelang/layout/layout.py
Add Fragment.is_equal, remove module-level make_swizzled_layout in fragment module; add Layout.is_equal, get_forward_index, and improved __repr__.
Gemm Python tests & examples
examples/deepseek_v32/*, examples/flash_attention/*
Add check_correctness flags, adjust SKV defaults and test call sites; parameterize several example mains to pass explicit runtime args.
Misc / toolchain & CI
.clang-tidy, tilelang/language/utils.py, .github/workflows/*.yml
Disable a clang-tidy check, remove a debug print, and add --cache-clear to pytest steps in CI workflows.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant GemmPy as GemmPy (tileop)
  participant FFI as _ffi_api.GemmPyGemmInst
  participant Impl as Impl (GemmMMA / GemmWGMMA)
  participant Lower as TIR Lowering

  User->>GemmPy: infer_layout(target, thread_nums)
  GemmPy->>FFI: GemmPyGemmInst(self, thread_nums, target)
  FFI-->>GemmPy: GemmInst (MMA/WGMMMA/MFMA)
  GemmPy->>Impl: select implementation class
  GemmPy->>Impl: infer_layout(...)
  Impl-->>GemmPy: A/B/C layouts

  User->>GemmPy: lower(layout_map, target, thread_nums, thread_var)
  GemmPy->>Impl: lower(layout_map, target, thread_nums, thread_var)
  Impl->>Lower: build PrimFunc(s)
  Lower-->>User: Stmt / PrimFunc
Loading
sequenceDiagram
  autonumber
  participant TIR as TIR (tilelang.tir.op)
  participant Builtin as tl Builtins
  participant CGen as CUDA Codegen
  participant PTX as PTX Utils

  TIR->>Builtin: call tl.ptx_wgmma_ss/rs(...)
  Builtin-->>CGen: CallNode(op=ptx_wgmma_ss/rs, args)
  CGen->>PTX: PrintWGMMAAssembly(shape, layouts, dtypes, descs, offsets, ...)
  PTX-->>CGen: PTX assembly string
  CGen-->>TIR: Emit inline asm
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

enhancement

Suggested reviewers

  • chengyupku

Poem

I nibble code beneath the moonlit sky,
K_inner hops, descriptors held high.
WGMMA hums in tiny rows,
Swizzles weave where shared memory flows.
A rabbit cheers — GEMM takes flight 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly describes the addition of WGMMA support for the T.gemm_v2 TileOp, matching the main objectives of the pull request. It is concise, specific, and immediately indicates the primary feature being implemented.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4cdd131 and 8783aad.

📒 Files selected for processing (1)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…GMMA

- Introduced new TileLang builtins `initialize_descriptor` and `increase_descriptor_offset` to facilitate descriptor management for WGMMA operations.
- Updated `builtin.cc` and `builtin.h` to define and document the new builtins, enhancing the framework's capabilities for descriptor handling.
- Modified `codegen_cuda.cc` and `ptx.cc` to integrate the new builtins into the code generation process, ensuring proper assembly generation for WGMMA operations.
- Enhanced the `GemmWGMMA` class to utilize the new descriptor functionalities, improving the efficiency of matrix multiplication operations.
- Updated related tests and documentation to reflect the new features and ensure comprehensive coverage.

These changes enhance the TileLang framework's support for advanced matrix operations on newer architectures, improving performance and usability.
- Enhanced code formatting across multiple files for better readability, including consistent indentation and line breaks.
- Updated function signatures and comments in `builtin.h`, `codegen_cuda.cc`, and `ptx.cc` to improve clarity.
- Refactored descriptor initialization and offset manipulation functions in `builtin.py` and `wgmma_macro_generator.py` for improved structure.
- Cleaned up unnecessary whitespace and improved alignment in `common.h` and `allocate.py`.

These changes contribute to a cleaner and more maintainable codebase in the TileLang framework.
- Updated the subproject commit for `cutlass` to indicate a dirty state.
- Refactored the `UpdateAnalyzer` function in `layout.cc` to call `LayoutNode::getVarMap()` instead of `getVarMap()`, improving clarity and ensuring proper context for variable mapping.

These changes enhance the maintainability and clarity of the layout handling in the TileLang framework.
@LeiWang1999 LeiWang1999 marked this pull request as ready for review September 16, 2025 08:12
cursor[bot]

This comment was marked as outdated.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/customize.py (1)

193-197: Bug: extent computation uses list max instead of elementwise max

max(src_extent, dst_extent) returns one of the lists lexicographically, not per-dimension. This can mis-size regions and corrupt memory ops.

Apply:

-    extent = max(src_extent, dst_extent)
+    # elementwise extent unification
+    extent = [T.max(a, b) for a, b in zip(src_extent, dst_extent)]
🧹 Nitpick comments (32)
tilelang/language/__init__.py (1)

45-46: Export looks good; clean up unused noqa.

  • Re-exporting alloc_descriptor is correct and matches allocate.py.
  • Ruff flags Line 45 for unused # noqa: F401 (RUF100). Drop it here (others weren’t flagged in this diff).

Apply:

-    alloc_descriptor,  # noqa: F401
+    alloc_descriptor,
tilelang/layout/layout.py (3)

92-94: Add docstring/return type for get_forward_index for parity with getters.

Minor consistency nit. Consider:

-    def get_forward_index(self):
-        return self.index
+    def get_forward_index(self) -> PrimExpr | list[PrimExpr]:
+        """Return the computed forward index expression(s)."""
+        return self.index

136-146: API parity: provide Pythonic equality too.

Keep is_equal, but also implement eq delegating to it for ergonomic comparisons; keep hash=None to avoid hashing mutable objects.

     def is_equal(self, other: "Layout") -> bool:
         """
         Check if the current layout is equal to another layout.
         """
         return _ffi_api.Layout_is_equal(self, other)

+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, Layout):
+            return NotImplemented
+        return self.is_equal(other)
+    __hash__ = None

147-148: repr: avoid huge dumps when vars/index grow.

Consider truncating sequences for readability in logs.

-        return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>"
+        fv = self.get_forward_vars()
+        fi = self.get_forward_index()
+        def _short(x): 
+            s = str(x)
+            return s if len(s) <= 120 else s[:117] + "..."
+        return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {_short(fv)} -> {_short(fi)}>"
src/op/builtin.h (2)

164-185: Fix WGMMA doc comments to match actual RS/SS operand forms.

Current comment for ptx_wgmma_rs refers to A_descriptor; RS variant uses A_buf (regular pointer/buffer) per Python wrappers. Also B_offset is documented as Var (should be PrimExpr). Update comments to avoid API confusion.

- *  void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
- * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
- * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
- * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
+ *  void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
+ * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
+ * StringImm accum_dtype_abbrv, Var A_buf, PrimExpr A_offset, Var
+ * B_descriptor, PrimExpr B_offset, Var C_data, PrimExpr C_offset, bool scale_out, bool
  * scale_in_a, bool scale_in_b);

344-361: Descriptor intrinsic docs: correct operation wording.

increase_descriptor_offset increments the offset; the block comment says “setting the start address.” Align wording to avoid misuse.

- * \brief tilelang intrinsic for setting the start address of a descriptor
- * buffer for wgmma/utcmma.
+ * \brief tilelang intrinsic for increasing the offset of a descriptor
+ * buffer for wgmma/utcmma.
tilelang/language/tir/op.py (2)

1064-1104: Nit: Docstring says “wmma” but this is WGMMA; also consider briefly documenting operand order.

Purely cosmetic; helps future maintainers and avoids confusion with WMMA.

-    """TVM intrinsic for ptx tensor core wmma instructions
+    """TVM intrinsic for PTX warp-group MMA (WGMMA) instructions
+    Operand order: prefix, trans_a, trans_b, a/b/accum dtype abbrvs,
+    A_desc, A_offset, B_desc, B_offset, C_data, C_offset, scale_out, scale_in_a, scale_in_b.

1106-1144: Nit: Same WGMMA terminology/doc tweak as above; otherwise wrapper matches builtin.

LGTM functionally; mirrors 15-arg registration.

-    return call_intrin(
+    # PTX warp-group MMA (WGMMA) RS variant: A from register, B from descriptor.
+    return call_intrin(
tilelang/layout/__init__.py (1)

6-13: Remove unused “noqa: F401” directives or enable F401 in Ruff config.

Ruff flags these as unused (RUF100). Either drop them or configure Ruff to honor F401.

-from .swizzle import (
-    make_swizzled_layout,  # noqa: F401
-    make_wgmma_swizzled_layout,  # noqa: F401
-    make_full_bank_swizzled_layout,  # noqa: F401
-    make_half_bank_swizzled_layout,  # noqa: F401
-    make_quarter_bank_swizzled_layout,  # noqa: F401
-)
-from .gemm_sp import make_metadata_layout  # noqa: F401
+from .swizzle import (
+    make_swizzled_layout,
+    make_wgmma_swizzled_layout,
+    make_full_bank_swizzled_layout,
+    make_half_bank_swizzled_layout,
+    make_quarter_bank_swizzled_layout,
+)
+from .gemm_sp import make_metadata_layout
tilelang/tileop/gemm/gemm_mma.py (2)

60-77: Use layout_map (or underscore the arg) to avoid ARG002 and keep parity with WGMMA.

Mirror WGMMA: when provided, feed A/B shared layouts into the emitter; otherwise prefix arg as _layout_map to appease linters.

-    def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
+    def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
@@
         mma_emitter = TensorCoreIntrinEmitter(
@@
             thread_var=thread_var,
         )
+
+        # Optional: honor externally inferred layouts if present (parity with WGMMA)
+        if self.A in layout_map:
+            mma_emitter._assign_a_shared_layout(layout_map[self.A])
+        if self.B in layout_map:
+            mma_emitter._assign_b_shared_layout(layout_map[self.B])

90-91: Replace assert with explicit validation (and check divisibility).

Python asserts can be stripped with -O; raise a ValueError and also enforce block_K % micro_size_k == 0 to match loop step.

-        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
+        if block_K < micro_size_k or (block_K % micro_size_k) != 0:
+            raise ValueError(
+                f"Invalid K tile: block_K={block_K}, micro_size_k={micro_size_k} "
+                "(must be >= and divisible)."
+            )
src/op/gemm.cc (1)

45-48: Documentation inconsistency between comment and implementation.

The documentation states that kPack must be 1, but the implementation at lines 71-73 allows values of both 1 and 2. This creates confusion about the actual requirements.

Either update the documentation to reflect that kPack can be 1 or 2, or enforce the restriction that it must be 1:

- * @note If `kPack` is provided it must be 1; otherwise the constructor
- *       fails with an ICHECK (runtime assertion). No other validation is
- *       performed here.
+ * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
+ *       fails with an ICHECK (runtime assertion). No other validation is
+ *       performed here.
tilelang/language/builtin.py (1)

375-375: Consider moving error messages to exception classes.

While not critical, defining error messages inside exception classes improves maintainability and reusability.

Consider creating custom exception classes:

class InvalidDescriptorTypeError(TypeError):
    """Raised when descriptor is not a Buffer or BufferLoad."""
    def __init__(self):
        super().__init__("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")

class InvalidDescriptorShapeError(ValueError):
    """Raised when descriptor is not a 1D buffer of size 1."""
    def __init__(self):
        super().__init__("Descriptor must be a 1D buffer of size 1.")

Also applies to: 378-378, 401-401, 404-404

tilelang/layout/swizzle.py (2)

23-23: Use Optional type annotation for nullable parameter.

The continuity parameter can be None but isn't typed as Optional.

Add proper type annotation:

+from typing import Optional
+
 def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
-                               continuity: int = None,
+                               continuity: Optional[int] = None,
                                k_major: bool = True):

54-54: Consider more descriptive error messages.

The error messages could be more helpful by describing what arguments are expected.

Improve error messages:

-        raise ValueError(f"Invalid arguments: {args}")
+        raise ValueError(f"Expected either a single buffer or (stride, continuous, element_size), got {len(args)} arguments: {args}")

Also applies to: 79-79, 104-104

src/layout/gemm_layouts.cc (2)

573-574: Parameter name inconsistency with function documentation.

The parameter name k_inner doesn't match the documentation's description which refers to whether the "K dimension is in the inner loop". Consider renaming to k_is_inner or is_k_inner for better clarity.


532-541: Missing implementation of k_major parameter in Volta layout.

The function signature was updated to use bool k_inner but the implementation still uses k_inner directly as a boolean flag without considering the k-major semantics that the rest of the codebase expects.

Based on the pattern in makeGemmABLayout and makeGemmABLayoutHopper, you may want to verify that this implementation correctly handles the k-major/k-inner semantics.

tilelang/tileop/gemm/gemm_wgmma.py (1)

110-111: Verify RS implementation for A in fragment and B in shared.

The _gemm_rsr function name and docstring mention loading data from "shared buffers A_shared and B_shared", but the RS variant should have A in registers/fragments. The docstring appears to be copied from the SS variant.

Apply this diff to fix the docstring:

             @T.prim_func
             def _gemm_rsr() -> None:
                 """
-                The inner macro that loads data from shared buffers A_shared and
-                B_shared into local fragments, then issues Tensor Core mma ops,
-                accumulating into C_local.
+                The inner macro that uses data from local fragment A_local and
+                shared buffer B_shared, then issues Tensor Core mma ops,
+                accumulating into C_local.
                 """
                 mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum)
tilelang/intrinsics/wgmma_macro_generator.py (3)

104-104: Remove unused parameter n_dim from method signature.

The method _initialize_wgmma_prefix has an unused parameter n_dim=16 that shadows the instance variable self.n_dim.

Apply this diff to fix:

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
         inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles

133-143: Extract swizzle mode detection to centralized utility.

The _determinate_swizzle_mode method performs layout equality checks against multiple swizzle patterns. This logic could be refactored into a centralized layout utility to avoid duplication if similar detection is needed elsewhere.

Would you like me to help create a centralized swizzle mode detection utility that could be reused across the codebase?


345-353: Improve error message specificity for unsupported dtypes.

The error message for unsupported dtypes could be more informative by including the actual bit width.

Apply this diff:

         else:
-            raise ValueError(f"Unsupported dtype {dtype}")
+            raise ValueError(f"Unsupported dtype {dtype} with {dtype_bits} bits for MMA load layout")
src/layout/layout.cc (1)

495-505: Parameter naming inconsistency in swizzled layout creation.

The make_swizzled_layout function uses allow_pad to choose between makeGemmABLayout and makeGemmABLayoutHopper. The parameter name allow_pad doesn't clearly convey that it's selecting between different hardware layout strategies (standard vs Hopper).

Consider renaming to be more descriptive:

       .def("tl.make_swizzled_layout",
            [](int stride, int continuous, int element_size, bool k_inner,
-              bool allow_pad = true) {
-             if (allow_pad) {
+              bool use_hopper_layout = false) {
+             if (!use_hopper_layout) {
                return makeGemmABLayout(stride, continuous, continuous,
                                        element_size, k_inner);
              } else {
                return makeGemmABLayoutHopper(stride, continuous, continuous,
                                              element_size, k_inner);
              }
            })
tilelang/tileop/gemm/__init__.py (2)

82-98: Add parameter validation for thread_nums.

While the FFI call handles the selection logic, it would be good to validate that thread_nums is positive before passing it to the FFI.

 def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst:
     """Select the appropriate GEMM instruction based on target and thread configuration.
     
     The selection logic follows this priority:
     1. WGMMA for Hopper architecture with sufficient matrix size and warp count
     2. MFMA for CDNA (AMD) architecture
     3. MMA for CUDA architecture
     4. Fallback to MMA for other cases
     
     Args:
         thread_nums: Number of threads in the block
         target: Target architecture
         
     Returns:
         GemmInst: The selected GEMM instruction type
     """
+    if thread_nums <= 0:
+        raise ValueError(f"thread_nums must be positive, got {thread_nums}")
     return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target))

118-118: Consider using a more specific exception message.

The error message could be more informative by including what implementations are available.

-            raise NotImplementedError("MFMA is not implemented")
+            raise NotImplementedError("MFMA is not implemented. Available implementations: MMA, WGMMA")
src/target/ptx.cc (2)

1053-1168: Complex operand generation but missing a_is_shared validation for register path.

The function generates WGMMA operands correctly but doesn't validate that when a_is_shared is false, the operation is actually supported for register-based A operands.

Consider adding validation:

 inline std::tuple<std::string, std::string, std::string, std::string>
 GetWGMMAOperands(int m, int n, int k, ptx::DataType dtype_a,
                  ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse,
                  bool a_is_shared) {
+  // WGMMA with register-based A operand has limitations
+  if (!a_is_shared) {
+    // Add any specific validation for register-based A operands if needed
+    // based on NVIDIA documentation
+  }
   std::stringstream templates, inputs, outputs, predicate;

1263-1266: Consider extracting predicate setup as a constant.

The predicate setup code could be defined as a constant string for better maintainability.

+  constexpr const char* PREDICATE_SETUP = 
+      "{.reg .pred p;\n"
+      "setp.ne.b32 p, {predicate}, 0;\n";
+
   std::string asm_code = R"(
   {
     __asm__ __volatile__(
-      "{.reg .pred p;\n"
-      "setp.ne.b32 p, {predicate}, 0;\n"
+      ")" PREDICATE_SETUP R"(
       "wgmma.mma_async{.sparse}.sync.aligned{.shape}{.dtype}{.atype}{.btype}"
       "{templates};\n}"
       : {outputs}
       : {inputs});
   }
 )";
tilelang/language/customize.py (2)

160-179: Deduplicate get_extent; current helper misses BufferLoad and can diverge

You re-implement get_extent but omit BufferLoad handling (supported in tilelang/language/copy.py). Prefer reusing the canonical helper to avoid drift.

Apply:

@@
-    def get_extent(data):
-        """
-        Return the inferred extent (shape) of a buffer-like object.
-        ...
-        """
-        if isinstance(data, Var) and T.has_let_value(data):
-            data = T.get_let_value(data)
-        if isinstance(data, Buffer):
-            return data.shape
-        elif isinstance(data, BufferRegion):
-            return [x.extent for x in data.region]
-        else:
-            return None
-
-    src_extent = get_extent(value)
-    dst_extent = get_extent(dst)
+    src_extent = _get_extent(value)
+    dst_extent = _get_extent(dst)

Also add the import near the top of this file:

from tilelang.language.copy import get_extent as _get_extent

82-105: Honor provided extents in buffer_region_to_tile_region

The extents parameter is only asserted but ignored. This prevents aligning extents when mixing BufferRegion with BufferLoad.

Apply:

@@
-    return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
+    # Override trailing dims with requested extents if provided
+    if extents:
+        region_extents = list(region_extents)
+        region_extents[-len(extents):] = extents
+    return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
tilelang/language/allocate.py (1)

129-136: Add type hints for API clarity

Tighten signature and return type.

Apply:

-def alloc_descriptor(dtype="uint64", scope="local.descriptor"):
+def alloc_descriptor(dtype: str = "uint64", scope: str = "local.descriptor") -> T.Buffer:
     """Allocate a descriptor buffer for wgmma and utcmma.
 
     Returns:
         T.Buffer: A TVM buffer object allocated as a descriptor
     """
     return T.alloc_buffer([1], dtype, scope=scope)
src/layout/layout.h (3)

163-167: Prefer explicit enum over boolean for K placement

A boolean is easy to misuse. Consider a strongly-typed enum (e.g., enum class KPlacement { Inner, Outer };) to make call sites self-documenting and prevent silent int→bool coercions.

Example:

enum class KPlacement { Inner, Outer };
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
                        int element_size, KPlacement k_place = KPlacement::Inner);

168-169: CDNA kPack rename verified — code updated; fix lingering doc/comments

makeGemmABLayoutCDNA (declaration/definition) and callers use int kPack and Python bindings expose "kPack".

  • Update remaining comment references to the old name:
    • src/op/gemm_py.h — comment referencing "k_pack" (around line ~30).
    • src/op/gemm_sp.h — similar comment (around line ~25).

177-178: Volta layout: no functional change required — callers pass explicit k_inner; remove or flip default for clarity.
makeGemmVoltaABLayout(..., bool k_inner = true) is declared in src/layout/layout.h; call sites pass explicit values (src/op/gemm.cc:479, 492).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ae9b706 and eac5433.

📒 Files selected for processing (30)
  • src/layout/gemm_layouts.cc (4 hunks)
  • src/layout/layout.cc (4 hunks)
  • src/layout/layout.h (2 hunks)
  • src/op/builtin.cc (2 hunks)
  • src/op/builtin.h (2 hunks)
  • src/op/gemm.cc (7 hunks)
  • src/op/gemm_py.cc (4 hunks)
  • src/op/gemm_py.h (1 hunks)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/target/ptx.cc (6 hunks)
  • src/target/ptx.h (1 hunks)
  • src/tl_templates/cuda/common.h (4 hunks)
  • src/transform/lower_device_storage_access_info.cc (1 hunks)
  • src/transform/storage_rewrite.cc (2 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
  • tilelang/language/__init__.py (4 hunks)
  • tilelang/language/allocate.py (1 hunks)
  • tilelang/language/ast/ir.py (2 hunks)
  • tilelang/language/builtin.py (2 hunks)
  • tilelang/language/customize.py (10 hunks)
  • tilelang/language/tir/ir.py (1 hunks)
  • tilelang/language/tir/op.py (1 hunks)
  • tilelang/layout/__init__.py (1 hunks)
  • tilelang/layout/fragment.py (1 hunks)
  • tilelang/layout/layout.py (2 hunks)
  • tilelang/layout/swizzle.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (3 hunks)
  • tilelang/tileop/gemm/gemm_base.py (2 hunks)
  • tilelang/tileop/gemm/gemm_mma.py (2 hunks)
  • tilelang/tileop/gemm/gemm_wgmma.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (23)
tilelang/language/tir/ir.py (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/ast/ir.py (1)
  • _dtype_forward (1876-1884)
tilelang/language/allocate.py (2)
src/transform/storage_rewrite.cc (4)
  • dtype (696-702)
  • dtype (696-696)
  • scope (674-678)
  • scope (674-674)
tilelang/language/ast/ir.py (1)
  • alloc_buffer (441-508)
src/transform/lower_device_storage_access_info.cc (1)
src/transform/storage_rewrite.cc (2)
  • scope (674-678)
  • scope (674-674)
tilelang/layout/layout.py (1)
tilelang/layout/fragment.py (1)
  • is_equal (209-213)
tilelang/language/ast/ir.py (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/tir/ir.py (1)
  • _dtype_forward (156-164)
tilelang/layout/swizzle.py (2)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
src/layout/swizzle.h (1)
  • tvm (12-70)
tilelang/tileop/gemm/gemm_wgmma.py (6)
tilelang/tileop/gemm/gemm_base.py (17)
  • GemmBase (12-120)
  • infer_layout (15-16)
  • policy (119-120)
  • M (34-35)
  • N (38-39)
  • in_dtype (54-56)
  • accum_dtype (59-60)
  • trans_A (46-47)
  • trans_B (50-51)
  • chunk (63-64)
  • is_gemm_ss (21-22)
  • K (42-43)
  • A (67-68)
  • B (71-72)
  • C (75-76)
  • lower (18-19)
  • clear_accum (107-108)
tilelang/layout/swizzle.py (1)
  • make_wgmma_swizzled_layout (22-34)
tilelang/intrinsics/wgmma_macro_generator.py (6)
  • TensorCoreIntrinEmitter (63-477)
  • make_mma_store_layout (423-477)
  • make_mma_load_layout (311-421)
  • _assign_a_shared_layout (96-98)
  • _assign_b_shared_layout (100-102)
  • wgmma (145-233)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
tilelang/tileop/gemm/gemm_mma.py (3)
  • infer_layout (15-58)
  • is_gemm_ss (204-205)
  • lower (60-202)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/swizzle.py (3)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/fragment.py (4)
  • is_equal (209-213)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/language/builtin.py (1)
  • initialize_descriptor (355-386)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (2)
  • scope (674-678)
  • scope (674-674)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (2)
  • PrintWGMMAAssembly (1235-1306)
  • PrintWGMMAAssembly (1235-1244)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/op/builtin.h (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/tl_templates/cuda/common.h (3)
src/tl_templates/cuda/copy_sm90.h (1)
  • void (255-258)
src/tl_templates/cuda/ldsm.h (12)
  • void (7-14)
  • void (16-23)
  • void (25-33)
  • void (35-42)
  • void (44-52)
  • void (54-62)
  • void (64-70)
  • void (72-79)
  • void (81-89)
  • void (91-98)
  • void (100-108)
  • void (110-119)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/target/ptx.h (1)
src/target/ptx.cc (2)
  • PrintWGMMAAssembly (1235-1306)
  • PrintWGMMAAssembly (1235-1244)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
  • tvm (13-363)
tilelang/language/ast/ir.py (1)
  • evaluate (1319-1331)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
  • trans_A (46-47)
  • trans_B (50-51)
  • A (67-68)
  • B (71-72)
src/layout/gemm_layouts.cc (2)
  • makeGemmABLayout (573-592)
  • makeGemmABLayout (573-574)
tilelang/language/customize.py (1)
tilelang/language/copy.py (1)
  • get_extent (105-118)
src/op/builtin.cc (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
tilelang/tileop/gemm/gemm_mma.py (3)
tilelang/tileop/gemm/gemm_wgmma.py (1)
  • lower (64-125)
tilelang/tileop/gemm/__init__.py (1)
  • lower (76-80)
tilelang/tileop/gemm/gemm_base.py (1)
  • lower (18-19)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
  • alloc_descriptor (129-135)
tilelang/layout/__init__.py (2)
tilelang/layout/swizzle.py (5)
  • make_swizzled_layout (10-18)
  • make_wgmma_swizzled_layout (22-34)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (98-109)
tilelang/tileop/gemm/__init__.py (4)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
tilelang/tileop/gemm/gemm_mma.py (3)
  • GemmMMA (13-214)
  • lower (60-202)
  • infer_layout (15-58)
tilelang/tileop/gemm/gemm_wgmma.py (3)
  • GemmWGMMA (13-137)
  • lower (64-125)
  • infer_layout (15-62)
tilelang/tileop/gemm/gemm_base.py (2)
  • lower (18-19)
  • infer_layout (15-16)
tilelang/layout/fragment.py (1)
tilelang/layout/layout.py (3)
  • get_input_shape (59-68)
  • get_output_shape (70-79)
  • is_equal (136-145)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (4)
  • makeGemmABLayoutHopper (594-615)
  • makeGemmABLayoutHopper (594-595)
  • makeGemmABLayoutCDNA (617-625)
  • makeGemmABLayoutCDNA (617-618)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (10)
  • makeGemmABLayout (573-592)
  • makeGemmABLayout (573-574)
  • makeGemmABLayoutHopper (594-615)
  • makeGemmABLayoutHopper (594-595)
  • makeFullBankSwizzleLayout (375-391)
  • makeFullBankSwizzleLayout (375-375)
  • makeHalfBankSwizzleLayout (356-372)
  • makeHalfBankSwizzleLayout (356-356)
  • makeQuarterBankSwizzleLayout (336-353)
  • makeQuarterBankSwizzleLayout (336-337)
🪛 Ruff (0.12.2)
tilelang/layout/swizzle.py

23-23: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)


104-104: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_wgmma.py

61-62: Avoid specifying long messages outside the exception class

(TRY003)


124-125: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/wgmma_macro_generator.py

104-104: Unused method argument: n_dim

(ARG002)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


353-353: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/builtin.py

375-375: Avoid specifying long messages outside the exception class

(TRY003)


377-377: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


378-378: Avoid specifying long messages outside the exception class

(TRY003)


401-401: Avoid specifying long messages outside the exception class

(TRY003)


403-403: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


404-404: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_mma.py

60-60: Unused method argument: layout_map

(ARG002)

tilelang/language/__init__.py

45-45: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/tileop/gemm/__init__.py

120-120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (43)
tilelang/tileop/gemm/gemm_base.py (2)

8-8: Add PrimExpr import for updated clear_accum type annotation.

The import is correctly added to support the type annotation change.


107-108: Update type annotation from bool to PrimExpr.

The change from bool to PrimExpr aligns with the WGMMA backend requirements where clear_accum can be an expression rather than just a boolean value. This provides more flexibility for the accumulation control logic.

src/op/gemm_py.h (1)

112-113: Clean private section reorganization.

The reorganization of private members with explicit private: labeling improves code organization and readability. The GemmInst enum and GetGemmInst method remain appropriately private while allowing FFI access through wrapper functions.

src/transform/lower_device_storage_access_info.cc (1)

47-48: Correct exclusion of descriptor scope from storage lowering.

The addition of scope.tag != ".descriptor" properly excludes descriptor buffers from the generic storage access lowering path, which is necessary for the WGMMA descriptor handling. This change aligns with the corresponding exclusion in storage_rewrite.cc and allows descriptor buffers to maintain their special semantics.

tilelang/language/tir/ir.py (1)

294-295: Add WGMMA PTX intrinsic wrappers.

The new ptx_wgmma_ss and ptx_wgmma_rs wrappers correctly follow the established pattern using _dtype_forward decorator, providing consistent API access to the underlying PTX WGMMA intrinsics for shared-shared and register-shared variants.

tilelang/layout/fragment.py (2)

207-207: Enhanced Fragment representation with shape information.

The updated __repr__ method now includes input and output shapes, providing more comprehensive debugging information. The use of get_input_shape() and get_output_shape() follows the established pattern from the base Layout class.


209-213: Add Fragment equality check method.

The new is_equal method provides a proper way to compare Fragment instances by delegating to the FFI implementation. This aligns with the corresponding method in the base Layout class and enables proper equality testing for Fragment objects.

src/transform/storage_rewrite.cc (2)

677-678: Exclude .descriptor from “special tagged memory” — OK.

This prevents special-merge and memory‑info constraints from applying to descriptor buffers, which is desirable.


847-849: Skip memory‑info path for .descriptor in NewAllocTagMerged — OK; add test.

Change correctly treats descriptor allocations neutrally (GetMemoryInfo is not invoked for .descriptor). Add a minimal IR test that ensures:

  • local.descriptor buffers are not grouped into special merges
  • no GetMemoryInfo lookups occur for .descriptor scopes

Relevant files: src/transform/storage_rewrite.cc (≈lines 846–850), src/transform/lower_device_storage_access_info.cc (≈line 48).

tilelang/language/ast/ir.py (1)

1895-1896: Approve: WGMMA wrappers wired and registered

Wrapper bindings exist in tilelang/language/ast/ir.py (ptx_wgmma_ss/rs), TIR op wrappers call _tvm_op.Op.get("tl.ptx_wgmma_ss/rs") in tilelang/language/tir/op.py, builtins are registered in src/op/builtin.cc/.h (TIR_DEFINE_TL_BUILTIN + declarations) and referenced by codegen in src/target/codegen_cuda.cc — no issues found. Add a small CI check to assert builtin registration if desired.

src/op/builtin.cc (2)

174-183: Descriptor intrinsics API consistent with wrappers.

5 inputs for initialize and 2 for increase align with the Python layer; opaque effect is appropriate given side effects. LGTM.


83-92: WGMMA intrinsics: 15-arg arity verified; codegen enforces 15 operands. Python wrappers define ptx_wgmma_ss/rs and builtin registers .set_num_inputs(15); src/target/codegen_cuda.cc contains ICHECK_EQ(op->args.size(), 15U) for both handlers — operand alignment confirmed.

src/op/gemm_py.cc (3)

109-111: Unreachable-return safety: fine as-is.

This satisfies compilers post-ICHECK; harmless.


255-257: Unreachable-return safety: fine as-is.

Consistent with the ICHECK case; no functional impact.


284-290: FFI reflection for GemmInst selection: LGTM.

Simple, side‑effect free binding; useful for runtime dispatch.

src/op/gemm.cc (1)

481-481: LGTM! Consistent layout API migration to boolean flags.

The transition from numeric codes (1/2) to boolean flags for k_inner parameter across all backends (Volta, Ampere/Turing/SM120, Hopper) improves API clarity and type safety.

Also applies to: 507-507, 520-520, 545-545, 547-547, 563-563, 565-565

src/target/codegen_cuda.cc (5)

898-898: Including descriptor scope in pointer access logic.

The addition of "local.descriptor" scope alongside "local.var" is correct for treating descriptors as local objects.


1788-1789: LGTM! Proper allocation of descriptor storage.

The allocation of tl::GmmaDescriptor for descriptor scope is correctly implemented.


1305-1340: Check for additional error conditions in WGMMA intrinsics.

The WGMMA SS intrinsic implementation looks correct, but consider adding validation for the descriptor parameters.

Should we verify that the descriptors are properly initialized before use? Consider adding checks to ensure a_desc and b_desc are valid descriptor handles.


1701-1713: LGTM! Descriptor operations correctly implemented.

The implementation of initialize_descriptor and increase_descriptor_offset intrinsics is correct and properly forwards all parameters to the TL template functions.

Also applies to: 1714-1721


1823-1823: Correct handling of descriptor scope in allocation check.

The condition properly excludes local.descriptor from the unsupported scope error path.

tilelang/language/builtin.py (1)

355-386: LGTM! Well-documented descriptor initialization function.

The initialize_descriptor function is well-implemented with proper type checking, parameter documentation, and error handling.

tilelang/layout/swizzle.py (2)

10-18: LGTM! Well-structured swizzle layout functions.

The updated make_swizzled_layout and new make_wgmma_swizzled_layout functions are properly implemented with clear parameter forwarding to the FFI API.

Also applies to: 22-34


39-109: LGTM! Flexible bank-swizzled layout helpers.

The three bank-swizzled layout functions (full/half/quarter) are well-implemented with flexible argument handling that supports both buffer objects and explicit parameters.

src/tl_templates/cuda/common.h (1)

307-367: LGTM! Well-structured descriptor union implementation.

The GmmaDescriptor union is properly designed with:

  • Multiple access patterns via desc_, reg32_[], and reg16_[]
  • Clear bitfield layout matching CUDA WMMA descriptor format
  • Proper copy and move semantics
  • Convenient arithmetic operator for offset adjustments
src/layout/layout.cc (2)

461-478: LGTM! Equality check methods properly implemented.

The new Layout_is_equal and Fragment_is_equal FFI bindings correctly expose the underlying equality check functionality with proper node casting.


506-511: LGTM! WGMMA swizzled layout properly wired.

The make_wgmma_swizzled_layout correctly passes through the continuity parameter separately from mat_continuous, which is essential for WGMMA's layout requirements.

tilelang/tileop/gemm/gemm_wgmma.py (1)

36-37: Verify continuity calculation for k-major layouts

tilelang/tileop/gemm/gemm_wgmma.py:36-37,50 — current code: a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp and b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp. Confirm whether the k-major branch should intentionally use self.M/self.N (instead of a K-derived continuity) and whether the 4 * factor is correct; if intentional, add an inline comment explaining the rationale and add a unit test, otherwise correct the formula to derive continuity from K.

tilelang/intrinsics/wgmma_macro_generator.py (1)

220-231: Use 64-bit arithmetic for shared-memory offset calculations.

A_offset / B_offset in tilelang/intrinsics/wgmma_macro_generator.py (lines 220–231) perform multiplications like i * 64 * A_buf.shape[-1] that can overflow 32-bit; ensure intermediate arithmetic uses 64-bit (cast the PrimExprs to int64 or use an explicit int64 cast / bounds check) before multiplying by elems_in_bytes and passing to ptx_wgmma_ss.

tilelang/tileop/gemm/__init__.py (6)

1-1: LGTM! Addition of IntEnum import is appropriate for the new GemmInst class.

The import is necessary for creating the strongly-typed enumeration that will be used to represent different GEMM instruction types.


10-11: LGTM! Import additions align with the new WGMMA support.

The imports for GemmWGMMA and _ffi_api are necessary for the new architecture-aware dispatch mechanism.


21-24: Good addition of layout_map parameter to support WGMMA requirements.

The parameter addition correctly passes the layout information through to the lower-level implementation.


27-42: LGTM! Well-designed enumeration for GEMM instruction types.

The GemmInst enumeration provides a clean abstraction for different GEMM implementations with convenient helper methods for type checking.


71-81: Architecture-aware dispatch implementation looks good.

The new infer_layout and lower methods properly delegate to architecture-specific implementations using a clean dispatch pattern.


100-120: LGTM! Clean dispatch with proper error handling.

The method provides a clear mapping from instruction types to implementation classes with appropriate error handling for unsupported types.

src/target/ptx.cc (7)

149-156: LGTM! Useful helper function for integer type checking.

This function provides a convenient way to check if a PTX data type represents an integer, which is needed for WGMMA scale input validation.


188-197: LGTM! Appropriate layout type conversion helper.

The function correctly maps boolean values to layout types for WGMMA assembly generation.


278-291: Well-structured WGMMA configuration struct.

The struct follows the same pattern as MMAConfig and provides proper equality comparison for configuration validation.


293-720: Comprehensive WGMMA configuration table.

The extensive configuration table covers all supported WGMMA operations including dense and sparse variants across multiple data types. This aligns well with the PR objective of supporting int8, fp8, and tf32 data types.


859-879: Solid configuration validation function.

The function properly validates WGMMA configurations by checking data type compatibility and searching for matching configurations in the valid list.


1119-1121: Fatal error for sparse WGMMA is reasonable.

Since sparse WGMMA is not yet implemented, failing fast with a clear error message is the appropriate approach.


1235-1306: Well-structured WGMMA assembly generation.

The function properly generates PTX assembly for WGMMA operations with support for various configurations including scale parameters and layout options. The template replacement approach is clean and maintainable.

tilelang/language/allocate.py (1)

129-136: New descriptor allocator looks good

Minimal, aligned with ".descriptor" scope handling in codegen. LGTM.

Comment on lines +231 to 233
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

🧩 Analysis chain

FFI lower signature now includes layout_map; ensure all providers updated.

The C++ call passes T.layout_map. Verify the Python FFI target “tl.gemm_py.lower” accepts (gemm_py, layout_map, target, thread_bounds, thread_var) everywhere to avoid runtime errors.


🏁 Script executed:

#!/bin/bash
# Confirm the new signature at Python side
rg -nP 'def\s+lower\((self|gemm_py)[^)]*layout_map' -C2

# Find the FFI registration or dispatcher for "tl.gemm_py.lower"
rg -n 'gemm_py\.lower' -S

# Check base class signatures to avoid LSP mismatches
rg -nP 'class\s+GemmBase\b.*?\n' -n tilelang/tileop/gemm -C2
rg -nP 'def\s+lower\(' tilelang/tileop/gemm -S

Length of output: 2364


Add layout_map to GemmBase.lower

Change tilelang/tileop/gemm/gemm_base.py:18 from
def lower(self, target, thread_nums, thread_var)
to
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var)
so the base signature matches the FFI tl.gemm_py.lower and existing subclass implementations.

src/target/ptx.h Outdated
Comment on lines 68 to 86
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
*/
std::string PrintWGMMAAssembly(
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
const bool &scale_in_b, const bool &a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Documentation incomplete and misaligned parameters.

The documentation comment only describes the first 6 parameters, but the function signature has 17 parameters. This incomplete documentation makes it difficult to understand the purpose of the remaining parameters.

Please update the documentation to include all parameters:

 /*!
  * \brief Print WGMMA assembly string given parameters.
  * \param shape The shape string mMnNkK
  * \param A_layout The layout of multiplicand A, can be either "row" or "col".
  * \param B_layout The layout of multiplicand B, can be either "row" or "col".
  * \param A_dtype The data type of multiplicand A.
  * \param B_dtype The data type of multiplicand B.
  * \param C_dtype The data type of multiplicand C.
+ * \param a_desc Descriptor for operand A (shared memory descriptor or register pointer).
+ * \param A_offset Offset for operand A.
+ * \param b_desc Descriptor for operand B.
+ * \param B_offset Offset for operand B.
+ * \param c_ptr Pointer to accumulator C.
+ * \param c_offset Offset for accumulator C.
+ * \param scale_out Scaling output flag.
+ * \param scale_in_a Scaling input A flag.
+ * \param scale_in_b Scaling input B flag.
+ * \param a_is_shared Whether operand A is in shared memory.
+ * \param metadata Pointer to metadata buffer (for sparse operations).
+ * \param metadata_offset Offset in metadata buffer.
+ * \param sparsity_selector Sparsity selector for sparse operations.
+ * \param sparse Whether this is a sparse WGMMA operation.
  */
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
*/
std::string PrintWGMMAAssembly(
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
const bool &scale_in_b, const bool &a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
* \param a_desc Descriptor for operand A (shared memory descriptor or register pointer).
* \param A_offset Offset for operand A.
* \param b_desc Descriptor for operand B.
* \param B_offset Offset for operand B.
* \param c_ptr Pointer to accumulator C.
* \param c_offset Offset for accumulator C.
* \param scale_out Scaling output flag.
* \param scale_in_a Scaling input A flag.
* \param scale_in_b Scaling input B flag.
* \param a_is_shared Whether operand A is in shared memory.
* \param metadata Pointer to metadata buffer (for sparse operations).
* \param metadata_offset Offset in metadata buffer.
* \param sparsity_selector Sparsity selector for sparse operations.
* \param sparse Whether this is a sparse WGMMA operation.
*/
std::string PrintWGMMAAssembly(
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
const bool &scale_in_b, const bool &a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 68 to 86, the function comment documents only
the first six parameters while the declaration has 17; update the block comment
to include @param entries for every parameter in the signature (A_layout,
B_layout — explain what true/false means or that they represent "row"/"col";
a_desc and b_desc — describe descriptor strings for A/B; A_offset and B_offset —
offset expressions or registers for A/B; c_ptr and c_offset — pointer and offset
used for C output; scale_out, scale_in_a, scale_in_b — boolean flags controlling
scaling behavior; a_is_shared — whether A is in shared memory; metadata and
metadata_offset — metadata pointer/string and its offset; sparsity_selector —
sparsity control selector; sparse — whether sparse mode is enabled), keeping
each description concise and aligned with the existing doc style.

src/target/ptx.h Outdated
Comment on lines 78 to 79
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Inconsistent use of const references for boolean parameters.

The function signature uses const bool & for boolean parameters, which is unnecessary overhead for primitive types. This differs from the existing PrintMMAAssembly function that passes booleans by value.

Apply this diff to pass booleans by value:

-std::string PrintWGMMAAssembly(
-    const std::string &shape, const bool &A_layout, const bool &B_layout,
+std::string PrintWGMMAAssembly(
+    const std::string &shape, bool A_layout, bool B_layout,
     const std::string &A_dtype, const std::string &B_dtype,
     const std::string &C_dtype, const std::string &a_desc,
     const std::string &A_offset, const std::string &b_desc,
     const std::string &B_offset, const std::string &c_ptr,
-    const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
-    const bool &scale_in_b, const bool &a_is_shared,
+    const std::string &c_offset, bool scale_out, bool scale_in_a,
+    bool scale_in_b, bool a_is_shared,
     const std::string &metadata, const std::string &metadata_offset,
     const std::string &sparsity_selector, bool sparse);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const std::string &shape, const bool &A_layout, const bool &B_layout,
const std::string &A_dtype, const std::string &B_dtype,
std::string PrintWGMMAAssembly(
const std::string &shape, bool A_layout, bool B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
const std::string &c_offset, bool scale_out, bool scale_in_a,
bool scale_in_b, bool a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 78-79, the function signature uses `const bool
&A_layout` and `const bool &B_layout` which is unnecessary for primitive types;
change both parameters to pass by value (`bool A_layout, bool B_layout`) and
update any matching function declarations/definitions and callers to use value
semantics so signatures remain consistent with PrintMMAAssembly.

Comment on lines +405 to +421
template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type;
descriptor.bitfield.base_offset_ = 0;
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
}

template <typename T>
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
T offset) {
descriptor.reg32_[0] += (offset >> 4);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

🧩 Analysis chain

Consider validating descriptor parameters.

The initialize_descriptor and increase_descriptor_offset functions directly manipulate hardware descriptor fields without validation. Consider adding assertions or checks for:

  • Valid layout_type values (0-3)
  • Byte offset alignment requirements
  • Address alignment to 16-byte boundaries

🏁 Script executed:

#!/bin/bash
# Check if there are any existing validation patterns for descriptor initialization
ast-grep --pattern 'initialize_descriptor($$$) {
  $$$
  assert($$$)
  $$$
}'

# Look for any documentation or comments about descriptor constraints
rg -A 3 -B 3 "descriptor.*align|align.*descriptor" --type cpp --type h

Length of output: 333


Validate GMMA descriptor parameters before writing hardware fields

These functions write descriptor bitfields with no checks; src/target/cuda.h documents a 64‑byte alignment requirement for tensor map descriptors. Add assertions/guards in src/tl_templates/cuda/common.h (initialize_descriptor and increase_descriptor_offset, ~lines 405–421):

  • Assert layout_type is in the supported range (0–3).
  • Assert start_address is 16‑byte aligned (e.g., (uintptr_t)start_address & 0xF == 0).
  • Assert leading_byte_offset, stride_byte_offset, and the offset passed to increase_descriptor_offset are multiples of 16 (since code uses >> 4) and fit the descriptor bitfield widths; also guard against reg32_[0] overflow when adding (offset >> 4).

Files to edit: src/tl_templates/cuda/common.h (initialize_descriptor, increase_descriptor_offset). Reference: src/target/cuda.h (tensor map descriptor alignment comment).

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 405 to 421, add validation before
writing descriptor fields: assert layout_type is within 0..3; assert
start_address is 16-byte aligned ((uintptr_t)start_address & 0xF == 0); assert
leading_byte_offset and stride_byte_offset are multiples of 16 and within the
bitfield width limits documented for GmmaDescriptor; in
increase_descriptor_offset assert offset is a multiple of 16, compute delta =
offset >> 4 and check that descriptor.reg32_[0] + delta will not overflow the
destination bitfield (or reg32_) before applying it; if any check fails, either
trigger a fast-fail assertion or return/handle the error path so invalid values
are never written to hardware fields.

if not isinstance(descriptor, (BufferLoad, Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")

if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix operator precedence ambiguity.

The condition mixes and and or operators without proper parentheses, which can lead to unexpected behavior.

Apply this diff to fix the operator precedence:

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):

The same issue exists in increase_descriptor_offset:

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):

Also applies to: 403-403

🧰 Tools
🪛 Ruff (0.12.2)

377-377: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)

🤖 Prompt for AI Agents
In tilelang/language/builtin.py around lines 377 and 403, the conditional mixes
"and" and "or" without parentheses causing ambiguous precedence; wrap the shape
checks in parentheses so the intent is clear — e.g. make the condition check
"isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1)" (and apply the same parenthesized grouping in
increase_descriptor_offset at line 403) to ensure the logical grouping is
correct.

cursor[bot]

This comment was marked as outdated.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/target/ptx.h (1)

29-31: Missing headers for Replacer (compile blocker).

Replacer uses std::vector and std::pair but (and ideally ) aren’t included.

Apply:

 #include <string>
 #include <tuple>
+#include <vector>
+#include <utility>
src/target/ptx.cc (1)

26-31: Missing (compile blocker).

std::stringstream is used but header isn’t included.

Apply:

 #include <algorithm>
+#include <sstream>
 #include <string>
 #include <tuple>
 #include <utility>
 #include <vector>
♻️ Duplicate comments (5)
src/target/ptx.h (2)

156-164: WGMMA Doxygen is incomplete and mismatched with the signature.

Only the first 6 params are documented; names refer to A/B_layout strings while the API uses a_is_k_major/b_is_k_major and many more args.

Recommend adding @param docs for all parameters and aligning names/semantics with the signature (a_is_k_major, b_is_k_major, a_desc, A_offset, b_desc, B_offset, c_ptr, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, metadata, metadata_offset, sparsity_selector, sparse). Want a ready-to-apply doc patch?


165-174: Pass booleans by value, not const references.

Primitive bools shouldn’t be passed by const&. Keep consistent with other APIs.

Apply:

-std::string PrintWGMMAAssembly(
-    const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major,
+std::string PrintWGMMAAssembly(
+    const std::string &shape, bool a_is_k_major, bool b_is_k_major,
     const std::string &A_dtype, const std::string &B_dtype,
     const std::string &C_dtype, const std::string &a_desc,
     const std::string &A_offset, const std::string &b_desc,
     const std::string &B_offset, const std::string &c_ptr,
-    const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
-    const bool &scale_in_b, const bool &a_is_shared,
+    const std::string &c_offset, bool scale_out, bool scale_in_a,
+    bool scale_in_b, bool a_is_shared,
     const std::string &metadata, const std::string &metadata_offset,
     const std::string &sparsity_selector, bool sparse);
src/tl_templates/cuda/common.h (3)

438-454: Add basic parameter validation (layout_type range).

Follow-up to earlier review on descriptor validation.

Add: assert(0 <= layout_type && layout_type <= 3).


450-454: increase_descriptor_offset risks overflow/carry into other fields.

Directly adding to reg32_[0] can spill into adjacent fields; enforce 16B granularity and update the bitfield safely.

Apply:

 TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
                                           T offset) {
-  descriptor.reg32_[0] += (offset >> 4);
+  assert((offset % 16) == 0);
+  uint32_t delta = uint32_t(offset >> 4);
+  uint32_t new_sa = uint32_t(descriptor.bitfield.start_address_) + delta;
+  // 14-bit field
+  assert((new_sa & ~0x3FFFu) == 0);
+  descriptor.bitfield.start_address_ = uint16_t(new_sa & 0x3FFFu);
 }

438-448: initialize_descriptor: missing >>4 for byte-offset fields; no validation.

Bitfields exclude 4 LSBs; current code writes raw byte offsets. Also lacks alignment/range checks.

Apply:

 TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
                                      T *start_address) {
-  descriptor.bitfield.start_address_ =
-      cute::cast_smem_ptr_to_uint(start_address) >> 4;
+  auto sa = cute::cast_smem_ptr_to_uint(start_address);
+  // 16B alignment and 14-bit range
+  assert((sa & 0xF) == 0);
+  descriptor.bitfield.start_address_ = uint16_t((sa >> 4) & 0x3FFF);
   descriptor.bitfield.layout_type_ = layout_type;
   descriptor.bitfield.base_offset_ = 0;
-  descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
-  descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
+  assert((leading_byte_offset % 16) == 0);
+  assert((stride_byte_offset % 16) == 0);
+  descriptor.bitfield.leading_byte_offset_ =
+      uint16_t(((leading_byte_offset >> 4) & 0x3FFF));
+  descriptor.bitfield.stride_byte_offset_ =
+      uint16_t(((stride_byte_offset >> 4) & 0x3FFF));
 }

Also assert layout_type in [0,3].

🧹 Nitpick comments (14)
tilelang/language/tir/op.py (1)

1106-1144: Add a docstring for ptx_wGMMA_rs; clarify variant.

Mirror the ss variant doc and note A is a buffer (not a descriptor).

Apply this diff:

 def ptx_wgmma_rs(
     dtype,
     wgmma_prefix,
     a_is_k_major,
     b_is_k_major,
     a_dtype_abbrv,
     b_dtype_abbrv,
     accum_dtype_abbrv,
     A_buf,
     A_offset,
     B_desc,
     B_offset,
     C_data,
     C_offset,
     scale_out,
     scale_in_a,
     scale_in_b,
 ):
-
-
+    """PTX WGMMA (warp-group MMA) intrinsic wrapper.
+
+    Variant: rs (A uses buffer pointer, B uses descriptor).
+    See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA).
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+    """

Optional: consider renaming A_buf to A_data or A_ptr for consistency with C_data (only if no keyword callers rely on A_buf).

src/target/ptx.h (1)

45-69: Two DataType enums exist (ptx::DataType here and tl::DataType in common.h).

Risk of drift and conversion friction.

Add explicit conversion helpers and static_assert ordinal equivalence in one place, or consolidate to a single enum declared in a shared header.

src/tl_templates/cuda/common.h (1)

18-19: Redundant SMEM pointer casters.

Both cute::cast_smem_ptr_to_uint and local helpers exist (cast_smem_ptr_to_int/smem_ptr_to_uint).

Pick one idiom and remove duplicates to avoid confusion.

src/tl_templates/cuda/instruction/wgmma.h (1)

22-35: Don’t rely on device printf; fail fast when unspecialized.

Keep printf only under a debug guard; restore a compile-time error otherwise.

Apply:

-    ) {
-        printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, scaleB=%d\n",
-               (int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB);
-        // 暂时注释掉 static_assert 来看调试输出
-        // static_assert(always_false_v<decltype(c)>,
-        //     "wgmma_ss: No specialization available for given template parameters!");
-    };
+    ) {
+#if defined(TL_DEBUG_WGMMA_FALLBACK)
+        printf("WGMMA fallback A=%d B=%d C=%d M=%d N=%d K=%d tnspA=%d tnspB=%d scaleA=%d scaleB=%d\n",
+               (int)A_type, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB);
+#else
+        static_assert(always_false_v<decltype(c)>,
+                      "wgmma_ss: No specialization available for given template parameters");
+#endif
+    };
src/target/ptx.cc (2)

171-177: Pass bool by value in LayoutTypeFromBool.

Avoid const& for primitives.

Apply:

-LayoutType LayoutTypeFromBool(const bool &layout) {
+LayoutType LayoutTypeFromBool(bool layout) {

1186-1195: Unnecessary const& on bools in PrintWGMMAAssembly.

Match the header change; pass by value.

Apply:

-std::string PrintWGMMAAssembly(
-    const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major,
+std::string PrintWGMMAAssembly(
+    const std::string &shape, bool a_is_k_major, bool b_is_k_major,
     const std::string &A_dtype, const std::string &B_dtype,
     const std::string &C_dtype, const std::string &a_desc,
     const std::string &A_offset, const std::string &b_desc,
     const std::string &B_offset, const std::string &c_ptr,
-    const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
-    const bool &scale_in_b, const bool &a_is_shared,
+    const std::string &c_offset, bool scale_out, bool scale_in_a,
+    bool scale_in_b, bool a_is_shared,
     const std::string &metadata, const std::string &metadata_offset,
     const std::string &sparsity_selector, bool sparse) {
src/target/codegen_cuda.cc (3)

1306-1317: Fix argument docs for tl::ptx_wgmma_ss branch

The comments don’t match the actual args (dtype is the call’s return type, not in op->args). Please correct to avoid future regressions.

-    // arg 0: dtype
-    // arg 1: shape
-    // arg 2: A_layout
-    // arg 3: B_layout
-    // arg 4: A_dtype
-    // arg 5: B_dtype
-    // arg 6: C_dtype
-    // arg 7: multiplicand_a
-    // arg 8: multiplicand_b
-    // arg 9: accumulator
-    // arg 10: saturate
+    // arg 0: wgmma_prefix (shape string, e.g. "m64n128k32")
+    // arg 1: a_is_k_major (bool)
+    // arg 2: b_is_k_major (bool)
+    // arg 3: a_dtype_abbrv (string)
+    // arg 4: b_dtype_abbrv (string)
+    // arg 5: accum_dtype_abbrv (string)
+    // arg 6: A_desc (descriptor, shared path)
+    // arg 7: A_offset (byte offset)
+    // arg 8: B_desc (descriptor)
+    // arg 9: B_offset (byte offset)
+    // arg 10: C_data (ptr)
+    // arg 11: C_offset (byte offset)
+    // arg 12: scale_out (bool)
+    // arg 13: scale_in_a (bool)
+    // arg 14: scale_in_b (bool)

1364-1375: Align RS branch docs with actual args

Mirror the corrected SS docs for RS; current comments mention “dtype/saturate”, which don’t exist here.

-    // arg 0: dtype
-    // arg 1: shape
-    // arg 2: A_layout
-    // arg 3: B_layout
-    // arg 4: A_dtype
-    // arg 5: B_dtype
-    // arg 6: C_dtype
-    // arg 7: multiplicand_a
-    // arg 8: multiplicand_b
-    // arg 9: accumulator
-    // arg 10: saturate
+    // arg 0: wgmma_prefix (shape)
+    // arg 1: a_is_k_major (bool)
+    // arg 2: b_is_k_major (bool)
+    // arg 3: a_dtype_abbrv (string)
+    // arg 4: b_dtype_abbrv (string)
+    // arg 5: accum_dtype_abbrv (string)
+    // arg 6: A_buf (global pointer, non-shared path)
+    // arg 7: A_offset (byte offset)
+    // arg 8: B_desc (descriptor)
+    // arg 9: B_offset (byte offset)
+    // arg 10: C_data (ptr)
+    // arg 11: C_offset (byte offset)
+    // arg 12: scale_out (bool)
+    // arg 13: scale_in_a (bool)
+    // arg 14: scale_in_b (bool)

1724-1743: Descriptor intrinsics emission — consider explicit offset type and statement-only usage

  • Template parameter for increase_descriptor_offset should be explicit-width to avoid ABI surprises on different toolchains.
  • These are side-effect calls; ensure they’re only used in EvaluateNode contexts.
-    os << "tl::increase_descriptor_offset<int>(" << PrintExpr(descriptor)
+    os << "tl::increase_descriptor_offset<int32_t>(" << PrintExpr(descriptor)
        << ", " << PrintExpr(offset) << ")";

Would you like me to scan call sites to confirm they’re only used as statements?

tilelang/intrinsics/wgmma_macro_generator.py (5)

216-219: Use read access for descriptor base pointers.

These descriptors are read by the instruction; use access_ptr("r"), not "w". Covered in the earlier diffs.

Also applies to: 283-283


104-109: Remove unused n_dim parameter.

_initialize_wgmma_prefix(self, n_dim: int = 16) doesn’t use n_dim (Ruff ARG002). Simplify:

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
@@
-        self._initialize_wgmma_prefix(self.n_dim)
+        self._initialize_wgmma_prefix()

Also applies to: 94-95


333-333: Drop duplicate import of is_fragment.

Already imported at the top; remove the inner import.

-        from tilelang.utils import is_fragment

312-317: Fix docstring: this is a load layout for operand A, not a store layout.

-        Create a layout function for storing MMA results into a fragment buffer.
-        This layout is used in conjunction with `inverse_mma_store_layout` to
-        map fragment indices to threads and local indices.
+        Create a layout describing how to load MMA operand A from a fragment buffer.
+        This layout is used in conjunction with `inverse_mma_load_layout` to
+        map fragment indices to threads and per-thread local indices.

41-59: Clarify swizzle_atom_size() semantics or compute from bytes to avoid confusion.

If you keep this helper, consider defining it as swizzle_byte_size() // 16 (16‑byte atoms), to match descriptor units terminology. Current // 16 on bits is non‑obvious. Usage has been corrected in earlier diffs, but this improves readability:

     def swizzle_atom_size(self) -> int:
-        if self.is_swizzle_32b():
-            return 32 // 16
-        elif self.is_swizzle_64b():
-            return 64 // 16
-        elif self.is_swizzle_128b():
-            return 128 // 16
-        else:
-            return 1
+        # number of 16-byte atoms in the swizzle size (32B→2, 64B→4, 128B→8)
+        return self.swizzle_byte_size() // 16
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eac5433 and 51fcf15.

📒 Files selected for processing (9)
  • src/op/builtin.h (2 hunks)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/target/ptx.cc (9 hunks)
  • src/target/ptx.h (2 hunks)
  • src/tl_templates/cuda/common.h (5 hunks)
  • src/tl_templates/cuda/gemm.h (1 hunks)
  • src/tl_templates/cuda/instruction/wgmma.h (1 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
  • tilelang/language/tir/op.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (5)
src/tl_templates/cuda/instruction/wgmma.h (2)
src/op/builtin.h (1)
  • tl (22-362)
src/tl_templates/cuda/common.h (5)
  • tl (306-400)
  • DataType (315-360)
  • int (97-100)
  • int (135-142)
  • uint32_t (118-120)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
  • scope (674-678)
  • scope (674-674)
  • n (366-370)
  • n (366-366)
  • n (371-375)
  • n (371-371)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (8)
  • PrintWGMMAAssembly (1186-1257)
  • PrintWGMMAAssembly (1186-1195)
  • ParseMMAShape (140-148)
  • ParseMMAShape (140-140)
  • DTypeEnumToString (107-109)
  • DTypeEnumToString (107-107)
  • DTypeEnumToString (111-113)
  • DTypeEnumToString (111-111)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
src/tl_templates/cuda/common.h (3)
src/target/ptx.h (1)
  • DataType (45-91)
src/tl_templates/cuda/ldsm.h (12)
  • void (7-14)
  • void (16-23)
  • void (25-33)
  • void (35-42)
  • void (44-52)
  • void (54-62)
  • void (64-70)
  • void (72-79)
  • void (81-89)
  • void (91-98)
  • void (100-108)
  • void (110-119)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/swizzle.py (3)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/fragment.py (3)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/language/builtin.py (1)
  • initialize_descriptor (355-386)
tilelang/language/tir/op.py (3)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
  • any (1774-1790)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
src/target/ptx.h (1)
src/target/ptx.cc (10)
  • DTypeFromString (54-104)
  • DTypeFromString (54-54)
  • DTypeEnumToString (107-109)
  • DTypeEnumToString (107-107)
  • DTypeEnumToString (111-113)
  • DTypeEnumToString (111-111)
  • ParseMMAShape (140-148)
  • ParseMMAShape (140-140)
  • PrintWGMMAAssembly (1186-1257)
  • PrintWGMMAAssembly (1186-1195)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py

104-104: Unused method argument: n_dim

(ARG002)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


353-353: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (9)
tilelang/language/tir/op.py (1)

1064-1104: Change docstring from WMMA → WGMMA and update link

File: tilelang/language/tir/op.py (ptx_wgmma_ss). Docstring currently says WMMA; it should reference WGMMA (warp-group). Verified tl.ptx_wgmma_* ops are registered and expect 15 args.

 def ptx_wgmma_ss(
 @@
-    """TVM intrinsic for ptx tensor core wmma instructions
-    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-wmma
-    """
+    """PTX WGMMA (warp-group MMA) intrinsic wrapper.
+
+    Variant: ss (A uses descriptor, B uses descriptor).
+    See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA).
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+    """
src/tl_templates/cuda/instruction/wgmma.h (2)

417-452: INT8 coverage seems partial (only N=8/16).

valid_wgmma_configs include N up to 256. If codegen may request N=32/64/… you’ll hit the fallback at compile time.

Would you like me to generate the missing specializations (or a meta-emitter) for s8->s32 shapes N∈{32,64,96,128,192,256}?


457-491: FP8 E4M3→F16/F32: only N=8 variants present.

Configs list N up to 256. Confirm dispatcher never asks for larger N, or add the remaining specializations.

I can auto-generate these blocks to match the config table.

src/tl_templates/cuda/gemm.h (1)

5-6: Include order LGTM.

Including wgmma.h in the sm90+ path before gemm_sm90.h is appropriate.

src/target/ptx.cc (1)

1222-1256: WGMMA operand builder: sparse path abort is fine; ensure scale_in use matches dtype.

Scale inputs are emitted only for non-integer types; that matches spec. No action needed.

Please confirm codegen never requests sparse WGMMA yet; otherwise return a diagnostic instead of LOG(FATAL) at higher layers.

src/tl_templates/cuda/common.h (1)

315-339: Enum duplication with ptx::DataType — parity unverified

No matches found for tvm::tl::codegen::ptx::DataType; ensure src/tl_templates/cuda/common.h (DataType, lines 315–339) has identical ordinals/values to the ptx definition. To verify locally, run:

rg -nP '\benum class DataType\b' -S
rg -nP 'ptx::DataType|namespace\s+ptx' -S
src/target/codegen_cuda.cc (2)

898-901: Local descriptor deref handling in GetBufferRef — LGTM

Returning the scalar vid for "local.descriptor" (same as "local.var") avoids bogus []/casts on tl::GmmaDescriptor. This matches the Python-side Buffer(shape=[1]) abstraction.


1810-1812: Ensure tl::GmmaDescriptor is visible from generated TU

src/target/codegen_cuda.cc already emits "#include <tl_templates/cuda/gemm.h>" in Finish() (decl_stream at src/target/codegen_cuda.cc:186). No declaration of tl::GmmaDescriptor was found in the repo—confirm that gemm.h (or an included header like common.h) actually defines tl::GmmaDescriptor; if it does not, add an explicit include for the header that declares it in Finish().

tilelang/intrinsics/wgmma_macro_generator.py (1)

104-109: Validate inst_n in wgmma_prefix against supported shapes.

Hardware supports specific N sizes (e.g., 8/16/32/64/128). inst_n = block_col_warps * warp_col_tiles may produce unsupported values. Add a guard or normalize.

Comment on lines +389 to +399
// Decay to a uint64_t
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
ret.reg32_[0] = reg32_[0] + uint32_t(offset);
ret.reg32_[1] = reg32_[1];
return ret;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

operator+ adds raw bytes; descriptor fields are in 16‑byte quanta.

The descriptor encodes addresses/offsets with 4 LSBs stripped. Adding raw offset likely mis-encodes.

Apply:

-  template <typename T>
-  CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
+  template <typename T>
+  CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
     GmmaDescriptor ret;
-    ret.reg32_[0] = reg32_[0] + uint32_t(offset);
+    ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
     ret.reg32_[1] = reg32_[1];
     return ret;
   }

And document that offset is in bytes.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Decay to a uint64_t
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
ret.reg32_[0] = reg32_[0] + uint32_t(offset);
ret.reg32_[1] = reg32_[1];
return ret;
}
// Decay to a uint64_t
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}

Comment on lines +168 to +169
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add divisibility assertion for K micro-tiles.

You iterate range(k_dim // micro_size_k) but don’t enforce divisibility; tails would be silently dropped. Add:

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"

Apply in both SS and RS.

Also applies to: 255-256

🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 168-169 (and
similarly around lines 255-256), you assert k_dim >= micro_size_k but do not
require k_dim to be divisible by micro_size_k, causing any tail to be silently
dropped when iterating range(k_dim // micro_size_k); add an assertion ensuring
k_dim % micro_size_k == 0 with a clear error message (e.g., "k_dim must be
divisible by micro_size_k, got k_dim: {k_dim}, micro_size_k: {micro_size_k}") in
both the SS and RS sections so the code fails fast instead of silently
truncating tails.

Comment on lines +255 to +256
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Mirror the fixes in RS path (per‑operand sizes, swizzle LBO units, pointer mode).

Do the same corrections for RS:

  • Use separate a_elems_in_bytes/b_elems_in_bytes.
  • Use swizzle_byte_size() for LBO and keep >>4.
  • Use access_ptr("r") for descriptors.

Also add the divisibility assert (see separate comment). Apply:

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"
@@
-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+                                                                                 b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+                                                                                    b_elems_in_bytes)
         if not b_swizzle_mode.is_none():
             # swizzle mode doesn't require LBO/SBO to be 1
             # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
             if b_is_k_major:
                 b_leading_byte_offset = 16
             else:
                 # MN Major
                 # LBO represents the distance between two atoms along the N dimension
                 # SBO represents the distance between two atoms along the K dimension
-                b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
-                b_stride_byte_offset = 8 * n_dim * elems_in_bytes
+                b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+                b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                        (B_offset * elems_in_bytes) >> 4,
+                        (B_offset * b_elems_in_bytes) >> 4,

Also applies to: 257-278, 283-284, 301-302

🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 255-256 (and
similarly apply changes to ranges 257-278, 283-284, 301-302), the RS path
currently mirrors the MS fixes incorrectly: replace any single shared
elems_in_bytes with separate a_elems_in_bytes and b_elems_in_bytes computed from
their per‑operand element sizes; for LBO swizzle sizes call
swizzle_byte_size(...) and retain the bitshift >> 4 as before; change descriptor
pointer creation to use access_ptr("r") instead of the current mode; and add the
divisibility assert that ensures k_dim (or relevant dim) is divisible by
micro_size_k (matching the other path). Ensure these edits are applied
consistently across the listed line ranges.

Comment on lines +291 to +295
accum_dtype,
wgmma_prefix,
self.a_transposed,
not self.b_transposed,
a_dtype_abbrv,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Likely wrong K‑major flags passed to RS intrinsic (inverted vs SS).

ptx_wgmma_rs takes the same a_is_k_major/b_is_k_major semantics as SS. SS passes not self.a_transposed and self.b_transposed, but RS passes their negations. Align them:

-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,

Please confirm against the PTX WGMMA docs for RS.


🌐 Web query:

What are the `a_is_k_major` / `b_is_k_major` semantics for `wgmma.mma_async` RS vs SS variants in NVIDIA PTX?

💡 Result:

Short answer: those flags tell the instruction whether the A / B operand’s shared‑memory descriptor (SS) — or the register‑sourced operand’s logical layout (RS) — is K‑major (i.e. the 8×2 / 2×(N/8) “core‑matrix” tile stride is along the K dimension) or not. Concretely:

  • SS variants (the .ss forms) take 64‑bit matrix descriptors for A and/or B; the descriptor’s “major‑K” bit(s) (a_is_k_major / b_is_k_major) mean “this SMEM tile is laid out K‑major” so the hardware will interpret LBO/SBO in the descriptor accordingly (core matrices are arranged with K as the strided direction). Use these bits when your shared‑memory tile is stored with K as the leading/strided dimension. [1][2]

  • RS variants (the .rs forms) source one operand from registers (or a thread‑local fragment) and the other from a descriptor; for RS the a_is_k_major / b_is_k_major semantic still denotes whether that operand’s logical tile is K‑major, but for the register‑sourced operand the layout interpretation is applied to the fragment/register mapping (i.e. whether the fragment is arranged as K‑major vs M/N‑major). In practice this means you must set the bit to match how your register fragment was produced (or how you packed it), otherwise the multiply/accumulate will use the wrong element ordering. [1][2]

References:

  • PTX WGMMA / wgmma.mma_async description (matrix descriptors, K‑major vs MN‑major and SS/RS usage). [1]
  • CUTLASS / WGMMA notes (explanation of descriptors, LBO/SBO and K‑major layout and the distinction of SS mode requiring descriptors). [2]

[1] NVIDIA PTX ISA (wgmma / wgmma.mma_async matrix descriptor & layout).
[2] CUTLASS / WGMMA tutorial and notes on shared‑memory descriptors and K‑major layout.


Fix inverted K‑major flags for RS WGMMA intrinsic

RS uses the same a_is_k_major / b_is_k_major semantics as SS; the call currently passes their negations. Update the args at tilelang/intrinsics/wgmma_macro_generator.py:291-295:

-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
accum_dtype,
wgmma_prefix,
self.a_transposed,
not self.b_transposed,
a_dtype_abbrv,
accum_dtype,
wgmma_prefix,
(not self.a_transposed),
self.b_transposed,
a_dtype_abbrv,
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 291-295 the K-major
flags for the RS WGMMA intrinsic are passed inverted relative to the SS
semantics; replace the current arguments so the a/b K-major booleans are the
negations of their current values (change self.a_transposed to not
self.a_transposed, and change not self.b_transposed to self.b_transposed) so the
call passes the correct a_is_k_major and b_is_k_major flags.

…e size calculations in wgmma macro generator. Enhanced handling of leading and stride byte offsets based on swizzle mode, improving clarity and performance in tensor core intrinsic emissions.
print(f"b_leading_byte_offset: {b_leading_byte_offset}")
print(f"b_stride_byte_offset: {b_stride_byte_offset}")
# exit()
@T.macro
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debugging Artifacts in Code

Debugging print statements and a commented exit() call were accidentally committed. These temporary artifacts output debug information during runtime, polluting the console.

Fix in Cursor Fix in Web

// 暂时注释掉 static_assert 来看调试输出
// static_assert(always_false_v<decltype(c)>,
// "wgmma_ss: No specialization available for given template parameters!");
};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debug Code and Masked Errors in Template

The WgmmaSSImpl fallback template includes a debug printf statement that pollutes output and a commented-out static_assert. This static_assert (noted by a Chinese comment as temporary debugging code) could mask compilation errors.

Fix in Cursor Fix in Web

class GemmInst(IntEnum):
MMA = 0
WGMMMA = 1
MFMA = 2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Enum Mismatch and Typo Cause Runtime Errors

The GemmInst Python enum has a value mismatch with its C++ counterpart, leading to runtime errors when C++ returns an undefined enum value (e.g., for MFMA or kUTCMMA). There's also a typo in the WGMMMA enum member, which should be WGMMA.

Fix in Cursor Fix in Web

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
src/target/codegen_cuda.cc (1)

1305-1362: WGMMA-SS: emit validated asm_code; drop ad‑hoc string (invalid arithmetic on descriptor, dead code)

asm_code from PrintWGMMAAssembly(...) is computed and then ignored. The custom wgmma_asm_code does pointer arithmetic on tl::GmmaDescriptor via uint64_t((desc_a) + (A_offset)), which is invalid. Unify with the RS path and stream asm_code.

Apply:

   const bool a_is_shared = true;
   this->PrintIndent();
   std::string asm_code = PrintWGMMAAssembly(
       shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset,
       b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
       a_is_shared, "", "", "", false);
-  auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
-  std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
-  // replace patterns
-  tl::codegen::Replacer replacer;
-  replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype));
-  replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype));
-  replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype));
-  replacer.register_rule("(M)", std::to_string(m));
-  replacer.register_rule("(N)", std::to_string(n));
-  replacer.register_rule("(K)", std::to_string(k));
-  replacer.register_rule("(tnspA)", a_is_k_major? "false": "true");
-  replacer.register_rule("(tnspB)", b_is_k_major? "false": "true");
-  replacer.register_rule("(scaleA)", scale_in_a? "1": "-1");
-  replacer.register_rule("(scaleB)", scale_in_b? "1": "-1");
-  replacer.register_rule("(desc_a)", a_desc);
-  replacer.register_rule("(A_offset)", A_offset);
-  replacer.register_rule("(desc_b)", b_desc);
-  replacer.register_rule("(B_offset)", B_offset);
-  replacer.register_rule("(C)", c_ref + " + " + c_offset);
-  replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
-  wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
-  this->stream << wgmma_asm_code;
+  this->stream << asm_code;
tilelang/intrinsics/wgmma_macro_generator.py (5)

159-169: Add validations: m_dim multiple-of-64 and k_dim divisibility by micro_size_k.

Prevents silent no-op loops and tails.

         m_dim = self.block_row_warps * self.warp_row_tiles
         warp_cols = self.warp_cols
         micro_size_k = self.micro_size_k
         k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
         wgmma_prefix = self.wgmma_prefix
@@
-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert m_dim >= 64 and (m_dim % 64) == 0, f"m_dim must be a multiple of 64, got {m_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"

176-219: Fix LBO/SBO units, per-operand element sizes, and descriptor pointer modes (SS path).

Currently mixes A’s element size for B, uses swizzle_atom_size() then >>4 (double-divide by 16), and passes access_ptr("w"). This will generate incorrect descriptors and offsets, especially for int8/fp8/tf32.

-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
-                                                                               elems_in_bytes)
-        a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim *
+                                                                                 a_elems_in_bytes)
+        a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 *
+                                                                                    a_elems_in_bytes)
@@
-            else:
+            else:
                 # MN Major
                 # LBO represents the distance between two atoms along the M dimension
                 # SBO represents the distance between two atoms along the K dimension
-                a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size()
-                a_stride_byte_offset = 8 * 64 * elems_in_bytes
+                a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size()
+                a_stride_byte_offset = 8 * 64 * a_elems_in_bytes
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+                                                                                 b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+                                                                                    b_elems_in_bytes)
         if not b_swizzle_mode.is_none():
@@
-            else:
-                # MN Major, K * N
-                # LBO represents the distance between two atoms along the N dimension
-                # SBO represents the distance between two atoms along the K dimension
-                b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
-                if b_n_axis_atoms <= 1:
-                    b_leading_byte_offset = 0
-                else:
-                    b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
-
-                if b_n_axis_atoms <= 1:
-                    b_stride_byte_offset = 8 * elems_in_bytes * n_dim
-                else:
-                    b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+            else:
+                # MN Major (N × K): use swizzle byte size for LBO; SBO spans N in bytes.
+                b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+                b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
-            desc_a = T.alloc_descriptor()
-            desc_b = T.alloc_descriptor()
-            T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
+            desc_a = T.alloc_descriptor()
+            desc_b = T.alloc_descriptor()
+            T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
                                     int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                    B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+                    B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes)
@@
-                                   accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4,
-                                   desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data,
+                                   accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4,
+                                   desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data,

Also applies to: 239-254


278-279: Add RS-side divisibility assert.

Mirror the SS check to avoid dropped tails.

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"

280-309: Mirror per-operand sizes and descriptor pointer mode (RS path).

Use b_elems_in_bytes and access_ptr("r"); fix LBO/SBO units for swizzle.

-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+                                                                                 b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+                                                                                    b_elems_in_bytes)
@@
-            else:
-                # MN Major
-                # LBO represents the distance between two atoms along the N dimension
-                # SBO represents the distance between two atoms along the K dimension
-                b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
-                if b_n_axis_atoms <= 1:
-                    b_leading_byte_offset = 0
-                else:
-                    b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
-
-                if b_n_axis_atoms <= 1:
-                    b_stride_byte_offset = 8 * elems_in_bytes * n_dim
-                else:
-                    b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+            else:
+                # MN Major
+                b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+                b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                        (B_offset * elems_in_bytes) >> 4,
+                        (B_offset * b_elems_in_bytes) >> 4,

Also applies to: 313-314, 332-332


324-325: Fix inverted K-major flags for RS.

RS uses the same a/b K-major semantics as SS.

-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,
🧹 Nitpick comments (6)
src/target/codegen_cuda.cc (1)

1363-1397: WGMMA-RS path looks good; minor naming/comment nit

Logic mirrors SS but correctly streams asm_code. Consider renaming A_layout/B_layout to a_is_k_major/b_is_k_major and fixing the arg comments (arg 0 is shape; dtype is not an arg) for clarity.

tilelang/intrinsics/wgmma_macro_generator.py (5)

104-109: Remove unused arg and make prefix computation explicit.

n_dim is unused; also guard inst_k with a sanity check (tf32 can be tricky).

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
         inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles
-        # 256 bits per instruction
-        inst_k = 256 // DataType(self.a_dtype).bits
+        # k derived from input dtype; ensure it's one of the valid WGMMA shapes.
+        bits = DataType(self.a_dtype).bits
+        assert 256 % bits == 0, f"Unsupported dtype bits for WGMMA prefix: {bits}"
+        inst_k = 256 // bits
         self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"
-        self._initialize_wgmma_prefix(self.n_dim)
+        self._initialize_wgmma_prefix()

Please verify tf32’s effective width in your stack (if tvm “tf32” reports 19 bits, switch to an explicit dtype→k map).

Also applies to: 94-95


221-234: Remove debug prints (noise in codegen path).

These spam stdout during TIR lowering; drop or guard behind a verbose flag.

-        print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}")
-        print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}")
-
-        print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}")
-        print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}")
-        print(f"m_dim: {m_dim}")
-        print(f"n_dim: {n_dim}")
-        print(f"k_dim: {k_dim}")
-        print(f"micro_size_k: {micro_size_k}")
-        print(f"a_leading_byte_offset: {a_leading_byte_offset}")
-        print(f"a_stride_byte_offset: {a_stride_byte_offset}")
-        print(f"b_leading_byte_offset: {b_leading_byte_offset}")
-        print(f"b_stride_byte_offset: {b_stride_byte_offset}")

342-347: Docstring mismatch: this is a load-layout helper.

Tweak wording to avoid confusion.

-        Create a layout function for storing MMA results into a fragment buffer.
+        Create a layout function for loading operand A into a fragment buffer.

364-367: Drop redundant import inside method.

is_fragment is already imported at module scope.

-        from tilelang.utils import is_fragment

143-143: Tidy exception messages per linter hint (TRY003).

Shorten or use a custom exception type; current messages are fine but flagged.

Also applies to: 384-384

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 51fcf15 and ce9f545.

📒 Files selected for processing (2)
  • src/target/codegen_cuda.cc (5 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
  • scope (674-678)
  • scope (674-674)
  • n (366-370)
  • n (366-366)
  • n (371-375)
  • n (371-371)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (8)
  • PrintWGMMAAssembly (1186-1257)
  • PrintWGMMAAssembly (1186-1195)
  • ParseMMAShape (140-148)
  • ParseMMAShape (140-140)
  • DTypeEnumToString (107-109)
  • DTypeEnumToString (107-107)
  • DTypeEnumToString (111-113)
  • DTypeEnumToString (111-111)
tilelang/language/builtin.py (2)
  • initialize_descriptor (355-386)
  • increase_descriptor_offset (389-411)
tilelang/intrinsics/wgmma_macro_generator.py (8)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/layout.py (3)
  • Layout (13-148)
  • is_equal (136-145)
  • inverse (125-134)
tilelang/layout/swizzle.py (3)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
tilelang/layout/fragment.py (4)
  • is_equal (209-213)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/language/allocate.py (1)
  • alloc_descriptor (129-135)
tilelang/language/builtin.py (1)
  • initialize_descriptor (355-386)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py

104-104: Unused method argument: n_dim

(ARG002)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


384-384: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Cursor Bugbot
🔇 Additional comments (5)
src/target/codegen_cuda.cc (5)

898-901: Descriptor refs: correct special‑casing for local.descriptor

Returning the variable directly for descriptor buffers is the right behavior and aligns with how the descriptor is declared.


1721-1721: LGTM

No issues spotted in tl_shuffle_elect emission.


1845-1845: Scope guard update is correct

Skipping the unsupported-scope check for local.descriptor is necessary after introducing descriptor allocations.


1723-1742: Descriptor intrinsics emission LGTM — confirm offset width

initialize_descriptor templating is correct. For increase_descriptor_offset(...), confirm the API expects a 32-bit offset; repository search did not locate the declaration, so either switch to std::int32_t (or a wider type) and add an explicit cast to avoid narrowing, or confirm plain int is intended. Location: src/target/codegen_cuda.cc:1739-1742


1809-1811: Descriptor allocation: ensure header is available

src/target/codegen_cuda.cc already writes #include <tl_templates/cuda/gemm.h> into the generated decls, but I couldn't find a definition for tl::GmmaDescriptor in the repo — ensure gemm.h (or a transitive include) defines tl::GmmaDescriptor; if it does not, add #include <tl_templates/cuda/common.h> in Finish() so the generated file includes the descriptor definition.

Comment on lines +83 to +85
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default warp_row_tiles conflicts with your own assertion (must be ≥16).

Set the default to 16 to avoid constructor-time assertion failures via parent init.

-        warp_row_tiles: int = 8,
+        warp_row_tiles: int = 16,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
warp_row_tiles: int = 16,
warp_col_tiles: int = 8,
chunk: int = 16,
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 83 to 85, the
default value for warp_row_tiles is 8 which conflicts with the class assertion
requiring warp_row_tiles >= 16 and can trigger assertion failures during parent
initialization; change the default warp_row_tiles to 16 (update the parameter
default in the function/class signature) so it satisfies the assertion by
default and avoid constructor-time failures.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/gemm.cc (1)

780-795: TCGEN5MMA layout still passes old k-factor integers

makeGemmABLayoutSm100 now expects a boolean k_inner, but we still feed trans_A ? 1 : 2 / trans_B ? 2 : 1. Both expressions evaluate to true, so the CDNA/Sm100 shared-memory layouts always assume K-inner ordering, breaking transposed operands. Pass explicit booleans that preserve the original semantics.

-      results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
-                                           mat_continuous, A->dtype.bits(),
-                                           trans_A ? 1 : 2));
+      results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
+                                           mat_continuous, A->dtype.bits(),
+                                           !trans_A));
@@
-      results.Set(B,
-                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
-                                        B->dtype.bits(), trans_B ? 2 : 1));
+      results.Set(B,
+                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
+                                        B->dtype.bits(), trans_B));
♻️ Duplicate comments (4)
src/tl_templates/cuda/common.h (1)

254-263: operator+ should add 16-byte quanta, not raw bytes

The descriptor stores offsets in 16-byte units. Adding a raw byte count corrupts the address field. Apply the same >> 4 used in increase_descriptor_offset so that operator+ preserves the descriptor encoding.

-    ret.reg32_[0] = reg32_[0] + uint32_t(offset);
+    ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
tilelang/language/builtin.py (2)

382-383: Critical: Fix operator precedence to prevent AttributeError.

The condition has ambiguous operator precedence. Python evaluates and before or, so if descriptor is a BufferLoad (not a Buffer), the expression descriptor.shape[0] will be evaluated even after the isinstance(descriptor, Buffer) check fails, causing an AttributeError.

As per coding guidelines

Apply this diff to fix the precedence:

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
         raise ValueError("Descriptor must be a 1D buffer of size 1.")

408-409: Critical: Fix operator precedence to prevent AttributeError.

Same operator precedence issue as in initialize_descriptor. The condition will incorrectly evaluate descriptor.shape[0] for BufferLoad instances.

As per coding guidelines

Apply this diff:

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
         raise ValueError("Descriptor must be a 1D buffer of size 1.")
src/target/codegen_cuda.cc (1)

1565-1592: Use validated assembly from PrintWGMMAAssembly instead of ad-hoc pattern replacement.

The code computes asm_code via PrintWGMMAAssembly (line 1567) but then ignores it, creating a separate wgmma_asm_code string with pattern replacement. The past review comment suggested using the validated asm_code directly, which is what the RS path does correctly (line 1628).

Apply this diff to use the validated assembly:

     const bool a_is_shared = true;
     this->PrintIndent();
     std::string asm_code = PrintWGMMAAssembly(
         shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset,
         b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
         a_is_shared, "", "", "", false);
-    auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
-    std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
-    // replace patterns
-    tl::codegen::Replacer replacer;
-    replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype));
-    replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype));
-    replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype));
-    replacer.register_rule("(M)", std::to_string(m));
-    replacer.register_rule("(N)", std::to_string(n));
-    replacer.register_rule("(K)", std::to_string(k));
-    replacer.register_rule("(tnspA)", a_is_k_major? "false": "true");
-    replacer.register_rule("(tnspB)", b_is_k_major? "false": "true");
-    replacer.register_rule("(scaleA)", scale_in_a? "1": "-1");
-    replacer.register_rule("(scaleB)", scale_in_b? "1": "-1");
-    replacer.register_rule("(desc_a)", a_desc);
-    replacer.register_rule("(A_offset)", A_offset);
-    replacer.register_rule("(desc_b)", b_desc);
-    replacer.register_rule("(B_offset)", B_offset);
-    replacer.register_rule("(C)", c_ref + " + " + c_offset);
-    replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
-    wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
-    this->stream << wgmma_asm_code;
+    this->stream << asm_code;
🧹 Nitpick comments (1)
src/layout/gemm_layouts.cc (1)

741-765: Update error message to reference k_inner instead of kfactor.

Line 764's error message still references "kfactor=" which is inconsistent with the new boolean parameter name.

Apply this diff:

     ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
               << ", continuous=" << mat_continuous
-              << ", element_size=" << element_size << ", kfactor=" << kfactor;
+              << ", element_size=" << element_size << ", k_inner=" << k_inner;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ce9f545 and 70699a9.

📒 Files selected for processing (16)
  • src/layout/gemm_layouts.cc (4 hunks)
  • src/layout/layout.cc (4 hunks)
  • src/layout/layout.h (2 hunks)
  • src/op/builtin.cc (2 hunks)
  • src/op/builtin.h (2 hunks)
  • src/op/gemm.cc (7 hunks)
  • src/op/gemm_py.cc (4 hunks)
  • src/op/gemm_py.h (1 hunks)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/tl_templates/cuda/common.h (4 hunks)
  • src/tl_templates/cuda/gemm.h (1 hunks)
  • src/transform/storage_rewrite.cc (2 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/allocate.py (1 hunks)
  • tilelang/language/builtin.py (2 hunks)
  • tilelang/layout/fragment.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/tl_templates/cuda/gemm.h
  • src/transform/storage_rewrite.cc
🧰 Additional context used
🧬 Code graph analysis (11)
tilelang/layout/fragment.py (1)
tilelang/layout/layout.py (4)
  • get_input_shape (59-68)
  • get_output_shape (70-79)
  • index (48-57)
  • is_equal (136-145)
tilelang/language/allocate.py (2)
src/transform/storage_rewrite.cc (4)
  • dtype (696-702)
  • dtype (696-696)
  • scope (674-678)
  • scope (674-674)
tilelang/language/ast/ir.py (1)
  • alloc_buffer (441-508)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
  • trans_A (46-47)
  • trans_B (50-51)
  • A (67-68)
  • B (71-72)
src/layout/gemm_layouts.cc (2)
  • makeGemmABLayout (720-739)
  • makeGemmABLayout (720-721)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
  • tvm (13-442)
tilelang/language/ast/ir.py (1)
  • evaluate (1319-1331)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
src/tl_templates/cuda/common.h (2)
src/target/ptx.h (1)
  • DataType (45-91)
tilelang/language/builtin.py (2)
  • initialize_descriptor (360-391)
  • increase_descriptor_offset (394-416)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (6)
  • makeGemmABLayoutHopper (741-765)
  • makeGemmABLayoutHopper (741-742)
  • makeGemmABLayoutSm100 (767-787)
  • makeGemmABLayoutSm100 (767-768)
  • makeGemmABLayoutCDNA (789-792)
  • makeGemmABLayoutCDNA (789-790)
src/op/builtin.cc (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/builtin.py (2)
  • initialize_descriptor (360-391)
  • increase_descriptor_offset (394-416)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (10)
  • makeGemmABLayout (720-739)
  • makeGemmABLayout (720-721)
  • makeGemmABLayoutHopper (741-765)
  • makeGemmABLayoutHopper (741-742)
  • makeFullBankSwizzleLayout (421-437)
  • makeFullBankSwizzleLayout (421-421)
  • makeHalfBankSwizzleLayout (402-418)
  • makeHalfBankSwizzleLayout (402-402)
  • makeQuarterBankSwizzleLayout (382-399)
  • makeQuarterBankSwizzleLayout (382-383)
src/op/builtin.h (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/language/builtin.py (2)
  • initialize_descriptor (360-391)
  • increase_descriptor_offset (394-416)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
  • scope (674-678)
  • scope (674-674)
  • n (366-370)
  • n (366-366)
  • n (371-375)
  • n (371-371)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (8)
  • PrintWGMMAAssembly (1186-1257)
  • PrintWGMMAAssembly (1186-1195)
  • ParseMMAShape (140-148)
  • ParseMMAShape (140-140)
  • DTypeEnumToString (107-109)
  • DTypeEnumToString (107-107)
  • DTypeEnumToString (111-113)
  • DTypeEnumToString (111-111)
tilelang/language/builtin.py (2)
  • initialize_descriptor (360-391)
  • increase_descriptor_offset (394-416)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
  • alloc_descriptor (158-164)
🪛 Ruff (0.13.3)
tilelang/language/builtin.py

380-380: Avoid specifying long messages outside the exception class

(TRY003)


382-382: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


383-383: Avoid specifying long messages outside the exception class

(TRY003)


406-406: Avoid specifying long messages outside the exception class

(TRY003)


408-408: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


409-409: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/__init__.py

47-47: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🔇 Additional comments (14)
src/op/gemm_py.cc (3)

108-110: LGTM: Compiler appeasement for unreachable path.

The unreachable return statement after ICHECK(0) is a standard pattern to satisfy compiler requirements about control flow paths returning a value.


283-290: LGTM: New public FFI binding for GetGemmInst.

The new tl.GemmPyGemmInst FFI binding correctly exposes the GetGemmInst method, allowing runtime instruction selection logic from Python.


230-232: FFI signature verified All lower methods now include layout_map as the first argument.

src/op/builtin.h (2)

219-239: LGTM: New WGMMA intrinsic declarations.

The new ptx_wgmma_ss and ptx_wgmma_rs declarations are properly documented with detailed parameter descriptions and correctly exported via TVM_DLL.


423-439: LGTM: New descriptor management intrinsics.

The initialize_descriptor and increase_descriptor_offset declarations are well-documented and correctly positioned after tl_shuffle_elect as noted in the enriched summary.

src/target/codegen_cuda.cc (4)

1071-1074: LGTM: Descriptor scope handling.

Treating "local.descriptor" scope the same as "local.var" is correct for descriptor buffer references.


1593-1628: LGTM: Correct WGMMA-RS emission.

The RS path correctly uses the validated asm_code from PrintWGMMAAssembly and writes it directly to the stream, consistent with the past review guidance.


1953-1973: LGTM: Descriptor operation emission.

The initialize_descriptor and increase_descriptor_offset handlers correctly emit template instantiations with the provided parameters.


2116-2117: LGTM: Descriptor allocation.

Declaring tl::GmmaDescriptor vid; for "local.descriptor" scope is correct for WGMMA descriptor buffers.

src/layout/gemm_layouts.cc (2)

578-587: LGTM: Boolean k_inner parameter simplifies API.

The change from int kfactor to bool k_inner makes the API more intuitive. The condition if (k_inner) replacing if (kfactor == 2) is clearer about intent.


720-739: LGTM: Well-documented signature change with clear semantics.

The function documentation has been updated to explain the boolean k_inner parameter's meaning for different element sizes. The implementation correctly uses !k_inner and k_inner for conditional paths.

src/layout/layout.cc (3)

487-491: LGTM: New equality check for Layout objects.

The tl.Layout_is_equal FFI binding correctly exposes the IsEqual method for Python-side layout comparison.


500-504: LGTM: New equality check for Fragment objects.

The tl.Fragment_is_equal FFI binding correctly exposes the IsEqual method for Python-side fragment comparison.


521-550: LGTM: Enhanced swizzled layout construction APIs.

The updated make_swizzled_layout with k_inner and allow_pad parameters, along with the new specialized constructors (make_wgmma_swizzled_layout, make_full_bank_swizzled_layout, etc.), provide a cleaner public API surface with better dispatch logic.

Comment on lines +303 to +312
template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type;
descriptor.bitfield.base_offset_ = 0;
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Descriptor fields must be encoded in 16-byte quanta

start_address_ is encoded as addr >> 4, but leading_byte_offset_ and stride_byte_offset_ are written with raw byte counts. The PTX descriptor spec strips the lower 4 bits for all three fields, so encoding plain bytes produces values 16× too large, breaking every descriptor consumed by WGMMA/UTCMMA. Shift the offsets by 4 (and consider asserting 16-byte alignment) before storing them.

-  descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
-  descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
+  descriptor.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
+  descriptor.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type;
descriptor.bitfield.base_offset_ = 0;
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type;
descriptor.bitfield.base_offset_ = 0;
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 303 to 312, the
initialize_descriptor writes leading_byte_offset_ and stride_byte_offset_ as raw
bytes while the PTX descriptor expects 16-byte quanta like start_address_ (addr
>> 4); change the writes to store (leading_byte_offset >> 4) and
(stride_byte_offset >> 4) and (optionally) add an assertion or static_assert
that both offsets and start_address are 16-byte aligned before shifting to catch
misalignment.

…r improved clarity and consistency. Update includes necessary changes in error messages for Hopper and Sm100 layouts. Additionally, include a new header for CUTE utilities in common.h.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/layout/gemm_layouts.cc (2)

741-765: Fix boolean k_inner arguments in gemm_sp call sites
In src/op/gemm_sp.cc at lines 240 and 254, calls to makeGemmABLayoutHopper still pass trans_A ? 1 : 2 and trans_B ? 2 : 1 (both convert to true). Update them to pass the boolean flags directly—e.g. replace trans_A ? 1 : 2 with trans_A and trans_B ? 2 : 1 with trans_B.


767-787: Convert k_inner arguments to bool in gemm.cc

  • src/op/gemm.cc:783-785: replace trans_A ? 1 : 2 with !trans_A so k_inner is true when A is not transposed
  • src/op/gemm.cc:793-795: replace trans_B ? 2 : 1 with trans_B so k_inner matches original intent
- makeGemmABLayoutSm100(mat_stride, mat_continuous, mat_continuous, A->dtype.bits(), trans_A ? 1 : 2);
+ makeGemmABLayoutSm100(mat_stride, mat_continuous, mat_continuous, A->dtype.bits(), !trans_A);

- makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, B->dtype.bits(), trans_B ? 2 : 1);
+ makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, B->dtype.bits(), trans_B);
♻️ Duplicate comments (3)
src/tl_templates/cuda/common.h (3)

304-320: Validate descriptor parameters before writing hardware fields.

These functions write descriptor bitfields with no checks. Previous reviews requested assertions for layout_type range (0-3), 16-byte alignment of addresses/offsets, and overflow guards in increase_descriptor_offset.

Consider adding:

  • Static or runtime assertions for layout_type ∈ [0, 3]
  • Alignment checks: (uintptr_t)start_address & 0xF == 0
  • Verification that offsets are multiples of 16 and fit bitfield widths
  • Overflow guard when adding offset >> 4 to descriptor.reg32_[0]

260-265: Critical encoding issue persists: operator+ must shift offset by 4.

The descriptor encodes addresses/offsets in 16-byte quanta (4 LSBs stripped), but operator+ at line 262 adds raw bytes without shifting. This was flagged in previous reviews and remains unaddressed.

Apply this diff:

-    ret.reg32_[0] = reg32_[0] + uint32_t(offset);
+    ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);

Note the inconsistency: increase_descriptor_offset at line 319 correctly shifts >> 4, but this operator does not.


304-314: Critical encoding issue persists: offset parameters must be shifted by 4.

Lines 312-313 write leading_byte_offset and stride_byte_offset as raw values, but the PTX descriptor expects 16-byte quanta like start_address_ (shifted >> 4 at line 309). This was flagged in previous reviews and remains unaddressed.

Apply this diff:

-  descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
-  descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
+  descriptor.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
+  descriptor.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 70699a9 and 2dbaccc.

📒 Files selected for processing (2)
  • src/layout/gemm_layouts.cc (6 hunks)
  • src/tl_templates/cuda/common.h (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/common.h (3)
src/target/ptx.h (1)
  • DataType (45-91)
src/tl_templates/cuda/ldsm.h (12)
  • void (7-14)
  • void (16-23)
  • void (25-33)
  • void (35-42)
  • void (44-52)
  • void (54-62)
  • void (64-70)
  • void (72-79)
  • void (81-89)
  • void (91-98)
  • void (100-108)
  • void (110-119)
tilelang/language/builtin.py (2)
  • initialize_descriptor (360-391)
  • increase_descriptor_offset (394-416)
🔇 Additional comments (5)
src/tl_templates/cuda/common.h (1)

8-8: LGTM!

The cute library includes and using declaration are appropriate for the new descriptor functionality.

Also applies to: 17-17

src/layout/gemm_layouts.cc (4)

181-181: Minor cleanup.

The blank line addition appears intentional. No concerns.


690-719: LGTM! Well-documented parameter change.

The documentation clearly explains the new k_inner boolean parameter and its impact on layout selection for different data types. The examples for fp64 and int8 are helpful.


720-739: Approve: all callers updated to bool k_inner
Verified every call to makeGemmABLayout (layout.cc, gemm_layouts.cc, op/gemm.cc) now passes a boolean for k_inner; no further changes needed.


578-587: All callers updated to use boolean arguments
Verified in src/op/gemm.cc (lines 671, 684) that both is_a and k_inner parameters are passed as booleans; no int arguments remain.

Comment on lines +181 to +205
enum class DataType : int {
kInt4 = 0,
kUInt4 = 1,
kInt8 = 2,
kUInt8 = 3,
kInt16 = 4,
kUInt16 = 5,
kInt32 = 6,
kUInt32 = 7,
kInt64 = 8,
kUInt64 = 9,
kFloat8_e4m3 = 10,
kFloat8_e5m2 = 11,
kFloat16 = 12,
kBFloat16 = 13,
kFloat16x2 = 14,
kFloat32 = 15,
kTensorFloat32 = 16,
kFloat64 = 17,
kBit1 = 18,
kBit8 = 19,
kBit16 = 20,
kBit32 = 21,
kBit64 = 22
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Eliminate DataType duplication.

The DataType enum is duplicated in src/target/ptx.h (lines 44-90). Maintain a single definition in ptx.h and reference it here via include/using to avoid maintenance drift.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 181 to 205, remove the duplicated
DataType enum block and instead include the single canonical definition from
src/target/ptx.h; add an #include for "src/target/ptx.h" at the top of this file
(or the appropriate relative include path) and either reference the enum as the
qualified name from ptx.h or add a single using declaration (e.g., using
::DataType; or using target::DataType;) so this file reuses the ptx.h definition
rather than redefining it.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/gemm.cc (1)

145-154: Bounds-check optional args to avoid out-of-bounds access.

args[16], [17], [18] are accessed unconditionally; if not present, UB/crash.

-  node->mbarptr = args[16];
-  if (node->mbarptr.as<CallNode>()) {
-    node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
-  } else {
-    node->mbar = std::nullopt;
-  }
-  node->C_coords = Array<PrimExpr>(
-      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
+  if (args.size() > 16) {
+    node->mbarptr = args[16];
+    if (node->mbarptr.as<CallNode>()) {
+      node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
+    } else {
+      node->mbar = std::nullopt;
+    }
+  } else {
+    node->mbarptr = PrimExpr();
+    node->mbar = std::nullopt;
+  }
+  ICHECK(args.size() > 18)
+      << "Gemm expects C_coords provided as args[17], args[18]";
+  node->C_coords = Array<PrimExpr>({args[17].as<PrimExpr>().value(),
+                                    args[18].as<PrimExpr>().value()});
♻️ Duplicate comments (6)
tilelang/intrinsics/wgmma_macro_generator.py (5)

139-149: Remove debug prints.

These pollute logs; drop them or gate via a logger.

-        print(f"_determinate_swizzle_mode mat_stride: {mat_stride}, mat_continuous: {mat_continuous}, element_size: {element_size}")
@@
-        print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}")
-        print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}")
-        print(f"b_leading_byte_offset: {b_leading_byte_offset >> 4}")
-        print(f"b_stride_byte_offset: {b_stride_byte_offset >> 4}")
-
-        print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}")
-        print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}")

Also applies to: 241-248


176-178: Add k-dimension divisibility asserts to prevent silent tails.

Ensure k_dim % micro_size_k == 0 in SS and RS.

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"
@@
-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"

Also applies to: 295-296


184-217: Descriptor math uses wrong element size and pointer mode; fix per-operand bytes and offsets.

Use separate a/b byte sizes, correct LBO/SBO units, and read-only access_ptr; scale A/B offsets with respective sizes.

-        elems_in_bits = DataType(self.a_dtype).bits
-        elems_in_bytes = elems_in_bits // 8
-        
-        a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
-        b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
+
+        a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // a_elems_in_bytes
+        b_swizzle_atom_elems = (
+            n_dim if b_swizzle_mode.is_none()
+            else b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes
+        )
@@
-        a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
-                                                                               elems_in_bytes)
-        a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim * a_elems_in_bytes)
+        a_stride_byte_offset  = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 * a_elems_in_bytes)
@@
-            if a_is_k_major:
-                a_leading_byte_offset = 16
-                a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size()
+            if a_is_k_major:
+                a_leading_byte_offset = 16
+                a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size()
             else:
@@
-                if a_m_axis_atoms <= 1:
-                    a_leading_byte_offset = 0
-                else:
-                    a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+                if a_m_axis_atoms <= 1:
+                    a_leading_byte_offset = 0
+                else:
+                    a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size()
@@
-                if a_m_axis_atoms <= 1:
-                    a_stride_byte_offset = 8 * elems_in_bytes * m_dim
-                else:
-                    a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
+                if a_m_axis_atoms <= 1:
+                    a_stride_byte_offset = 8 * a_elems_in_bytes * m_dim
+                else:
+                    a_stride_byte_offset = 8 * a_elems_in_bytes * a_swizzle_atom_elems
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (
-            0 if n_dim == 8 else (8 * 8 * elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes)
+        b_stride_byte_offset  = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (
+            0 if n_dim == 8 else (8 * 8 * b_elems_in_bytes)
         )
@@
-            if b_is_k_major:
-                b_leading_byte_offset = 16
-                b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
+            if b_is_k_major:
+                b_leading_byte_offset = 16
+                b_stride_byte_offset  = 8 * b_swizzle_mode.swizzle_byte_size()
             else:
@@
-                if b_n_axis_atoms <= 1:
-                    b_leading_byte_offset = 0
-                else:
-                    b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
+                if b_n_axis_atoms <= 1:
+                    b_leading_byte_offset = 0
+                else:
+                    b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
                 if b_n_axis_atoms <= 1:
-                    b_stride_byte_offset = 8 * elems_in_bytes * n_dim
+                    b_stride_byte_offset = 8 * b_elems_in_bytes * n_dim
                 else:
-                    b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
+                    b_stride_byte_offset = 8 * b_elems_in_bytes * b_swizzle_atom_elems
@@
-            desc_a = T.alloc_descriptor()
-            desc_b = T.alloc_descriptor()
-            T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
+            desc_a = T.alloc_descriptor()
+            desc_b = T.alloc_descriptor()
+            T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
                                     int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                    T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major,
+                    T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major,
                                    b_is_k_major, a_dtype_abbrv, b_dtype_abbrv,
-                                   accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4,
-                                   desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data,
+                                   accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4,
+                                   desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data,
                                    C_offset, scale_out, scale_in_a, scale_in_b)

Also applies to: 217-241, 258-261, 267-271


303-307: RS path: mirror per-operand sizes, use read-only descriptors, and fix K-major flags.

Correct B descriptor math/units, pointer mode, and flags; scale B offset with B’s element size.

-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes)
+        b_stride_byte_offset  = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * b_elems_in_bytes)
@@
-                b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+                b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes)
@@
-                    b_leading_byte_offset = 8 * b_swizzle_mode.swizzle_atom_size() * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+                    b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
@@
-                    b_stride_byte_offset = 8 * elems_in_bytes * n_dim
+                    b_stride_byte_offset = 8 * b_elems_in_bytes * n_dim
                 else:
-                    b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+                    b_stride_byte_offset = 8 * b_elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes)
@@
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,
@@
-                        (B_offset * elems_in_bytes) >> 4,
+                        (B_offset * b_elems_in_bytes) >> 4,

Also applies to: 316-326, 330-352, 341-343, 349-349


84-85: Default warp_row_tiles conflicts with your assertion; set to 16.

Avoid constructor-time assertion failures.

-        warp_row_tiles: int = 8,
+        warp_row_tiles: int = 16,
src/op/gemm.cc (1)

112-115: kPack doc vs. validation mismatch; enforce a single rule (1 only).

Doc says “kPack must be 1” but runtime allows 1 or 2. Align with doc/PR intent by enforcing 1.

-  if (args.size() > 14) {
-    node->kPack = args[14].as<IntImm>().value()->value;
-    if (node->kPack != 1 && node->kPack != 2) {
-      ICHECK(false) << "kPack must be 1 or 2";
-    }
-  }
+  if (args.size() > 14) {
+    node->kPack = args[14].as<IntImm>().value()->value;
+    if (node->kPack != 1) {
+      ICHECK(false) << "kPack must be 1";
+    }
+  }

Also applies to: 136-141

🧹 Nitpick comments (10)
src/op/gemm.cc (3)

103-111: Update constructor arg docs to include mbarptr and C_coords.

Docs omit mbarptr (descriptor) and C_coords now required later.

- *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
+ *     [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
  *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
  *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
- *      (optional) kPack (Int), (optional) wg_wait (Int)]
+ *      (optional) kPack (Int), (optional) wg_wait (Int),
+ *      (optional) mbarptr (Handle), C_i (PrimExpr), C_j (PrimExpr)]

481-492: Harden arch parsing to avoid std::stoi exceptions.

Guard empty/malformed “sm_” strings.

-  if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
-  } else {
+  if (arch.rfind("sm_", 0) == 0) {
+    std::string num = arch.substr(3);
+    if (!num.empty() && std::isdigit(num[0])) {
+      arch_int = std::stoi(num);
+    } else {
+      arch_int = 0;
+    }
+  } else {
     arch_int = 0;
   }

760-761: Remove noisy LOG(INFO) in hot path (use VLOG or drop).

Unconditional INFO logs during layout inference are noisy.

-    LOG(INFO) << "gemm_inst: " << (int)gemm_inst << ", trans_B: " << trans_B;
+    // VLOG(1) << "gemm_inst: " << (int)gemm_inst << ", trans_B: " << trans_B;
tilelang/layout/__init__.py (1)

6-13: Drop unnecessary # noqa: F401 comments (RUF100).

Ruff flags these as unused. Remove or switch to explicit all exports.

-from .swizzle import (
-    make_swizzled_layout,  # noqa: F401
-    make_wgmma_swizzled_layout,  # noqa: F401
-    make_full_bank_swizzled_layout,  # noqa: F401
-    make_half_bank_swizzled_layout,  # noqa: F401
-    make_quarter_bank_swizzled_layout,  # noqa: F401
-    make_linear_layout,  # noqa: F401
-)
+from .swizzle import (
+    make_swizzled_layout,
+    make_wgmma_swizzled_layout,
+    make_full_bank_swizzled_layout,
+    make_half_bank_swizzled_layout,
+    make_quarter_bank_swizzled_layout,
+    make_linear_layout,
+)
tilelang/layout/swizzle.py (1)

10-18: Type Optional properly and remove debug prints; trim exception messages.

  • Use Optional[int] for continuity.
  • Remove print in make_wgmma_swizzled_layout.
  • Shorten ValueError messages.
+from typing import Optional
@@
-def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
-                               continuity: int = None,
+def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
+                               continuity: Optional[int] = None,
                                k_major: bool = True):
@@
-    print(f"make_wgmma_swizzled_layout: {buffer.shape[0]}, {buffer.shape[1]}, {continuity}, {tvm.DataType(buffer.dtype).bits}, {k_major}")
     return _ffi_api.make_wgmma_swizzled_layout(
@@
-    else:
-        raise ValueError(f"Invalid arguments: {args}")
+    else:
+        raise ValueError("Invalid arguments")
@@
-    else:
-        raise ValueError(f"Invalid arguments: {args}")
+    else:
+        raise ValueError("Invalid arguments")
@@
-    else:
-        raise ValueError(f"Invalid arguments: {args}")
+    else:
+        raise ValueError("Invalid arguments")
@@
-    else:
-        raise ValueError(f"Invalid arguments: {args}")
+    else:
+        raise ValueError("Invalid arguments")

Also applies to: 22-35, 48-60, 73-86, 98-110, 120-130

tilelang/tileop/gemm/gemm_wgmma.py (2)

13-61: GemmWGMMA layout inference looks good; minor nits.

  • Good: k_major/continuity mirror Hopper path.
  • Consider shorter error text.
-        else:
-            raise ValueError(
-                f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
+        else:
+            raise ValueError("Unsupported WGMMA layout combination")

Also, confirm from tilelang import tvm as tvm is intentional vs import tvm.


65-127: Lowering path is clean; minor error-msg nit.

-        raise ValueError(
-            f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
+        raise ValueError("Unsupported WGMMA lowering combination")
tilelang/intrinsics/wgmma_macro_generator.py (2)

140-145: Collapse trivial condition (SIM114).

Combine NONE/linear checks to satisfy linter.

-        if layout is None:
-            return SwizzleMode.NONE
-        elif layout.is_equal(make_linear_layout(mat_stride, mat_continuous)):
+        if layout is None or layout.is_equal(make_linear_layout(mat_stride, mat_continuous)):
             return SwizzleMode.NONE

105-110: Remove unused method arg n_dim.

Not used; simplify signature.

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self):
         inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles
src/layout/gemm_layouts.cc (1)

388-388: Remove or gate debug LOG statement.

This LOG(INFO) in makeQuarterBankSwizzleLayout will generate output on every call, potentially flooding logs in production. Debug logging in hot paths should either be removed before release or gated behind a debug flag.

Consider one of these approaches:

-  LOG(INFO) << "makeQuarterBankSwizzleLayout: " << stride << ", " << continuous << ", " << element_size;
+  DLOG(INFO) << "makeQuarterBankSwizzleLayout: " << stride << ", " << continuous << ", " << element_size;

Or remove entirely if not needed for debugging.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2dbaccc and d2db013.

📒 Files selected for processing (7)
  • src/layout/gemm_layouts.cc (7 hunks)
  • src/layout/layout.cc (3 hunks)
  • src/op/gemm.cc (7 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
  • tilelang/layout/__init__.py (1 hunks)
  • tilelang/layout/swizzle.py (1 hunks)
  • tilelang/tileop/gemm/gemm_wgmma.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/layout/__init__.py (1)
tilelang/layout/swizzle.py (6)
  • make_swizzled_layout (10-18)
  • make_wgmma_swizzled_layout (22-35)
  • make_full_bank_swizzled_layout (40-60)
  • make_half_bank_swizzled_layout (65-85)
  • make_quarter_bank_swizzled_layout (90-110)
  • make_linear_layout (112-130)
tilelang/layout/swizzle.py (1)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
  • trans_A (46-47)
  • trans_B (50-51)
  • A (67-68)
  • B (71-72)
src/layout/gemm_layouts.cc (4)
  • makeGemmABLayout (721-740)
  • makeGemmABLayout (721-722)
  • makeGemmABLayoutHopper (742-773)
  • makeGemmABLayoutHopper (742-743)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (12)
  • makeGemmABLayout (721-740)
  • makeGemmABLayout (721-722)
  • makeGemmABLayoutHopper (742-773)
  • makeGemmABLayoutHopper (742-743)
  • makeFullBankSwizzleLayout (422-438)
  • makeFullBankSwizzleLayout (422-422)
  • makeHalfBankSwizzleLayout (403-419)
  • makeHalfBankSwizzleLayout (403-403)
  • makeQuarterBankSwizzleLayout (382-400)
  • makeQuarterBankSwizzleLayout (382-383)
  • makeGemmLayoutLinear (492-496)
  • makeGemmLayoutLinear (492-492)
tilelang/tileop/gemm/gemm_wgmma.py (6)
tilelang/tileop/gemm/gemm_base.py (20)
  • GemmBase (12-120)
  • infer_layout (15-16)
  • policy (119-120)
  • M (34-35)
  • N (38-39)
  • in_dtype (54-56)
  • accum_dtype (59-60)
  • trans_A (46-47)
  • trans_B (50-51)
  • chunk (63-64)
  • is_gemm_ss (21-22)
  • K (42-43)
  • A (67-68)
  • B (71-72)
  • C (75-76)
  • is_gemm_rs (27-28)
  • lower (18-19)
  • clear_accum (107-108)
  • is_gemm_sr (24-25)
  • is_gemm_rr (30-31)
tilelang/layout/swizzle.py (1)
  • make_wgmma_swizzled_layout (22-35)
tilelang/intrinsics/wgmma_macro_generator.py (6)
  • TensorCoreIntrinEmitter (64-525)
  • make_mma_store_layout (471-525)
  • make_mma_load_layout (359-469)
  • _assign_a_shared_layout (97-99)
  • _assign_b_shared_layout (101-103)
  • wgmma (153-273)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
tilelang/tileop/gemm/gemm_mma.py (7)
  • infer_layout (15-58)
  • is_gemm_ss (204-205)
  • is_gemm_rs (210-211)
  • lower (60-202)
  • _gemm_ssr (95-120)
  • is_gemm_sr (207-208)
  • is_gemm_rr (213-214)
tilelang/intrinsics/wgmma_macro_generator.py (5)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/swizzle.py (4)
  • make_full_bank_swizzled_layout (40-60)
  • make_half_bank_swizzled_layout (65-85)
  • make_quarter_bank_swizzled_layout (90-110)
  • make_linear_layout (112-130)
tilelang/language/builtin.py (1)
  • initialize_descriptor (360-391)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
🪛 GitHub Actions: CI Test on AMD
tilelang/intrinsics/wgmma_macro_generator.py

[error] 142-145: SIM114 Combine if branches using logical or operator

🪛 GitHub Actions: CI Test on Metal
tilelang/intrinsics/wgmma_macro_generator.py

[error] 142-142: SIM114 Combine if branches using logical or operator

🪛 Ruff (0.13.3)
tilelang/layout/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/swizzle.py

23-23: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


55-55: Avoid specifying long messages outside the exception class

(TRY003)


80-80: Avoid specifying long messages outside the exception class

(TRY003)


105-105: Avoid specifying long messages outside the exception class

(TRY003)


126-126: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_wgmma.py

62-63: Avoid specifying long messages outside the exception class

(TRY003)


125-126: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/wgmma_macro_generator.py

105-105: Unused method argument: n_dim

(ARG002)


151-151: Avoid specifying long messages outside the exception class

(TRY003)


401-401: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task
🔇 Additional comments (6)
src/layout/gemm_layouts.cc (2)

709-740: LGTM: k_inner refactor is well-documented.

The updated documentation clearly explains the semantics of the k_inner parameter for both fp64 and int8 cases, with inline comments reinforcing the layout selection logic. The implementation correctly uses k_inner to select between KxN and NxK layouts.


580-588: All makeGemmVoltaABLayout call sites now pass a boolean k_inner. Only two calls remain (in src/op/gemm.cc lines 671 and 685), each using an explicit bool argument—no numeric k-factor usages.

src/layout/layout.cc (4)

487-491: LGTM: Layout equality check.

The FFI binding correctly delegates to LayoutNode::IsEqual() for comparing two Layout objects.


500-504: LGTM: Fragment equality check.

The FFI binding correctly delegates to FragmentNode::IsEqual() for comparing two Fragment objects.


532-554: LGTM: New layout constructor bindings.

The new FFI bindings (make_wgmma_swizzled_layout, make_full_bank_swizzled_layout, make_half_bank_swizzled_layout, make_quarter_bank_swizzled_layout, make_linear_layout) correctly expose the corresponding C++ functions with appropriate parameter passing.


521-531: Dispatch logic verified: all existing call sites rely on the default allow_pad=true, preserving previous behavior; explicit allow_pad=false now correctly routes to makeGemmABLayoutHopper.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/op/gemm.cc (2)

145-153: Potential args out-of-bounds access (args[16..18])

mbarptr = args[16] and C_coords from args[17], args[18] are read without size checks. If frontend omits these, this will crash.

-  node->mbarptr = args[16];
+  ICHECK_GT(args.size(), 18) << "Gemm expects mbar pointer and C_coords at indices 16..18";
+  node->mbarptr = args[16];
   if (node->mbarptr.as<CallNode>()) {
     node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
   } else {
     node->mbar = std::nullopt;
   }
-  node->C_coords = Array<PrimExpr>(
-      {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
+  node->C_coords = Array<PrimExpr>({args[17], args[18]});

784-787: Wrong argument type passed to makeGemmABLayoutSm100 (bool API coerced from int)

makeGemmABLayoutSm100(..., bool k_inner) now takes a bool, but calls still pass trans_A ? 1 : 2 and trans_B ? 2 : 1. Any non‑zero coerces to true, breaking K‑inner semantics.

-      results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
-                                           mat_continuous, A->dtype.bits(),
-                                           trans_A ? 1 : 2));
+      results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
+                                           mat_continuous, A->dtype.bits(),
+                                           !trans_A));
@@
-      results.Set(B,
-                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
-                                        B->dtype.bits(), trans_B ? 2 : 1));
+      results.Set(B,
+                  makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
+                                        B->dtype.bits(), trans_B));

Also applies to: 794-796

♻️ Duplicate comments (15)
src/tl_templates/cuda/instruction/wgmma.h (1)

15-23: Remove debug code before merging.

The fallback template contains debug printf statements and a commented-out static_assert that should be removed or uncommented before merging to production. This was previously flagged.

src/op/gemm.cc (1)

112-115: Doc vs. code mismatch for kPack

Note says “kPack must be 1”, but constructor accepts 1 or 2 (Line 139) and errors “kPack must be 1 or 2”. Align doc or code.

- * @note If `kPack` is provided it must be 1; otherwise the constructor
- *       fails with an ICHECK (runtime assertion). No other validation is
- *       performed here.
+ * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
+ *       fails with an ICHECK (runtime assertion).

Or, if only 1 is intended now:

-    if (node->kPack != 1 && node->kPack != 2) {
-      ICHECK(false) << "kPack must be 1 or 2";
+    if (node->kPack != 1) {
+      ICHECK(false) << "kPack must be 1";
     }
tilelang/language/builtin.py (2)

382-383: Fix operator precedence in validation (and/or)

and binds tighter than or, so descriptor.shape[0] may be accessed even when not a Buffer. Parenthesize the shape checks.

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):

408-409: Fix operator precedence in validation (and/or)

Same issue here—wrap the and side.

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
src/target/codegen_cuda.cc (1)

1537-1599: WGMMA-SS emission: uses ad-hoc string, ignores validated asm_code, and does invalid descriptor arithmetic

  • asm_code from PrintWGMMAAssembly is computed then ignored.
  • The fallback string does uint64_t((desc_a) + (A_offset)) which is invalid on a descriptor struct.
  • RS path already emits the validated asm_code.

Fix by emitting asm_code and removing the replacer block.

-    const bool a_is_shared = true;
-    this->PrintIndent();
-    std::string asm_code = PrintWGMMAAssembly(
+    const bool a_is_shared = true;
+    this->PrintIndent();
+    std::string asm_code = PrintWGMMAAssembly(
         shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
         A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
         scale_in_b, a_is_shared, "", "", "", false);
-    auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
-    std::string wgmma_asm_code =
-        "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
-        "(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
-        "uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
-    // replace patterns
-    tl::codegen::Replacer replacer;
-    replacer.register_rule("(AType)",
-                           tl::codegen::ptx::DTypeEnumToString(A_dtype));
-    replacer.register_rule("(BType)",
-                           tl::codegen::ptx::DTypeEnumToString(B_dtype));
-    replacer.register_rule("(CType)",
-                           tl::codegen::ptx::DTypeEnumToString(C_dtype));
-    replacer.register_rule("(M)", std::to_string(m));
-    replacer.register_rule("(N)", std::to_string(n));
-    replacer.register_rule("(K)", std::to_string(k));
-    replacer.register_rule("(tnspA)", a_is_k_major ? "false" : "true");
-    replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true");
-    replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1");
-    replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1");
-    replacer.register_rule("(desc_a)", a_desc);
-    replacer.register_rule("(A_offset)", A_offset);
-    replacer.register_rule("(desc_b)", b_desc);
-    replacer.register_rule("(B_offset)", B_offset);
-    replacer.register_rule("(C)", c_ref + " + " + c_offset);
-    replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
-    wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
-    this->stream << wgmma_asm_code;
+    this->stream << asm_code;
src/target/ptx.h (2)

154-162: Complete and align WGMMA parameter documentation with the signature

The comment documents only a subset of parameters and uses outdated names (A_layout/B_layout vs a_is_k_major/b_is_k_major). Please document all parameters and align names with the signature.

 /*!
  * \brief Print WGMMA assembly string given parameters.
  * \param shape The shape string mMnNkK
- * \param A_layout The layout of multiplicand A, can be either "row" or "col".
- * \param B_layout The layout of multiplicand B, can be either "row" or "col".
- * \param A_dtype The data type of multiplicand A.
- * \param B_dtype The data type of multiplicand B.
- * \param C_dtype The data type of multiplicand C.
+ * \param a_is_k_major Whether A is K‑major (true) or MN‑major (false).
+ * \param b_is_k_major Whether B is K‑major (true) or MN‑major (false).
+ * \param A_dtype Data type of A (e.g., ".f16", "e4m3", ".s8").
+ * \param B_dtype Data type of B.
+ * \param C_dtype Data type of accumulator/output C.
+ * \param a_desc Descriptor for A (SS) or base pointer (RS).
+ * \param A_offset Offset for A (in 16‑byte quanta when descriptor is used).
+ * \param b_desc Descriptor for B (SS/RS) or base pointer (SS).
+ * \param B_offset Offset for B (in 16‑byte quanta when descriptor is used).
+ * \param c_ptr Pointer to accumulator C.
+ * \param c_offset Offset into C (in elements).
+ * \param scale_out Scale‑out flag (true keeps/accumulates; false clears).
+ * \param scale_in_a Scale‑in flag for A (true 1, false −1).
+ * \param scale_in_b Scale‑in flag for B (true 1, false −1).
+ * \param a_is_shared Whether A comes from shared memory (SS vs RS selection).
+ * \param metadata Pointer to metadata buffer (sparse only).
+ * \param metadata_offset Offset into metadata.
+ * \param sparsity_selector Sparsity selector (sparse only).
+ * \param sparse Whether sparse WGMMA is used.
  */

164-174: Pass booleans by value (avoid const references to primitives)

Use plain bool for primitive params to match existing APIs and avoid unnecessary references. Update declaration/definition and callers accordingly.

-PrintWGMMAAssembly(const std::string &shape, const bool &a_is_k_major,
-                   const bool &b_is_k_major, const std::string &A_dtype,
+PrintWGMMAAssembly(const std::string &shape, bool a_is_k_major,
+                   bool b_is_k_major, const std::string &A_dtype,
                    const std::string &B_dtype, const std::string &C_dtype,
                    const std::string &a_desc, const std::string &A_offset,
                    const std::string &b_desc, const std::string &B_offset,
                    const std::string &c_ptr, const std::string &c_offset,
-                   const bool &scale_out, const bool &scale_in_a,
-                   const bool &scale_in_b, const bool &a_is_shared,
+                   bool scale_out, bool scale_in_a,
+                   bool scale_in_b, bool a_is_shared,
                    const std::string &metadata,
                    const std::string &metadata_offset,
                    const std::string &sparsity_selector, bool sparse);
tilelang/intrinsics/wgmma_macro_generator.py (4)

84-85: Default contradicts assertion (warp_row_tiles)

Set default to 16 to satisfy the class invariant.

-        warp_row_tiles: int = 8,
+        warp_row_tiles: int = 16,

170-171: Enforce divisibility to avoid dropping K tails (SS path)

Loop uses k_dim // micro_size_k; add a divisibility assert.

-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"

178-184: Fix per‑operand element sizes, descriptor pointer mode, and offset scaling (SS)

  • Use separate byte sizes for A/B; don’t reuse A’s for B.
  • Descriptors read; use access_ptr("r").
  • Scale A/B offsets with their own elem sizes.
-        elems_in_bits = DataType(self.a_dtype).bits
-        elems_in_bytes = elems_in_bits // 8
-
-        a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
-        b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
-        ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
+        elems_in_bits = DataType(self.a_dtype).bits
+        a_elems_in_bytes = elems_in_bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
+
+        a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // a_elems_in_bytes
+        b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else \
+            (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes)
@@
-        a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
-                                                                               elems_in_bytes)
-        a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim * a_elems_in_bytes)
+        a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 * a_elems_in_bytes)
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim *
-                                elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
-                                                                      (8 * 8 * elems_in_bytes))
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * b_elems_in_bytes))
@@
-            T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
+            T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
                                     int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                    T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
-                                   a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
-                                   (A_offset * elems_in_bytes) >> 4, desc_b.data,
-                                   (B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset,
+                    T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
+                                   a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
+                                   (A_offset * a_elems_in_bytes) >> 4, desc_b.data,
+                                   (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data, C_offset,
                                    scale_out, scale_in_a, scale_in_b)

Also applies to: 185-212, 213-237, 247-250, 260-264


288-289: RS path: divisibility, per‑operand sizes, pointer mode, K‑major flags, and K‑major swizzle SBO

  • Add K‑divisibility assert.
  • Use per‑operand byte sizes; don’t reuse A’s for B.
  • Descriptor is read‑only: use access_ptr("r").
  • Fix K‑major flags (match SS semantics).
  • When B is K‑major with swizzle, set SBO to 8 * swizzle_byte_size().
-        assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+        assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}"
+        assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"
@@
-        elems_in_bytes = DataType(self.a_dtype).bits // 8
+        a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+        b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
-        b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
-                                                                               elems_in_bytes)
-        b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
-                                                                                  elems_in_bytes)
+        b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes)
+        b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * b_elems_in_bytes)
@@
-            if b_is_k_major:
-                b_leading_byte_offset = 16
+            if b_is_k_major:
+                b_leading_byte_offset = 16
+                b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
@@
-            T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+            T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
                                     int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
-                        self.a_transposed,
-                        not self.b_transposed,
+                        (not self.a_transposed),
+                        self.b_transposed,
@@
-                        (B_offset * elems_in_bytes) >> 4,
+                        (B_offset * b_elems_in_bytes) >> 4,

Also applies to: 290-321, 325-327, 333-350

src/tl_templates/cuda/common.h (3)

303-314: Add compile‑time validation for descriptor fields

Guard against invalid layout_type and mis‑encoded offsets (template args).

 template <int layout_type = 0, int leading_byte_offset = 0,
           int stride_byte_offset = 0, typename T>
 TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
                                      T *start_address) {
+  static_assert(layout_type >= 0 && layout_type <= 3, "layout_type must be in [0,3]");
+  static_assert((leading_byte_offset >= 0) && (stride_byte_offset >= 0), "offsets must be non-negative");
+  // Expect offsets encoded in 16B quanta (match Python-side >>4)
+  static_assert((leading_byte_offset & 0xFFFF0000) == 0, "leading_byte_offset too large for bitfield");
+  static_assert((stride_byte_offset & 0xFFFF0000) == 0, "stride_byte_offset too large for bitfield");
   descriptor.bitfield.start_address_ =
       cute::cast_smem_ptr_to_uint(start_address) >> 4;
   descriptor.bitfield.layout_type_ = layout_type;
   descriptor.bitfield.base_offset_ = 0;
   descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
   descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
 }

171-205: Avoid DataType duplication (single source of truth)

This enum duplicates src/target/ptx.h. Prefer including/reusing the canonical definition to prevent drift.


254-265: Unify descriptor offset semantics (operator+)

operator+ adds raw units while increase_descriptor_offset accepts bytes (>>4). Make operator+ accept bytes too for consistency.

   template <typename T>
   CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
     GmmaDescriptor ret;
-    ret.reg32_[0] = reg32_[0] + uint32_t(offset);
+    ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
     ret.reg32_[1] = reg32_[1];
     return ret;
   }

Consider documenting that offsets are byte counts.

tilelang/tileop/gemm/__init__.py (1)

27-42: Fix enum typo: WGMMMA → WGMMA (dispatch bug)

The extra “M” mismatches the C++/FFI side and breaks selection.

 class GemmInst(IntEnum):
     MMA = 0
-    WGMMMA = 1
+    WGMMA = 1
     MFMA = 2
 
     def is_mma(self) -> bool:
         return self == GemmInst.MMA
 
     def is_wgmma(self) -> bool:
-        return self == GemmInst.WGMMMA
+        return self == GemmInst.WGMMA
🧹 Nitpick comments (5)
tilelang/language/builtin.py (1)

360-365: Type hints don’t match accepted inputs

Functions accept Buffer or BufferLoad, but annotations say Buffer / PrimExpr. Align hints to avoid confusion and improve IDE/static checks.

-def initialize_descriptor(descriptor: Buffer,
+def initialize_descriptor(descriptor: Union[Buffer, BufferLoad],
                           start_address: PrimExpr,
                           layout_type_: int = 0,
                           leading_byte_offset: int = 0,
                           stride_byte_offset: int = 0) -> PrimExpr:
@@
-def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr:
+def increase_descriptor_offset(descriptor: Union[Buffer, BufferLoad], offset: PrimExpr) -> PrimExpr:

Also applies to: 394-405

tilelang/layout/swizzle.py (2)

22-35: Use Optional typing and validate continuity derivation

  • Annotate continuity as Optional[int] per PEP 484.
  • Minor: assert the single-arg form is a Buffer and 2D before using shape.
-from typing import Union, Any
+from typing import Union, Any, Optional
@@
-def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
-                               continuity: int = None,
+def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
+                               continuity: Optional[int] = None,
                                k_major: bool = True):
     assert len(buffer.shape) == 2
-    if continuity is None:
+    if continuity is None:
         continuity = int(buffer.shape[1])

47-60: Guard the single-argument overloads

When len(args) == 1, ensure it’s a 2D Buffer before reading .shape.

-    if len(args) == 1:
-        buffer = args[0]
+    if len(args) == 1:
+        buffer = args[0]
+        assert isinstance(buffer, tvm.tir.Buffer) and len(buffer.shape) == 2

Apply similarly to half/quarter/linear helpers.

Also applies to: 72-85, 97-110, 120-130

tilelang/intrinsics/wgmma_macro_generator.py (2)

105-105: Unused parameter n_dim

Remove or prefix with underscore to satisfy linters.

-    def _initialize_wgmma_prefix(self, n_dim: int = 16):
+    def _initialize_wgmma_prefix(self, _n_dim: int = 16):

376-376: Redundant import

is_fragment is already imported at file top; remove the inner import.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2db013 and fef8d2a.

📒 Files selected for processing (16)
  • .clang-tidy (1 hunks)
  • src/layout/gemm_layouts.cc (6 hunks)
  • src/layout/layout.cc (3 hunks)
  • src/layout/layout.h (2 hunks)
  • src/op/builtin.h (2 hunks)
  • src/op/gemm.cc (7 hunks)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/target/ptx.cc (9 hunks)
  • src/target/ptx.h (2 hunks)
  • src/tl_templates/cuda/common.h (4 hunks)
  • src/tl_templates/cuda/instruction/wgmma.h (1 hunks)
  • src/transform/storage_rewrite.cc (2 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (1 hunks)
  • tilelang/language/builtin.py (2 hunks)
  • tilelang/layout/swizzle.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (3 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/target/ptx.cc
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/transform/storage_rewrite.cc
  • src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (10)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (12)
  • makeGemmABLayout (720-739)
  • makeGemmABLayout (720-721)
  • makeGemmABLayoutHopper (741-766)
  • makeGemmABLayoutHopper (741-742)
  • makeFullBankSwizzleLayout (421-437)
  • makeFullBankSwizzleLayout (421-421)
  • makeHalfBankSwizzleLayout (402-418)
  • makeHalfBankSwizzleLayout (402-402)
  • makeQuarterBankSwizzleLayout (382-399)
  • makeQuarterBankSwizzleLayout (382-383)
  • makeGemmLayoutLinear (491-495)
  • makeGemmLayoutLinear (491-491)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
  • scope (674-679)
  • scope (674-674)
  • n (366-370)
  • n (366-366)
  • n (371-375)
  • n (371-371)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1064-1103)
  • ptx_wgmma_rs (1106-1143)
src/target/ptx.cc (8)
  • PrintWGMMAAssembly (1188-1261)
  • PrintWGMMAAssembly (1189-1199)
  • ParseMMAShape (142-150)
  • ParseMMAShape (142-142)
  • DTypeEnumToString (108-110)
  • DTypeEnumToString (108-108)
  • DTypeEnumToString (112-115)
  • DTypeEnumToString (112-112)
tilelang/language/builtin.py (2)
  • initialize_descriptor (360-391)
  • increase_descriptor_offset (394-416)
src/tl_templates/cuda/common.h (2)
src/target/ptx.h (1)
  • DataType (45-90)
tilelang/language/builtin.py (2)
  • initialize_descriptor (360-391)
  • increase_descriptor_offset (394-416)
src/tl_templates/cuda/instruction/wgmma.h (1)
src/tl_templates/cuda/common.h (5)
  • tl (171-265)
  • DataType (180-225)
  • int (97-100)
  • int (135-142)
  • uint32_t (118-120)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (6)
  • makeGemmABLayoutHopper (741-766)
  • makeGemmABLayoutHopper (741-742)
  • makeGemmABLayoutSm100 (768-788)
  • makeGemmABLayoutSm100 (768-769)
  • makeGemmABLayoutCDNA (790-793)
  • makeGemmABLayoutCDNA (790-791)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/layout/layout.py (3)
  • Layout (13-148)
  • is_equal (136-145)
  • inverse (125-134)
tilelang/layout/swizzle.py (4)
  • make_full_bank_swizzled_layout (39-59)
  • make_half_bank_swizzled_layout (64-84)
  • make_quarter_bank_swizzled_layout (89-109)
  • make_linear_layout (112-130)
tilelang/layout/fragment.py (4)
  • is_equal (209-213)
  • Fragment (13-213)
  • replicate (146-160)
  • repeat (123-144)
tilelang/language/builtin.py (1)
  • initialize_descriptor (360-391)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_store_index_map (159-165)
src/target/ptx.h (4)
src/tl_templates/cuda/common.h (3)
  • DataType (180-225)
  • int (97-100)
  • int (135-142)
src/target/ptx.cc (10)
  • DTypeFromString (56-106)
  • DTypeFromString (56-56)
  • DTypeEnumToString (108-110)
  • DTypeEnumToString (108-108)
  • DTypeEnumToString (112-115)
  • DTypeEnumToString (112-112)
  • ParseMMAShape (142-150)
  • ParseMMAShape (142-142)
  • PrintWGMMAAssembly (1188-1261)
  • PrintWGMMAAssembly (1189-1199)
src/transform/storage_rewrite.cc (2)
  • dtype (697-703)
  • dtype (697-697)
src/target/codegen_hip.cc (1)
  • _rules (75-75)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
  • trans_A (46-47)
  • trans_B (50-51)
  • A (67-68)
  • B (71-72)
src/layout/gemm_layouts.cc (4)
  • makeGemmABLayout (720-739)
  • makeGemmABLayout (720-721)
  • makeGemmABLayoutHopper (741-766)
  • makeGemmABLayoutHopper (741-742)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
  • tvm (13-442)
tilelang/language/ast/ir.py (1)
  • evaluate (1319-1331)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
tilelang/layout/swizzle.py (3)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
src/op/builtin.h (1)
  • tvm (13-442)
src/layout/swizzle.h (1)
  • tvm (12-70)
🪛 Ruff (0.13.3)
tilelang/tileop/gemm/__init__.py

120-120: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/wgmma_macro_generator.py

105-105: Unused method argument: n_dim

(ARG002)


145-145: Avoid specifying long messages outside the exception class

(TRY003)


396-396: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/builtin.py

380-380: Avoid specifying long messages outside the exception class

(TRY003)


382-382: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


383-383: Avoid specifying long messages outside the exception class

(TRY003)


406-406: Avoid specifying long messages outside the exception class

(TRY003)


408-408: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


409-409: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/layout/swizzle.py

23-23: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)


104-104: Avoid specifying long messages outside the exception class

(TRY003)


126-126: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-nvidia
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
🔇 Additional comments (14)
.clang-tidy (1)

49-49: Re-evaluate disabling of clang-diagnostic-tautological-constant-compare in .clang-tidy:49. No instances of this diagnostic were found in a full clang-tidy scan—please confirm it’s actually triggered in your build configuration or remove it if unnecessary.

src/tl_templates/cuda/instruction/wgmma.h (1)

626-629: Clarify or remove unused WgmmaOutputRegs trait

  • No references to WgmmaOutputRegs were found in the codebase. If it isn’t used, delete it; otherwise clarify whether value is meant to represent per-thread register usage (e.g. 4 for M64N8K16 F32) versus total registers across the warp group, and update the formula or documentation accordingly.
src/op/gemm.cc (1)

585-591: New gemm_rs path selection looks good

Selecting rs/sr/ss by fragment/shared scope is consistent. No issues.

src/target/codegen_cuda.cc (2)

2123-2159: Descriptor allocation handling looks correct

local.descriptor mapped to tl::GmmaDescriptor avoids generic PrintStorageScope and is consistent with GetBufferRef handling.


1960-1980: Descriptor intrinsics emission looks good

Templated initialize_descriptor and typed increase_descriptor_offset<int> generation matches the Python builtins.

src/layout/gemm_layouts.cc (1)

578-587: kfactor→k_inner migration looks correct

Branching and error messages updated coherently across targets; no issues spotted.

Also applies to: 720-739, 741-766, 768-787

src/layout/layout.cc (8)

487-491: LGTM!

The equality check for Layout is correctly implemented, delegating to the existing IsEqual method.


500-504: LGTM!

The equality check for Fragment follows the same pattern as Layout_is_equal and correctly delegates to the existing IsEqual method.


521-531: API simplification for common case.

The function passes continuous for both mat_continuous and continuity parameters to the underlying layout constructors. This is a reasonable simplification for the common case where these values are equal. For finer control, users can call tl.make_wgmma_swizzled_layout which exposes both parameters separately.

The addition of the allow_pad parameter with a default value of true maintains backward compatibility while providing control over which layout constructor to use.


532-537: LGTM!

The new make_wgmma_swizzled_layout function correctly exposes all parameters (mat_continuous and continuity separately) for fine-grained control over WGMMA layout construction, properly delegating to makeGemmABLayoutHopper.


538-541: LGTM!

The binding correctly delegates to makeFullBankSwizzleLayout with the appropriate parameters.


542-545: LGTM!

The binding correctly delegates to makeHalfBankSwizzleLayout with the appropriate parameters.


546-550: LGTM!

The binding correctly delegates to makeQuarterBankSwizzleLayout with the appropriate parameters.


551-553: LGTM!

The new make_linear_layout binding correctly delegates to makeGemmLayoutLinear for creating linear (non-swizzled) layouts.

Comment on lines +169 to +176
int element_size, bool k_inner = true);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
int continuity, int element_size,
bool k_inner = true);
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
int element_size, bool k_inner = true);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor);
int kPack);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Audit callers after k_inner bool API change

APIs now take bool k_inner. Verify all call sites pass booleans, not ?: 1 : 2 ternaries. Sm100 calls in src/op/gemm.cc still pass ints.

Also applies to: 171-173, 174-176, 185-186


🏁 Script executed:

#!/bin/bash
# Find suspicious non-bool last-arg usages of Gemm AB layout constructors.
rg -nPU -C2 '\bmakeGemmABLayout(?:Hopper|Sm100)?\s*\('
# Heuristic: show calls where 5th arg looks like a ternary yielding 1/2
rg -nPU '\bmakeGemmABLayout(?:Hopper|Sm100)?\s*\((?:[^,]*,){4}\s*[^),]*\?\s*[12]\s*:\s*[12]\s*\)' -n -C2

Length of output: 8230


Fix non-boolean arguments to makeGemmABLayout calls*
Several call sites still pass ints (e.g., trans_A ? 1 : 2, trans_B ? 2 : 1) for the new bool k_inner parameter. Update:

  • src/op/gemm.cc: lines 784–786, 794–795
  • src/op/gemm_sp.cc: lines 240–242, 254–255
    to pass a boolean (e.g., trans_A or !trans_A) instead.

Comment on lines +342 to +646
// ================================= BF16 x BF16 -> F32
// =================================

// M64N8K16 BF16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32,
64, 8, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// M64N16K16 BF16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32,
64, 16, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
};

// ================================= TF32 x TF32 -> F32
// =================================

// M64N8K8 TF32->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32,
DataType::kFloat32, 64, 8, 8, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// M64N16K8 TF32->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32,
DataType::kFloat32, 64, 16, 8, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
};

// ================================= INT8 x INT8 -> INT32
// =================================

// M64N8K32 S8->S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// M64N16K32 S8->S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 16,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// ================================= FP8 x FP8 -> F16/F32
// =================================

// M64N8K32 E4M3->F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// M64N8K32 E4M3->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3,
DataType::kFloat32, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// 函数模板委托给类模板
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1>
TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
WgmmaSSImpl<A_type, B_type, C_type, M, N, K, tnspA, tnspB, scaleA,
scaleB>::execute(desc_a, desc_b, c, scale_out);
}

// ================================= Mixed Precision Support
// =================================

// Mixed precision: S8 x U8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kInt8, DataType::kUInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// Mixed precision: U8 x S8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// Mixed precision: U8 x U8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kUInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// Mixed precision FP8: E4M3 x E5M2 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e5m2,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// Mixed precision FP8: E5M2 x E4M3 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e5m2, DataType::kFloat8_e4m3,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};

// ================================= Convenience Templates
// =================================

// Type trait to determine the number of output registers needed
template <DataType C_type, int M, int N> struct WgmmaOutputRegs {
static constexpr int value =
(M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8);
};

// Type trait to get element size in bits
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kInt16 || dtype == DataType::kUInt16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
: 8;
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify tile size coverage is intentional.

Several data type combinations only support small tile sizes:

  • BF16 → F32: only 64×8×16 and 64×16×16 (missing 64×32×16, 64×64×16, etc.)
  • TF32 → F32: only 64×8×8 and 64×16×8
  • INT8 → INT32: only 64×8×32 and 64×16×32
  • FP8 variants: only 64×8×32

In contrast, F16 → F16 has full coverage up to 64×256×16. If users attempt unsupported tile sizes, they'll hit the fallback template with only a printf warning (since static_assert is commented out). Confirm whether:

  1. This limited coverage matches hardware capabilities (NVIDIA Hopper/Ada only supports certain tile configs per data type)
  2. The fallback should fail at compile time for missing specializations
  3. Additional specializations should be added

Run this script to check if other parts of the codebase rely on unsupported tile sizes:


🏁 Script executed:

#!/bin/bash
# Description: Search for wgmma_ss calls that might use unsupported tile configurations

# Search for wgmma_ss template instantiations
rg -nP 'wgmma_ss\s*<' -A2 -B2

# Search for BF16, TF32, INT8, FP8 usage patterns
rg -nP '(kBFloat16|kTensorFloat32|kInt8|kFloat8)' --type=cpp -C3

Length of output: 43439


Align WgmmaSSImpl specializations with supported tile configurations

Codegen registers dense WMMA shapes for BF16, TF32, INT8 and FP8 with N up to 256, but WgmmaSSImpl only defines N=8,16 (and N=8 for FP8). Instantiations for N>16 will silently hit the fallback. Either add missing specializations for the full shape set or restrict codegen to supported shapes and re-enable compile-time asserts for unsupported cases.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/instruction/wgmma.h around lines 342–646, the review
notes that codegen emits WMMA shapes with N up to 256 but WgmmaSSImpl only
provides specializations for N=8 and N=16 (and N=8 for FP8), so instantiations
for larger N silently fall back; fix by either (A) adding WgmmaSSImpl
specializations that match every tile configuration codegen emits (for each
DataType combination and N values up to 256, mirroring the existing patterns for
M/K variants and ensuring the asm constraints/operand counts match each N), or
(B) change the codegen to only emit the N values already supported (8,16) and
re-enable/restore compile-time static_asserts that reject unsupported shapes so
generation fails fast; choose one approach and implement it consistently across
BF16, TF32, INT8, FP8 and mixed-precision variants.

Comment on lines +632 to +645
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kInt16 || dtype == DataType::kUInt16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
: 8;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle kBit types in ElementBits trait.*

The ElementBits trait's default case returns 8 bits, but DataType includes kBit16, kBit32, and kBit64 variants (from common.h lines 170-224) that should return their respective bit widths, not the default 8.

Apply this diff to handle kBit* types correctly:

 template <DataType dtype> struct ElementBits {
   static constexpr int value =
       (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
        dtype == DataType::kInt32)
           ? 32
       : (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
-         dtype == DataType::kInt16 || dtype == DataType::kUInt16)
+         dtype == DataType::kInt16 || dtype == DataType::kUInt16 ||
+         dtype == DataType::kBit16)
           ? 16
       : (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
-         dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
+         dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2 ||
+         dtype == DataType::kBit8)
           ? 8
       : (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
+      : (dtype == DataType::kBit32) ? 32
+      : (dtype == DataType::kBit64) ? 64
+      : (dtype == DataType::kBit1) ? 1
                                                                 : 8;
 };
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kInt16 || dtype == DataType::kUInt16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
: 8;
};
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kInt16 || dtype == DataType::kUInt16 ||
dtype == DataType::kBit16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2 ||
dtype == DataType::kBit8)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4)
? 4
: (dtype == DataType::kBit32)
? 32
: (dtype == DataType::kBit64)
? 64
: (dtype == DataType::kBit1)
? 1
: 8;
};
🤖 Prompt for AI Agents
In src/tl_templates/cuda/instruction/wgmma.h around lines 632 to 645, the
ElementBits trait currently falls back to 8 for unknown types but omits handling
DataType::kBit16, kBit32, and kBit64; update the conditional chain to explicitly
return 16 for kBit16, 32 for kBit32, and 64 for kBit64 (keeping the existing
mappings for float/int variants and retaining an 8-bit fallback) so those
bit-oriented DataType variants return their correct bit widths.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd9bd37 and ff3e04d.

📒 Files selected for processing (1)
  • examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)
examples/deepseek_v32/sparse_mla_bwd.py (1)
  • test_sparse_mla_bwd (334-383)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: format-check
🔇 Additional comments (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)

36-37: LGTM once the import is fixed.

Making the function call explicit with keyword arguments improves clarity and aligns with the updated APIs mentioned in the PR objectives. The use of SKV=4096 instead of the default 32768 is reasonable for testing purposes (smaller, faster execution).

Comment on lines 36 to 37
test_sparse_mla_bwd(
B=1, S=4096, SKV=4096, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing import for torch.

The code references torch.bfloat16 on line 37, but the torch module is not imported in this file. This will cause a NameError at runtime when the test is executed.

Apply this diff to add the missing import:

 # ruff: noqa
+import torch
 import tilelang.testing
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
test_sparse_mla_bwd(
B=1, S=4096, SKV=4096, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16)
# ruff: noqa
import torch
import tilelang.testing
🤖 Prompt for AI Agents
In examples/deepseek_v32/test_tilelang_example_deepseek_v32.py around lines 36
to 37, the test calls torch.bfloat16 but the torch module is not imported; add
an import statement for torch at the top of the file (e.g., import torch) so
torch.bfloat16 resolves, keeping import placement consistent with existing
imports.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ff3e04d and c6ab014.

📒 Files selected for processing (1)
  • examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)
examples/deepseek_v32/sparse_mla_bwd.py (1)
  • test_sparse_mla_bwd (334-383)
🪛 GitHub Actions: CI Test on AMD
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

[error] 1-1: Reformatted files. Please review and stage the changes.

🪛 GitHub Actions: CI Test on Metal
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

[error] 1-1: Code formatting check failed. Reformatted files. Please review and stage the changes. Changes not staged for commit: examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

…checks

- Changed SKV parameter from 32768 to 8192 in sparse MLA backward and forward tests.
- Added check_correctness parameter to test functions for validation of outputs.
- Updated test cases to reflect new SKV values and correctness checks.
@LeiWang1999 LeiWang1999 merged commit a13cde2 into tile-ai:main Oct 9, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant