-
Couldn't load subscription status.
- Fork 286
[TileOp] Implement WGMMA for T.gemm_v2 #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Added support for the WGMMA intrinsic in the TileLang framework, enabling efficient matrix multiplication on newer architectures. - Refactored GEMM layout functions to accept a boolean parameter for K dimension handling, improving flexibility in layout generation. - Updated layout inference logic to accommodate new WGMMA configurations and ensure compatibility with existing GEMM operations. - Enhanced Python bindings for layout functions, allowing for better integration and usability in user-defined operations. - Improved documentation for layout functions and GEMM operations to clarify usage and parameters. These changes enhance the performance and usability of GEMM operations, particularly for advanced architectures, while maintaining backward compatibility with existing implementations.
…bility - Improved code formatting across multiple files for better readability, including consistent indentation and line breaks. - Updated layout function signatures to enhance clarity, particularly in `gemm_layouts.cc`, `layout.cc`, and `layout.h`. - Refactored lambda functions in `builtin.cc` and `gemm_py.cc` for improved structure and maintainability. - Enhanced comments and documentation in layout-related files to clarify usage and parameters. These changes contribute to a cleaner codebase and improved maintainability of layout functions in the TileLang framework.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughReplace integer K-factor parameters with boolean Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant GemmPy as GemmPy (tileop)
participant FFI as _ffi_api.GemmPyGemmInst
participant Impl as Impl (GemmMMA / GemmWGMMA)
participant Lower as TIR Lowering
User->>GemmPy: infer_layout(target, thread_nums)
GemmPy->>FFI: GemmPyGemmInst(self, thread_nums, target)
FFI-->>GemmPy: GemmInst (MMA/WGMMMA/MFMA)
GemmPy->>Impl: select implementation class
GemmPy->>Impl: infer_layout(...)
Impl-->>GemmPy: A/B/C layouts
User->>GemmPy: lower(layout_map, target, thread_nums, thread_var)
GemmPy->>Impl: lower(layout_map, target, thread_nums, thread_var)
Impl->>Lower: build PrimFunc(s)
Lower-->>User: Stmt / PrimFunc
sequenceDiagram
autonumber
participant TIR as TIR (tilelang.tir.op)
participant Builtin as tl Builtins
participant CGen as CUDA Codegen
participant PTX as PTX Utils
TIR->>Builtin: call tl.ptx_wgmma_ss/rs(...)
Builtin-->>CGen: CallNode(op=ptx_wgmma_ss/rs, args)
CGen->>PTX: PrintWGMMAAssembly(shape, layouts, dtypes, descs, offsets, ...)
PTX-->>CGen: PTX assembly string
CGen-->>TIR: Emit inline asm
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…GMMA - Introduced new TileLang builtins `initialize_descriptor` and `increase_descriptor_offset` to facilitate descriptor management for WGMMA operations. - Updated `builtin.cc` and `builtin.h` to define and document the new builtins, enhancing the framework's capabilities for descriptor handling. - Modified `codegen_cuda.cc` and `ptx.cc` to integrate the new builtins into the code generation process, ensuring proper assembly generation for WGMMA operations. - Enhanced the `GemmWGMMA` class to utilize the new descriptor functionalities, improving the efficiency of matrix multiplication operations. - Updated related tests and documentation to reflect the new features and ensure comprehensive coverage. These changes enhance the TileLang framework's support for advanced matrix operations on newer architectures, improving performance and usability.
- Enhanced code formatting across multiple files for better readability, including consistent indentation and line breaks. - Updated function signatures and comments in `builtin.h`, `codegen_cuda.cc`, and `ptx.cc` to improve clarity. - Refactored descriptor initialization and offset manipulation functions in `builtin.py` and `wgmma_macro_generator.py` for improved structure. - Cleaned up unnecessary whitespace and improved alignment in `common.h` and `allocate.py`. These changes contribute to a cleaner and more maintainable codebase in the TileLang framework.
- Updated the subproject commit for `cutlass` to indicate a dirty state. - Refactored the `UpdateAnalyzer` function in `layout.cc` to call `LayoutNode::getVarMap()` instead of `getVarMap()`, improving clarity and ensuring proper context for variable mapping. These changes enhance the maintainability and clarity of the layout handling in the TileLang framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/customize.py (1)
193-197: Bug: extent computation uses list max instead of elementwise max
max(src_extent, dst_extent)returns one of the lists lexicographically, not per-dimension. This can mis-size regions and corrupt memory ops.Apply:
- extent = max(src_extent, dst_extent) + # elementwise extent unification + extent = [T.max(a, b) for a, b in zip(src_extent, dst_extent)]
🧹 Nitpick comments (32)
tilelang/language/__init__.py (1)
45-46: Export looks good; clean up unused noqa.
- Re-exporting alloc_descriptor is correct and matches allocate.py.
- Ruff flags Line 45 for unused
# noqa: F401(RUF100). Drop it here (others weren’t flagged in this diff).Apply:
- alloc_descriptor, # noqa: F401 + alloc_descriptor,tilelang/layout/layout.py (3)
92-94: Add docstring/return type for get_forward_index for parity with getters.Minor consistency nit. Consider:
- def get_forward_index(self): - return self.index + def get_forward_index(self) -> PrimExpr | list[PrimExpr]: + """Return the computed forward index expression(s).""" + return self.index
136-146: API parity: provide Pythonic equality too.Keep is_equal, but also implement eq delegating to it for ergonomic comparisons; keep hash=None to avoid hashing mutable objects.
def is_equal(self, other: "Layout") -> bool: """ Check if the current layout is equal to another layout. """ return _ffi_api.Layout_is_equal(self, other) + def __eq__(self, other: object) -> bool: + if not isinstance(other, Layout): + return NotImplemented + return self.is_equal(other) + __hash__ = None
147-148: repr: avoid huge dumps when vars/index grow.Consider truncating sequences for readability in logs.
- return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>" + fv = self.get_forward_vars() + fi = self.get_forward_index() + def _short(x): + s = str(x) + return s if len(s) <= 120 else s[:117] + "..." + return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {_short(fv)} -> {_short(fi)}>"src/op/builtin.h (2)
164-185: Fix WGMMA doc comments to match actual RS/SS operand forms.Current comment for ptx_wgmma_rs refers to A_descriptor; RS variant uses A_buf (regular pointer/buffer) per Python wrappers. Also B_offset is documented as Var (should be PrimExpr). Update comments to avoid API confusion.
- * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool - * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, - * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var - * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool + * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool + * trans_a, bool trans_b, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, + * StringImm accum_dtype_abbrv, Var A_buf, PrimExpr A_offset, Var + * B_descriptor, PrimExpr B_offset, Var C_data, PrimExpr C_offset, bool scale_out, bool * scale_in_a, bool scale_in_b);
344-361: Descriptor intrinsic docs: correct operation wording.increase_descriptor_offset increments the offset; the block comment says “setting the start address.” Align wording to avoid misuse.
- * \brief tilelang intrinsic for setting the start address of a descriptor - * buffer for wgmma/utcmma. + * \brief tilelang intrinsic for increasing the offset of a descriptor + * buffer for wgmma/utcmma.tilelang/language/tir/op.py (2)
1064-1104: Nit: Docstring says “wmma” but this is WGMMA; also consider briefly documenting operand order.Purely cosmetic; helps future maintainers and avoids confusion with WMMA.
- """TVM intrinsic for ptx tensor core wmma instructions + """TVM intrinsic for PTX warp-group MMA (WGMMA) instructions + Operand order: prefix, trans_a, trans_b, a/b/accum dtype abbrvs, + A_desc, A_offset, B_desc, B_offset, C_data, C_offset, scale_out, scale_in_a, scale_in_b.
1106-1144: Nit: Same WGMMA terminology/doc tweak as above; otherwise wrapper matches builtin.LGTM functionally; mirrors 15-arg registration.
- return call_intrin( + # PTX warp-group MMA (WGMMA) RS variant: A from register, B from descriptor. + return call_intrin(tilelang/layout/__init__.py (1)
6-13: Remove unused “noqa: F401” directives or enable F401 in Ruff config.Ruff flags these as unused (RUF100). Either drop them or configure Ruff to honor F401.
-from .swizzle import ( - make_swizzled_layout, # noqa: F401 - make_wgmma_swizzled_layout, # noqa: F401 - make_full_bank_swizzled_layout, # noqa: F401 - make_half_bank_swizzled_layout, # noqa: F401 - make_quarter_bank_swizzled_layout, # noqa: F401 -) -from .gemm_sp import make_metadata_layout # noqa: F401 +from .swizzle import ( + make_swizzled_layout, + make_wgmma_swizzled_layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, +) +from .gemm_sp import make_metadata_layouttilelang/tileop/gemm/gemm_mma.py (2)
60-77: Use layout_map (or underscore the arg) to avoid ARG002 and keep parity with WGMMA.Mirror WGMMA: when provided, feed A/B shared layouts into the emitter; otherwise prefix arg as
_layout_mapto appease linters.- def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): @@ mma_emitter = TensorCoreIntrinEmitter( @@ thread_var=thread_var, ) + + # Optional: honor externally inferred layouts if present (parity with WGMMA) + if self.A in layout_map: + mma_emitter._assign_a_shared_layout(layout_map[self.A]) + if self.B in layout_map: + mma_emitter._assign_b_shared_layout(layout_map[self.B])
90-91: Replace assert with explicit validation (and check divisibility).Python asserts can be stripped with -O; raise a ValueError and also enforce block_K % micro_size_k == 0 to match loop step.
- assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + if block_K < micro_size_k or (block_K % micro_size_k) != 0: + raise ValueError( + f"Invalid K tile: block_K={block_K}, micro_size_k={micro_size_k} " + "(must be >= and divisible)." + )src/op/gemm.cc (1)
45-48: Documentation inconsistency between comment and implementation.The documentation states that
kPackmust be 1, but the implementation at lines 71-73 allows values of both 1 and 2. This creates confusion about the actual requirements.Either update the documentation to reflect that kPack can be 1 or 2, or enforce the restriction that it must be 1:
- * @note If `kPack` is provided it must be 1; otherwise the constructor - * fails with an ICHECK (runtime assertion). No other validation is - * performed here. + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here.tilelang/language/builtin.py (1)
375-375: Consider moving error messages to exception classes.While not critical, defining error messages inside exception classes improves maintainability and reusability.
Consider creating custom exception classes:
class InvalidDescriptorTypeError(TypeError): """Raised when descriptor is not a Buffer or BufferLoad.""" def __init__(self): super().__init__("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") class InvalidDescriptorShapeError(ValueError): """Raised when descriptor is not a 1D buffer of size 1.""" def __init__(self): super().__init__("Descriptor must be a 1D buffer of size 1.")Also applies to: 378-378, 401-401, 404-404
tilelang/layout/swizzle.py (2)
23-23: Use Optional type annotation for nullable parameter.The
continuityparameter can beNonebut isn't typed asOptional.Add proper type annotation:
+from typing import Optional + def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, - continuity: int = None, + continuity: Optional[int] = None, k_major: bool = True):
54-54: Consider more descriptive error messages.The error messages could be more helpful by describing what arguments are expected.
Improve error messages:
- raise ValueError(f"Invalid arguments: {args}") + raise ValueError(f"Expected either a single buffer or (stride, continuous, element_size), got {len(args)} arguments: {args}")Also applies to: 79-79, 104-104
src/layout/gemm_layouts.cc (2)
573-574: Parameter name inconsistency with function documentation.The parameter name
k_innerdoesn't match the documentation's description which refers to whether the "K dimension is in the inner loop". Consider renaming tok_is_inneroris_k_innerfor better clarity.
532-541: Missing implementation of k_major parameter in Volta layout.The function signature was updated to use
bool k_innerbut the implementation still usesk_innerdirectly as a boolean flag without considering the k-major semantics that the rest of the codebase expects.Based on the pattern in
makeGemmABLayoutandmakeGemmABLayoutHopper, you may want to verify that this implementation correctly handles the k-major/k-inner semantics.tilelang/tileop/gemm/gemm_wgmma.py (1)
110-111: Verify RS implementation for A in fragment and B in shared.The
_gemm_rsrfunction name and docstring mention loading data from "shared buffers A_shared and B_shared", but the RS variant should have A in registers/fragments. The docstring appears to be copied from the SS variant.Apply this diff to fix the docstring:
@T.prim_func def _gemm_rsr() -> None: """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. + The inner macro that uses data from local fragment A_local and + shared buffer B_shared, then issues Tensor Core mma ops, + accumulating into C_local. """ mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum)tilelang/intrinsics/wgmma_macro_generator.py (3)
104-104: Remove unused parametern_dimfrom method signature.The method
_initialize_wgmma_prefixhas an unused parametern_dim=16that shadows the instance variableself.n_dim.Apply this diff to fix:
- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles
133-143: Extract swizzle mode detection to centralized utility.The
_determinate_swizzle_modemethod performs layout equality checks against multiple swizzle patterns. This logic could be refactored into a centralized layout utility to avoid duplication if similar detection is needed elsewhere.Would you like me to help create a centralized swizzle mode detection utility that could be reused across the codebase?
345-353: Improve error message specificity for unsupported dtypes.The error message for unsupported dtypes could be more informative by including the actual bit width.
Apply this diff:
else: - raise ValueError(f"Unsupported dtype {dtype}") + raise ValueError(f"Unsupported dtype {dtype} with {dtype_bits} bits for MMA load layout")src/layout/layout.cc (1)
495-505: Parameter naming inconsistency in swizzled layout creation.The
make_swizzled_layoutfunction usesallow_padto choose betweenmakeGemmABLayoutandmakeGemmABLayoutHopper. The parameter nameallow_paddoesn't clearly convey that it's selecting between different hardware layout strategies (standard vs Hopper).Consider renaming to be more descriptive:
.def("tl.make_swizzled_layout", [](int stride, int continuous, int element_size, bool k_inner, - bool allow_pad = true) { - if (allow_pad) { + bool use_hopper_layout = false) { + if (!use_hopper_layout) { return makeGemmABLayout(stride, continuous, continuous, element_size, k_inner); } else { return makeGemmABLayoutHopper(stride, continuous, continuous, element_size, k_inner); } })tilelang/tileop/gemm/__init__.py (2)
82-98: Add parameter validation for thread_nums.While the FFI call handles the selection logic, it would be good to validate that
thread_numsis positive before passing it to the FFI.def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst: """Select the appropriate GEMM instruction based on target and thread configuration. The selection logic follows this priority: 1. WGMMA for Hopper architecture with sufficient matrix size and warp count 2. MFMA for CDNA (AMD) architecture 3. MMA for CUDA architecture 4. Fallback to MMA for other cases Args: thread_nums: Number of threads in the block target: Target architecture Returns: GemmInst: The selected GEMM instruction type """ + if thread_nums <= 0: + raise ValueError(f"thread_nums must be positive, got {thread_nums}") return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target))
118-118: Consider using a more specific exception message.The error message could be more informative by including what implementations are available.
- raise NotImplementedError("MFMA is not implemented") + raise NotImplementedError("MFMA is not implemented. Available implementations: MMA, WGMMA")src/target/ptx.cc (2)
1053-1168: Complex operand generation but missing a_is_shared validation for register path.The function generates WGMMA operands correctly but doesn't validate that when
a_is_sharedis false, the operation is actually supported for register-based A operands.Consider adding validation:
inline std::tuple<std::string, std::string, std::string, std::string> GetWGMMAOperands(int m, int n, int k, ptx::DataType dtype_a, ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse, bool a_is_shared) { + // WGMMA with register-based A operand has limitations + if (!a_is_shared) { + // Add any specific validation for register-based A operands if needed + // based on NVIDIA documentation + } std::stringstream templates, inputs, outputs, predicate;
1263-1266: Consider extracting predicate setup as a constant.The predicate setup code could be defined as a constant string for better maintainability.
+ constexpr const char* PREDICATE_SETUP = + "{.reg .pred p;\n" + "setp.ne.b32 p, {predicate}, 0;\n"; + std::string asm_code = R"( { __asm__ __volatile__( - "{.reg .pred p;\n" - "setp.ne.b32 p, {predicate}, 0;\n" + ")" PREDICATE_SETUP R"( "wgmma.mma_async{.sparse}.sync.aligned{.shape}{.dtype}{.atype}{.btype}" "{templates};\n}" : {outputs} : {inputs}); } )";tilelang/language/customize.py (2)
160-179: Deduplicate get_extent; current helper misses BufferLoad and can divergeYou re-implement
get_extentbut omitBufferLoadhandling (supported in tilelang/language/copy.py). Prefer reusing the canonical helper to avoid drift.Apply:
@@ - def get_extent(data): - """ - Return the inferred extent (shape) of a buffer-like object. - ... - """ - if isinstance(data, Var) and T.has_let_value(data): - data = T.get_let_value(data) - if isinstance(data, Buffer): - return data.shape - elif isinstance(data, BufferRegion): - return [x.extent for x in data.region] - else: - return None - - src_extent = get_extent(value) - dst_extent = get_extent(dst) + src_extent = _get_extent(value) + dst_extent = _get_extent(dst)Also add the import near the top of this file:
from tilelang.language.copy import get_extent as _get_extent
82-105: Honor provided extents in buffer_region_to_tile_regionThe
extentsparameter is only asserted but ignored. This prevents aligning extents when mixing BufferRegion with BufferLoad.Apply:
@@ - return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) + # Override trailing dims with requested extents if provided + if extents: + region_extents = list(region_extents) + region_extents[-len(extents):] = extents + return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)tilelang/language/allocate.py (1)
129-136: Add type hints for API clarityTighten signature and return type.
Apply:
-def alloc_descriptor(dtype="uint64", scope="local.descriptor"): +def alloc_descriptor(dtype: str = "uint64", scope: str = "local.descriptor") -> T.Buffer: """Allocate a descriptor buffer for wgmma and utcmma. Returns: T.Buffer: A TVM buffer object allocated as a descriptor """ return T.alloc_buffer([1], dtype, scope=scope)src/layout/layout.h (3)
163-167: Prefer explicit enum over boolean for K placementA boolean is easy to misuse. Consider a strongly-typed enum (e.g.,
enum class KPlacement { Inner, Outer };) to make call sites self-documenting and prevent silent int→bool coercions.Example:
enum class KPlacement { Inner, Outer }; Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, int element_size, KPlacement k_place = KPlacement::Inner);
168-169: CDNA kPack rename verified — code updated; fix lingering doc/commentsmakeGemmABLayoutCDNA (declaration/definition) and callers use int kPack and Python bindings expose "kPack".
- Update remaining comment references to the old name:
- src/op/gemm_py.h — comment referencing "k_pack" (around line ~30).
- src/op/gemm_sp.h — similar comment (around line ~25).
177-178: Volta layout: no functional change required — callers pass explicit k_inner; remove or flip default for clarity.
makeGemmVoltaABLayout(..., bool k_inner = true) is declared in src/layout/layout.h; call sites pass explicit values (src/op/gemm.cc:479, 492).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (30)
src/layout/gemm_layouts.cc(4 hunks)src/layout/layout.cc(4 hunks)src/layout/layout.h(2 hunks)src/op/builtin.cc(2 hunks)src/op/builtin.h(2 hunks)src/op/gemm.cc(7 hunks)src/op/gemm_py.cc(4 hunks)src/op/gemm_py.h(1 hunks)src/target/codegen_cuda.cc(5 hunks)src/target/ptx.cc(6 hunks)src/target/ptx.h(1 hunks)src/tl_templates/cuda/common.h(4 hunks)src/transform/lower_device_storage_access_info.cc(1 hunks)src/transform/storage_rewrite.cc(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py(1 hunks)tilelang/language/__init__.py(4 hunks)tilelang/language/allocate.py(1 hunks)tilelang/language/ast/ir.py(2 hunks)tilelang/language/builtin.py(2 hunks)tilelang/language/customize.py(10 hunks)tilelang/language/tir/ir.py(1 hunks)tilelang/language/tir/op.py(1 hunks)tilelang/layout/__init__.py(1 hunks)tilelang/layout/fragment.py(1 hunks)tilelang/layout/layout.py(2 hunks)tilelang/layout/swizzle.py(1 hunks)tilelang/tileop/gemm/__init__.py(3 hunks)tilelang/tileop/gemm/gemm_base.py(2 hunks)tilelang/tileop/gemm/gemm_mma.py(2 hunks)tilelang/tileop/gemm/gemm_wgmma.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (23)
tilelang/language/tir/ir.py (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/language/ast/ir.py (1)
_dtype_forward(1876-1884)
tilelang/language/allocate.py (2)
src/transform/storage_rewrite.cc (4)
dtype(696-702)dtype(696-696)scope(674-678)scope(674-674)tilelang/language/ast/ir.py (1)
alloc_buffer(441-508)
src/transform/lower_device_storage_access_info.cc (1)
src/transform/storage_rewrite.cc (2)
scope(674-678)scope(674-674)
tilelang/layout/layout.py (1)
tilelang/layout/fragment.py (1)
is_equal(209-213)
tilelang/language/ast/ir.py (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/language/tir/ir.py (1)
_dtype_forward(156-164)
tilelang/layout/swizzle.py (2)
tilelang/language/ast/ir.py (1)
buffer(93-161)src/layout/swizzle.h (1)
tvm(12-70)
tilelang/tileop/gemm/gemm_wgmma.py (6)
tilelang/tileop/gemm/gemm_base.py (17)
GemmBase(12-120)infer_layout(15-16)policy(119-120)M(34-35)N(38-39)in_dtype(54-56)accum_dtype(59-60)trans_A(46-47)trans_B(50-51)chunk(63-64)is_gemm_ss(21-22)K(42-43)A(67-68)B(71-72)C(75-76)lower(18-19)clear_accum(107-108)tilelang/layout/swizzle.py (1)
make_wgmma_swizzled_layout(22-34)tilelang/intrinsics/wgmma_macro_generator.py (6)
TensorCoreIntrinEmitter(63-477)make_mma_store_layout(423-477)make_mma_load_layout(311-421)_assign_a_shared_layout(96-98)_assign_b_shared_layout(100-102)wgmma(145-233)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/transform/simplify.py (1)
_Simplify(30-49)tilelang/tileop/gemm/gemm_mma.py (3)
infer_layout(15-58)is_gemm_ss(204-205)lower(60-202)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/layout/swizzle.py (3)
make_full_bank_swizzled_layout(39-59)make_half_bank_swizzled_layout(64-84)make_quarter_bank_swizzled_layout(89-109)tilelang/layout/fragment.py (4)
is_equal(209-213)Fragment(13-213)replicate(146-160)repeat(123-144)tilelang/language/builtin.py (1)
initialize_descriptor(355-386)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map(159-165)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (2)
scope(674-678)scope(674-674)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)src/target/ptx.cc (2)
PrintWGMMAAssembly(1235-1306)PrintWGMMAAssembly(1235-1244)tilelang/language/builtin.py (2)
initialize_descriptor(355-386)increase_descriptor_offset(389-411)
src/op/builtin.h (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/language/builtin.py (2)
initialize_descriptor(355-386)increase_descriptor_offset(389-411)
src/tl_templates/cuda/common.h (3)
src/tl_templates/cuda/copy_sm90.h (1)
void(255-258)src/tl_templates/cuda/ldsm.h (12)
void(7-14)void(16-23)void(25-33)void(35-42)void(44-52)void(54-62)void(64-70)void(72-79)void(81-89)void(91-98)void(100-108)void(110-119)tilelang/language/builtin.py (2)
initialize_descriptor(355-386)increase_descriptor_offset(389-411)
src/target/ptx.h (1)
src/target/ptx.cc (2)
PrintWGMMAAssembly(1235-1306)PrintWGMMAAssembly(1235-1244)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
tvm(13-363)tilelang/language/ast/ir.py (1)
evaluate(1319-1331)tilelang/language/tir/op.py (1)
call_intrin(119-144)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
trans_A(46-47)trans_B(50-51)A(67-68)B(71-72)src/layout/gemm_layouts.cc (2)
makeGemmABLayout(573-592)makeGemmABLayout(573-574)
tilelang/language/customize.py (1)
tilelang/language/copy.py (1)
get_extent(105-118)
src/op/builtin.cc (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/language/builtin.py (2)
initialize_descriptor(355-386)increase_descriptor_offset(389-411)
tilelang/tileop/gemm/gemm_mma.py (3)
tilelang/tileop/gemm/gemm_wgmma.py (1)
lower(64-125)tilelang/tileop/gemm/__init__.py (1)
lower(76-80)tilelang/tileop/gemm/gemm_base.py (1)
lower(18-19)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
alloc_descriptor(129-135)
tilelang/layout/__init__.py (2)
tilelang/layout/swizzle.py (5)
make_swizzled_layout(10-18)make_wgmma_swizzled_layout(22-34)make_full_bank_swizzled_layout(39-59)make_half_bank_swizzled_layout(64-84)make_quarter_bank_swizzled_layout(89-109)tilelang/layout/gemm_sp.py (1)
make_metadata_layout(98-109)
tilelang/tileop/gemm/__init__.py (4)
tilelang/ir.py (1)
GemmWarpPolicy(30-39)tilelang/tileop/gemm/gemm_mma.py (3)
GemmMMA(13-214)lower(60-202)infer_layout(15-58)tilelang/tileop/gemm/gemm_wgmma.py (3)
GemmWGMMA(13-137)lower(64-125)infer_layout(15-62)tilelang/tileop/gemm/gemm_base.py (2)
lower(18-19)infer_layout(15-16)
tilelang/layout/fragment.py (1)
tilelang/layout/layout.py (3)
get_input_shape(59-68)get_output_shape(70-79)is_equal(136-145)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (4)
makeGemmABLayoutHopper(594-615)makeGemmABLayoutHopper(594-595)makeGemmABLayoutCDNA(617-625)makeGemmABLayoutCDNA(617-618)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (10)
makeGemmABLayout(573-592)makeGemmABLayout(573-574)makeGemmABLayoutHopper(594-615)makeGemmABLayoutHopper(594-595)makeFullBankSwizzleLayout(375-391)makeFullBankSwizzleLayout(375-375)makeHalfBankSwizzleLayout(356-372)makeHalfBankSwizzleLayout(356-356)makeQuarterBankSwizzleLayout(336-353)makeQuarterBankSwizzleLayout(336-337)
🪛 Ruff (0.12.2)
tilelang/layout/swizzle.py
23-23: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
54-54: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
104-104: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_wgmma.py
61-62: Avoid specifying long messages outside the exception class
(TRY003)
124-125: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/wgmma_macro_generator.py
104-104: Unused method argument: n_dim
(ARG002)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
353-353: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/builtin.py
375-375: Avoid specifying long messages outside the exception class
(TRY003)
377-377: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
378-378: Avoid specifying long messages outside the exception class
(TRY003)
401-401: Avoid specifying long messages outside the exception class
(TRY003)
403-403: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
404-404: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_mma.py
60-60: Unused method argument: layout_map
(ARG002)
tilelang/language/__init__.py
45-45: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/layout/__init__.py
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
8-8: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
9-9: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
13-13: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/tileop/gemm/__init__.py
120-120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (43)
tilelang/tileop/gemm/gemm_base.py (2)
8-8: Add PrimExpr import for updated clear_accum type annotation.The import is correctly added to support the type annotation change.
107-108: Update type annotation from bool to PrimExpr.The change from
booltoPrimExpraligns with the WGMMA backend requirements where clear_accum can be an expression rather than just a boolean value. This provides more flexibility for the accumulation control logic.src/op/gemm_py.h (1)
112-113: Clean private section reorganization.The reorganization of private members with explicit
private:labeling improves code organization and readability. TheGemmInstenum andGetGemmInstmethod remain appropriately private while allowing FFI access through wrapper functions.src/transform/lower_device_storage_access_info.cc (1)
47-48: Correct exclusion of descriptor scope from storage lowering.The addition of
scope.tag != ".descriptor"properly excludes descriptor buffers from the generic storage access lowering path, which is necessary for the WGMMA descriptor handling. This change aligns with the corresponding exclusion instorage_rewrite.ccand allows descriptor buffers to maintain their special semantics.tilelang/language/tir/ir.py (1)
294-295: Add WGMMA PTX intrinsic wrappers.The new
ptx_wgmma_ssandptx_wgmma_rswrappers correctly follow the established pattern using_dtype_forwarddecorator, providing consistent API access to the underlying PTX WGMMA intrinsics for shared-shared and register-shared variants.tilelang/layout/fragment.py (2)
207-207: Enhanced Fragment representation with shape information.The updated
__repr__method now includes input and output shapes, providing more comprehensive debugging information. The use ofget_input_shape()andget_output_shape()follows the established pattern from the base Layout class.
209-213: Add Fragment equality check method.The new
is_equalmethod provides a proper way to compare Fragment instances by delegating to the FFI implementation. This aligns with the corresponding method in the base Layout class and enables proper equality testing for Fragment objects.src/transform/storage_rewrite.cc (2)
677-678: Exclude.descriptorfrom “special tagged memory” — OK.This prevents special-merge and memory‑info constraints from applying to descriptor buffers, which is desirable.
847-849: Skip memory‑info path for.descriptorin NewAllocTagMerged — OK; add test.Change correctly treats descriptor allocations neutrally (GetMemoryInfo is not invoked for
.descriptor). Add a minimal IR test that ensures:
- local.descriptor buffers are not grouped into special merges
- no GetMemoryInfo lookups occur for
.descriptorscopesRelevant files: src/transform/storage_rewrite.cc (≈lines 846–850), src/transform/lower_device_storage_access_info.cc (≈line 48).
tilelang/language/ast/ir.py (1)
1895-1896: Approve: WGMMA wrappers wired and registeredWrapper bindings exist in tilelang/language/ast/ir.py (ptx_wgmma_ss/rs), TIR op wrappers call _tvm_op.Op.get("tl.ptx_wgmma_ss/rs") in tilelang/language/tir/op.py, builtins are registered in src/op/builtin.cc/.h (TIR_DEFINE_TL_BUILTIN + declarations) and referenced by codegen in src/target/codegen_cuda.cc — no issues found. Add a small CI check to assert builtin registration if desired.
src/op/builtin.cc (2)
174-183: Descriptor intrinsics API consistent with wrappers.5 inputs for initialize and 2 for increase align with the Python layer; opaque effect is appropriate given side effects. LGTM.
83-92: WGMMA intrinsics: 15-arg arity verified; codegen enforces 15 operands. Python wrappers define ptx_wgmma_ss/rs and builtin registers .set_num_inputs(15); src/target/codegen_cuda.cc contains ICHECK_EQ(op->args.size(), 15U) for both handlers — operand alignment confirmed.src/op/gemm_py.cc (3)
109-111: Unreachable-return safety: fine as-is.This satisfies compilers post-ICHECK; harmless.
255-257: Unreachable-return safety: fine as-is.Consistent with the ICHECK case; no functional impact.
284-290: FFI reflection for GemmInst selection: LGTM.Simple, side‑effect free binding; useful for runtime dispatch.
src/op/gemm.cc (1)
481-481: LGTM! Consistent layout API migration to boolean flags.The transition from numeric codes (1/2) to boolean flags for k_inner parameter across all backends (Volta, Ampere/Turing/SM120, Hopper) improves API clarity and type safety.
Also applies to: 507-507, 520-520, 545-545, 547-547, 563-563, 565-565
src/target/codegen_cuda.cc (5)
898-898: Including descriptor scope in pointer access logic.The addition of
"local.descriptor"scope alongside"local.var"is correct for treating descriptors as local objects.
1788-1789: LGTM! Proper allocation of descriptor storage.The allocation of
tl::GmmaDescriptorfor descriptor scope is correctly implemented.
1305-1340: Check for additional error conditions in WGMMA intrinsics.The WGMMA SS intrinsic implementation looks correct, but consider adding validation for the descriptor parameters.
Should we verify that the descriptors are properly initialized before use? Consider adding checks to ensure
a_descandb_descare valid descriptor handles.
1701-1713: LGTM! Descriptor operations correctly implemented.The implementation of
initialize_descriptorandincrease_descriptor_offsetintrinsics is correct and properly forwards all parameters to the TL template functions.Also applies to: 1714-1721
1823-1823: Correct handling of descriptor scope in allocation check.The condition properly excludes
local.descriptorfrom the unsupported scope error path.tilelang/language/builtin.py (1)
355-386: LGTM! Well-documented descriptor initialization function.The
initialize_descriptorfunction is well-implemented with proper type checking, parameter documentation, and error handling.tilelang/layout/swizzle.py (2)
10-18: LGTM! Well-structured swizzle layout functions.The updated
make_swizzled_layoutand newmake_wgmma_swizzled_layoutfunctions are properly implemented with clear parameter forwarding to the FFI API.Also applies to: 22-34
39-109: LGTM! Flexible bank-swizzled layout helpers.The three bank-swizzled layout functions (full/half/quarter) are well-implemented with flexible argument handling that supports both buffer objects and explicit parameters.
src/tl_templates/cuda/common.h (1)
307-367: LGTM! Well-structured descriptor union implementation.The
GmmaDescriptorunion is properly designed with:
- Multiple access patterns via
desc_,reg32_[], andreg16_[]- Clear bitfield layout matching CUDA WMMA descriptor format
- Proper copy and move semantics
- Convenient arithmetic operator for offset adjustments
src/layout/layout.cc (2)
461-478: LGTM! Equality check methods properly implemented.The new
Layout_is_equalandFragment_is_equalFFI bindings correctly expose the underlying equality check functionality with proper node casting.
506-511: LGTM! WGMMA swizzled layout properly wired.The
make_wgmma_swizzled_layoutcorrectly passes through the continuity parameter separately from mat_continuous, which is essential for WGMMA's layout requirements.tilelang/tileop/gemm/gemm_wgmma.py (1)
36-37: Verify continuity calculation for k-major layoutstilelang/tileop/gemm/gemm_wgmma.py:36-37,50 — current code:
a_continuity = self.M if a_is_k_major else 4 * self.K // m_warpandb_continuity = self.N if b_is_k_major else 4 * self.K // n_warp. Confirm whether the k-major branch should intentionally useself.M/self.N(instead of a K-derived continuity) and whether the4 *factor is correct; if intentional, add an inline comment explaining the rationale and add a unit test, otherwise correct the formula to derive continuity from K.tilelang/intrinsics/wgmma_macro_generator.py (1)
220-231: Use 64-bit arithmetic for shared-memory offset calculations.A_offset / B_offset in tilelang/intrinsics/wgmma_macro_generator.py (lines 220–231) perform multiplications like
i * 64 * A_buf.shape[-1]that can overflow 32-bit; ensure intermediate arithmetic uses 64-bit (cast the PrimExprs to int64 or use an explicit int64 cast / bounds check) before multiplying byelems_in_bytesand passing to ptx_wgmma_ss.tilelang/tileop/gemm/__init__.py (6)
1-1: LGTM! Addition ofIntEnumimport is appropriate for the newGemmInstclass.The import is necessary for creating the strongly-typed enumeration that will be used to represent different GEMM instruction types.
10-11: LGTM! Import additions align with the new WGMMA support.The imports for
GemmWGMMAand_ffi_apiare necessary for the new architecture-aware dispatch mechanism.
21-24: Good addition oflayout_mapparameter to support WGMMA requirements.The parameter addition correctly passes the layout information through to the lower-level implementation.
27-42: LGTM! Well-designed enumeration for GEMM instruction types.The
GemmInstenumeration provides a clean abstraction for different GEMM implementations with convenient helper methods for type checking.
71-81: Architecture-aware dispatch implementation looks good.The new
infer_layoutandlowermethods properly delegate to architecture-specific implementations using a clean dispatch pattern.
100-120: LGTM! Clean dispatch with proper error handling.The method provides a clear mapping from instruction types to implementation classes with appropriate error handling for unsupported types.
src/target/ptx.cc (7)
149-156: LGTM! Useful helper function for integer type checking.This function provides a convenient way to check if a PTX data type represents an integer, which is needed for WGMMA scale input validation.
188-197: LGTM! Appropriate layout type conversion helper.The function correctly maps boolean values to layout types for WGMMA assembly generation.
278-291: Well-structured WGMMA configuration struct.The struct follows the same pattern as
MMAConfigand provides proper equality comparison for configuration validation.
293-720: Comprehensive WGMMA configuration table.The extensive configuration table covers all supported WGMMA operations including dense and sparse variants across multiple data types. This aligns well with the PR objective of supporting int8, fp8, and tf32 data types.
859-879: Solid configuration validation function.The function properly validates WGMMA configurations by checking data type compatibility and searching for matching configurations in the valid list.
1119-1121: Fatal error for sparse WGMMA is reasonable.Since sparse WGMMA is not yet implemented, failing fast with a clear error message is the appropriate approach.
1235-1306: Well-structured WGMMA assembly generation.The function properly generates PTX assembly for WGMMA operations with support for various configurations including scale parameters and layout options. The template replacement approach is clean and maintainable.
tilelang/language/allocate.py (1)
129-136: New descriptor allocator looks goodMinimal, aligned with ".descriptor" scope handling in codegen. LGTM.
| Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target, | ||
| T.thread_bounds, T.thread_var)); | ||
| ICHECK(prim_func->attrs.defined()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
🧩 Analysis chain
FFI lower signature now includes layout_map; ensure all providers updated.
The C++ call passes T.layout_map. Verify the Python FFI target “tl.gemm_py.lower” accepts (gemm_py, layout_map, target, thread_bounds, thread_var) everywhere to avoid runtime errors.
🏁 Script executed:
#!/bin/bash
# Confirm the new signature at Python side
rg -nP 'def\s+lower\((self|gemm_py)[^)]*layout_map' -C2
# Find the FFI registration or dispatcher for "tl.gemm_py.lower"
rg -n 'gemm_py\.lower' -S
# Check base class signatures to avoid LSP mismatches
rg -nP 'class\s+GemmBase\b.*?\n' -n tilelang/tileop/gemm -C2
rg -nP 'def\s+lower\(' tilelang/tileop/gemm -SLength of output: 2364
Add layout_map to GemmBase.lower
Change tilelang/tileop/gemm/gemm_base.py:18 from
def lower(self, target, thread_nums, thread_var)
to
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var)
so the base signature matches the FFI tl.gemm_py.lower and existing subclass implementations.
src/target/ptx.h
Outdated
| /*! | ||
| * \brief Print WGMMA assembly string given parameters. | ||
| * \param shape The shape string mMnNkK | ||
| * \param A_layout The layout of multiplicand A, can be either "row" or "col". | ||
| * \param B_layout The layout of multiplicand B, can be either "row" or "col". | ||
| * \param A_dtype The data type of multiplicand A. | ||
| * \param B_dtype The data type of multiplicand B. | ||
| * \param C_dtype The data type of multiplicand C. | ||
| */ | ||
| std::string PrintWGMMAAssembly( | ||
| const std::string &shape, const bool &A_layout, const bool &B_layout, | ||
| const std::string &A_dtype, const std::string &B_dtype, | ||
| const std::string &C_dtype, const std::string &a_desc, | ||
| const std::string &A_offset, const std::string &b_desc, | ||
| const std::string &B_offset, const std::string &c_ptr, | ||
| const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, | ||
| const bool &scale_in_b, const bool &a_is_shared, | ||
| const std::string &metadata, const std::string &metadata_offset, | ||
| const std::string &sparsity_selector, bool sparse); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation incomplete and misaligned parameters.
The documentation comment only describes the first 6 parameters, but the function signature has 17 parameters. This incomplete documentation makes it difficult to understand the purpose of the remaining parameters.
Please update the documentation to include all parameters:
/*!
* \brief Print WGMMA assembly string given parameters.
* \param shape The shape string mMnNkK
* \param A_layout The layout of multiplicand A, can be either "row" or "col".
* \param B_layout The layout of multiplicand B, can be either "row" or "col".
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
+ * \param a_desc Descriptor for operand A (shared memory descriptor or register pointer).
+ * \param A_offset Offset for operand A.
+ * \param b_desc Descriptor for operand B.
+ * \param B_offset Offset for operand B.
+ * \param c_ptr Pointer to accumulator C.
+ * \param c_offset Offset for accumulator C.
+ * \param scale_out Scaling output flag.
+ * \param scale_in_a Scaling input A flag.
+ * \param scale_in_b Scaling input B flag.
+ * \param a_is_shared Whether operand A is in shared memory.
+ * \param metadata Pointer to metadata buffer (for sparse operations).
+ * \param metadata_offset Offset in metadata buffer.
+ * \param sparsity_selector Sparsity selector for sparse operations.
+ * \param sparse Whether this is a sparse WGMMA operation.
*/📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| /*! | |
| * \brief Print WGMMA assembly string given parameters. | |
| * \param shape The shape string mMnNkK | |
| * \param A_layout The layout of multiplicand A, can be either "row" or "col". | |
| * \param B_layout The layout of multiplicand B, can be either "row" or "col". | |
| * \param A_dtype The data type of multiplicand A. | |
| * \param B_dtype The data type of multiplicand B. | |
| * \param C_dtype The data type of multiplicand C. | |
| */ | |
| std::string PrintWGMMAAssembly( | |
| const std::string &shape, const bool &A_layout, const bool &B_layout, | |
| const std::string &A_dtype, const std::string &B_dtype, | |
| const std::string &C_dtype, const std::string &a_desc, | |
| const std::string &A_offset, const std::string &b_desc, | |
| const std::string &B_offset, const std::string &c_ptr, | |
| const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, | |
| const bool &scale_in_b, const bool &a_is_shared, | |
| const std::string &metadata, const std::string &metadata_offset, | |
| const std::string &sparsity_selector, bool sparse); | |
| /*! | |
| * \brief Print WGMMA assembly string given parameters. | |
| * \param shape The shape string mMnNkK | |
| * \param A_layout The layout of multiplicand A, can be either "row" or "col". | |
| * \param B_layout The layout of multiplicand B, can be either "row" or "col". | |
| * \param A_dtype The data type of multiplicand A. | |
| * \param B_dtype The data type of multiplicand B. | |
| * \param C_dtype The data type of multiplicand C. | |
| * \param a_desc Descriptor for operand A (shared memory descriptor or register pointer). | |
| * \param A_offset Offset for operand A. | |
| * \param b_desc Descriptor for operand B. | |
| * \param B_offset Offset for operand B. | |
| * \param c_ptr Pointer to accumulator C. | |
| * \param c_offset Offset for accumulator C. | |
| * \param scale_out Scaling output flag. | |
| * \param scale_in_a Scaling input A flag. | |
| * \param scale_in_b Scaling input B flag. | |
| * \param a_is_shared Whether operand A is in shared memory. | |
| * \param metadata Pointer to metadata buffer (for sparse operations). | |
| * \param metadata_offset Offset in metadata buffer. | |
| * \param sparsity_selector Sparsity selector for sparse operations. | |
| * \param sparse Whether this is a sparse WGMMA operation. | |
| */ | |
| std::string PrintWGMMAAssembly( | |
| const std::string &shape, const bool &A_layout, const bool &B_layout, | |
| const std::string &A_dtype, const std::string &B_dtype, | |
| const std::string &C_dtype, const std::string &a_desc, | |
| const std::string &A_offset, const std::string &b_desc, | |
| const std::string &B_offset, const std::string &c_ptr, | |
| const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, | |
| const bool &scale_in_b, const bool &a_is_shared, | |
| const std::string &metadata, const std::string &metadata_offset, | |
| const std::string &sparsity_selector, bool sparse); |
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 68 to 86, the function comment documents only
the first six parameters while the declaration has 17; update the block comment
to include @param entries for every parameter in the signature (A_layout,
B_layout — explain what true/false means or that they represent "row"/"col";
a_desc and b_desc — describe descriptor strings for A/B; A_offset and B_offset —
offset expressions or registers for A/B; c_ptr and c_offset — pointer and offset
used for C output; scale_out, scale_in_a, scale_in_b — boolean flags controlling
scaling behavior; a_is_shared — whether A is in shared memory; metadata and
metadata_offset — metadata pointer/string and its offset; sparsity_selector —
sparsity control selector; sparse — whether sparse mode is enabled), keeping
each description concise and aligned with the existing doc style.
src/target/ptx.h
Outdated
| const std::string &shape, const bool &A_layout, const bool &B_layout, | ||
| const std::string &A_dtype, const std::string &B_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Inconsistent use of const references for boolean parameters.
The function signature uses const bool & for boolean parameters, which is unnecessary overhead for primitive types. This differs from the existing PrintMMAAssembly function that passes booleans by value.
Apply this diff to pass booleans by value:
-std::string PrintWGMMAAssembly(
- const std::string &shape, const bool &A_layout, const bool &B_layout,
+std::string PrintWGMMAAssembly(
+ const std::string &shape, bool A_layout, bool B_layout,
const std::string &A_dtype, const std::string &B_dtype,
const std::string &C_dtype, const std::string &a_desc,
const std::string &A_offset, const std::string &b_desc,
const std::string &B_offset, const std::string &c_ptr,
- const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
- const bool &scale_in_b, const bool &a_is_shared,
+ const std::string &c_offset, bool scale_out, bool scale_in_a,
+ bool scale_in_b, bool a_is_shared,
const std::string &metadata, const std::string &metadata_offset,
const std::string &sparsity_selector, bool sparse);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| const std::string &shape, const bool &A_layout, const bool &B_layout, | |
| const std::string &A_dtype, const std::string &B_dtype, | |
| std::string PrintWGMMAAssembly( | |
| const std::string &shape, bool A_layout, bool B_layout, | |
| const std::string &A_dtype, const std::string &B_dtype, | |
| const std::string &C_dtype, const std::string &a_desc, | |
| const std::string &A_offset, const std::string &b_desc, | |
| const std::string &B_offset, const std::string &c_ptr, | |
| const std::string &c_offset, bool scale_out, bool scale_in_a, | |
| bool scale_in_b, bool a_is_shared, | |
| const std::string &metadata, const std::string &metadata_offset, | |
| const std::string &sparsity_selector, bool sparse); |
🤖 Prompt for AI Agents
In src/target/ptx.h around lines 78-79, the function signature uses `const bool
&A_layout` and `const bool &B_layout` which is unnecessary for primitive types;
change both parameters to pass by value (`bool A_layout, bool B_layout`) and
update any matching function declarations/definitions and callers to use value
semantics so signatures remain consistent with PrintMMAAssembly.
| template <int layout_type = 0, int leading_byte_offset = 0, | ||
| int stride_byte_offset = 0, typename T> | ||
| TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, | ||
| T *start_address) { | ||
| descriptor.bitfield.start_address_ = | ||
| cute::cast_smem_ptr_to_uint(start_address) >> 4; | ||
| descriptor.bitfield.layout_type_ = layout_type; | ||
| descriptor.bitfield.base_offset_ = 0; | ||
| descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; | ||
| descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; | ||
| } | ||
|
|
||
| template <typename T> | ||
| TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, | ||
| T offset) { | ||
| descriptor.reg32_[0] += (offset >> 4); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
🧩 Analysis chain
Consider validating descriptor parameters.
The initialize_descriptor and increase_descriptor_offset functions directly manipulate hardware descriptor fields without validation. Consider adding assertions or checks for:
- Valid layout_type values (0-3)
- Byte offset alignment requirements
- Address alignment to 16-byte boundaries
🏁 Script executed:
#!/bin/bash
# Check if there are any existing validation patterns for descriptor initialization
ast-grep --pattern 'initialize_descriptor($$$) {
$$$
assert($$$)
$$$
}'
# Look for any documentation or comments about descriptor constraints
rg -A 3 -B 3 "descriptor.*align|align.*descriptor" --type cpp --type hLength of output: 333
Validate GMMA descriptor parameters before writing hardware fields
These functions write descriptor bitfields with no checks; src/target/cuda.h documents a 64‑byte alignment requirement for tensor map descriptors. Add assertions/guards in src/tl_templates/cuda/common.h (initialize_descriptor and increase_descriptor_offset, ~lines 405–421):
- Assert layout_type is in the supported range (0–3).
- Assert start_address is 16‑byte aligned (e.g., (uintptr_t)start_address & 0xF == 0).
- Assert leading_byte_offset, stride_byte_offset, and the offset passed to increase_descriptor_offset are multiples of 16 (since code uses >> 4) and fit the descriptor bitfield widths; also guard against reg32_[0] overflow when adding (offset >> 4).
Files to edit: src/tl_templates/cuda/common.h (initialize_descriptor, increase_descriptor_offset). Reference: src/target/cuda.h (tensor map descriptor alignment comment).
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 405 to 421, add validation before
writing descriptor fields: assert layout_type is within 0..3; assert
start_address is 16-byte aligned ((uintptr_t)start_address & 0xF == 0); assert
leading_byte_offset and stride_byte_offset are multiples of 16 and within the
bitfield width limits documented for GmmaDescriptor; in
increase_descriptor_offset assert offset is a multiple of 16, compute delta =
offset >> 4 and check that descriptor.reg32_[0] + delta will not overflow the
destination bitfield (or reg32_) before applying it; if any check fails, either
trigger a fast-fail assertion or return/handle the error path so invalid values
are never written to hardware fields.
| if not isinstance(descriptor, (BufferLoad, Buffer)): | ||
| raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") | ||
|
|
||
| if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix operator precedence ambiguity.
The condition mixes and and or operators without proper parentheses, which can lead to unexpected behavior.
Apply this diff to fix the operator precedence:
- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+ if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):The same issue exists in increase_descriptor_offset:
- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
+ if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):Also applies to: 403-403
🧰 Tools
🪛 Ruff (0.12.2)
377-377: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around lines 377 and 403, the conditional mixes
"and" and "or" without parentheses causing ambiguous precedence; wrap the shape
checks in parentheses so the intent is clear — e.g. make the condition check
"isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1)" (and apply the same parenthesized grouping in
increase_descriptor_offset at line 403) to ensure the logical grouping is
correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/target/ptx.h (1)
29-31: Missing headers for Replacer (compile blocker).Replacer uses std::vector and std::pair but (and ideally ) aren’t included.
Apply:
#include <string> #include <tuple> +#include <vector> +#include <utility>src/target/ptx.cc (1)
26-31: Missing (compile blocker).std::stringstream is used but header isn’t included.
Apply:
#include <algorithm> +#include <sstream> #include <string> #include <tuple> #include <utility> #include <vector>
♻️ Duplicate comments (5)
src/target/ptx.h (2)
156-164: WGMMA Doxygen is incomplete and mismatched with the signature.Only the first 6 params are documented; names refer to A/B_layout strings while the API uses a_is_k_major/b_is_k_major and many more args.
Recommend adding @param docs for all parameters and aligning names/semantics with the signature (a_is_k_major, b_is_k_major, a_desc, A_offset, b_desc, B_offset, c_ptr, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, metadata, metadata_offset, sparsity_selector, sparse). Want a ready-to-apply doc patch?
165-174: Pass booleans by value, not const references.Primitive bools shouldn’t be passed by const&. Keep consistent with other APIs.
Apply:
-std::string PrintWGMMAAssembly( - const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major, +std::string PrintWGMMAAssembly( + const std::string &shape, bool a_is_k_major, bool b_is_k_major, const std::string &A_dtype, const std::string &B_dtype, const std::string &C_dtype, const std::string &a_desc, const std::string &A_offset, const std::string &b_desc, const std::string &B_offset, const std::string &c_ptr, - const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, - const bool &scale_in_b, const bool &a_is_shared, + const std::string &c_offset, bool scale_out, bool scale_in_a, + bool scale_in_b, bool a_is_shared, const std::string &metadata, const std::string &metadata_offset, const std::string &sparsity_selector, bool sparse);src/tl_templates/cuda/common.h (3)
438-454: Add basic parameter validation (layout_type range).Follow-up to earlier review on descriptor validation.
Add: assert(0 <= layout_type && layout_type <= 3).
450-454: increase_descriptor_offset risks overflow/carry into other fields.Directly adding to reg32_[0] can spill into adjacent fields; enforce 16B granularity and update the bitfield safely.
Apply:
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, T offset) { - descriptor.reg32_[0] += (offset >> 4); + assert((offset % 16) == 0); + uint32_t delta = uint32_t(offset >> 4); + uint32_t new_sa = uint32_t(descriptor.bitfield.start_address_) + delta; + // 14-bit field + assert((new_sa & ~0x3FFFu) == 0); + descriptor.bitfield.start_address_ = uint16_t(new_sa & 0x3FFFu); }
438-448: initialize_descriptor: missing >>4 for byte-offset fields; no validation.Bitfields exclude 4 LSBs; current code writes raw byte offsets. Also lacks alignment/range checks.
Apply:
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, T *start_address) { - descriptor.bitfield.start_address_ = - cute::cast_smem_ptr_to_uint(start_address) >> 4; + auto sa = cute::cast_smem_ptr_to_uint(start_address); + // 16B alignment and 14-bit range + assert((sa & 0xF) == 0); + descriptor.bitfield.start_address_ = uint16_t((sa >> 4) & 0x3FFF); descriptor.bitfield.layout_type_ = layout_type; descriptor.bitfield.base_offset_ = 0; - descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; - descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; + assert((leading_byte_offset % 16) == 0); + assert((stride_byte_offset % 16) == 0); + descriptor.bitfield.leading_byte_offset_ = + uint16_t(((leading_byte_offset >> 4) & 0x3FFF)); + descriptor.bitfield.stride_byte_offset_ = + uint16_t(((stride_byte_offset >> 4) & 0x3FFF)); }Also assert layout_type in [0,3].
🧹 Nitpick comments (14)
tilelang/language/tir/op.py (1)
1106-1144: Add a docstring for ptx_wGMMA_rs; clarify variant.Mirror the ss variant doc and note A is a buffer (not a descriptor).
Apply this diff:
def ptx_wgmma_rs( dtype, wgmma_prefix, a_is_k_major, b_is_k_major, a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, A_buf, A_offset, B_desc, B_offset, C_data, C_offset, scale_out, scale_in_a, scale_in_b, ): - - + """PTX WGMMA (warp-group MMA) intrinsic wrapper. + + Variant: rs (A uses buffer pointer, B uses descriptor). + See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA). + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + """Optional: consider renaming A_buf to A_data or A_ptr for consistency with C_data (only if no keyword callers rely on A_buf).
src/target/ptx.h (1)
45-69: Two DataType enums exist (ptx::DataType here and tl::DataType in common.h).Risk of drift and conversion friction.
Add explicit conversion helpers and static_assert ordinal equivalence in one place, or consolidate to a single enum declared in a shared header.
src/tl_templates/cuda/common.h (1)
18-19: Redundant SMEM pointer casters.Both cute::cast_smem_ptr_to_uint and local helpers exist (cast_smem_ptr_to_int/smem_ptr_to_uint).
Pick one idiom and remove duplicates to avoid confusion.
src/tl_templates/cuda/instruction/wgmma.h (1)
22-35: Don’t rely on device printf; fail fast when unspecialized.Keep printf only under a debug guard; restore a compile-time error otherwise.
Apply:
- ) { - printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, scaleB=%d\n", - (int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB); - // 暂时注释掉 static_assert 来看调试输出 - // static_assert(always_false_v<decltype(c)>, - // "wgmma_ss: No specialization available for given template parameters!"); - }; + ) { +#if defined(TL_DEBUG_WGMMA_FALLBACK) + printf("WGMMA fallback A=%d B=%d C=%d M=%d N=%d K=%d tnspA=%d tnspB=%d scaleA=%d scaleB=%d\n", + (int)A_type, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB); +#else + static_assert(always_false_v<decltype(c)>, + "wgmma_ss: No specialization available for given template parameters"); +#endif + };src/target/ptx.cc (2)
171-177: Pass bool by value in LayoutTypeFromBool.Avoid const& for primitives.
Apply:
-LayoutType LayoutTypeFromBool(const bool &layout) { +LayoutType LayoutTypeFromBool(bool layout) {
1186-1195: Unnecessary const& on bools in PrintWGMMAAssembly.Match the header change; pass by value.
Apply:
-std::string PrintWGMMAAssembly( - const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major, +std::string PrintWGMMAAssembly( + const std::string &shape, bool a_is_k_major, bool b_is_k_major, const std::string &A_dtype, const std::string &B_dtype, const std::string &C_dtype, const std::string &a_desc, const std::string &A_offset, const std::string &b_desc, const std::string &B_offset, const std::string &c_ptr, - const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, - const bool &scale_in_b, const bool &a_is_shared, + const std::string &c_offset, bool scale_out, bool scale_in_a, + bool scale_in_b, bool a_is_shared, const std::string &metadata, const std::string &metadata_offset, const std::string &sparsity_selector, bool sparse) {src/target/codegen_cuda.cc (3)
1306-1317: Fix argument docs for tl::ptx_wgmma_ss branchThe comments don’t match the actual args (dtype is the call’s return type, not in
op->args). Please correct to avoid future regressions.- // arg 0: dtype - // arg 1: shape - // arg 2: A_layout - // arg 3: B_layout - // arg 4: A_dtype - // arg 5: B_dtype - // arg 6: C_dtype - // arg 7: multiplicand_a - // arg 8: multiplicand_b - // arg 9: accumulator - // arg 10: saturate + // arg 0: wgmma_prefix (shape string, e.g. "m64n128k32") + // arg 1: a_is_k_major (bool) + // arg 2: b_is_k_major (bool) + // arg 3: a_dtype_abbrv (string) + // arg 4: b_dtype_abbrv (string) + // arg 5: accum_dtype_abbrv (string) + // arg 6: A_desc (descriptor, shared path) + // arg 7: A_offset (byte offset) + // arg 8: B_desc (descriptor) + // arg 9: B_offset (byte offset) + // arg 10: C_data (ptr) + // arg 11: C_offset (byte offset) + // arg 12: scale_out (bool) + // arg 13: scale_in_a (bool) + // arg 14: scale_in_b (bool)
1364-1375: Align RS branch docs with actual argsMirror the corrected SS docs for RS; current comments mention “dtype/saturate”, which don’t exist here.
- // arg 0: dtype - // arg 1: shape - // arg 2: A_layout - // arg 3: B_layout - // arg 4: A_dtype - // arg 5: B_dtype - // arg 6: C_dtype - // arg 7: multiplicand_a - // arg 8: multiplicand_b - // arg 9: accumulator - // arg 10: saturate + // arg 0: wgmma_prefix (shape) + // arg 1: a_is_k_major (bool) + // arg 2: b_is_k_major (bool) + // arg 3: a_dtype_abbrv (string) + // arg 4: b_dtype_abbrv (string) + // arg 5: accum_dtype_abbrv (string) + // arg 6: A_buf (global pointer, non-shared path) + // arg 7: A_offset (byte offset) + // arg 8: B_desc (descriptor) + // arg 9: B_offset (byte offset) + // arg 10: C_data (ptr) + // arg 11: C_offset (byte offset) + // arg 12: scale_out (bool) + // arg 13: scale_in_a (bool) + // arg 14: scale_in_b (bool)
1724-1743: Descriptor intrinsics emission — consider explicit offset type and statement-only usage
- Template parameter for
increase_descriptor_offsetshould be explicit-width to avoid ABI surprises on different toolchains.- These are side-effect calls; ensure they’re only used in
EvaluateNodecontexts.- os << "tl::increase_descriptor_offset<int>(" << PrintExpr(descriptor) + os << "tl::increase_descriptor_offset<int32_t>(" << PrintExpr(descriptor) << ", " << PrintExpr(offset) << ")";Would you like me to scan call sites to confirm they’re only used as statements?
tilelang/intrinsics/wgmma_macro_generator.py (5)
216-219: Use read access for descriptor base pointers.These descriptors are read by the instruction; use
access_ptr("r"), not"w". Covered in the earlier diffs.Also applies to: 283-283
104-109: Remove unusedn_dimparameter.
_initialize_wgmma_prefix(self, n_dim: int = 16)doesn’t usen_dim(Ruff ARG002). Simplify:- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): @@ - self._initialize_wgmma_prefix(self.n_dim) + self._initialize_wgmma_prefix()Also applies to: 94-95
333-333: Drop duplicate import ofis_fragment.Already imported at the top; remove the inner import.
- from tilelang.utils import is_fragment
312-317: Fix docstring: this is a load layout for operand A, not a store layout.- Create a layout function for storing MMA results into a fragment buffer. - This layout is used in conjunction with `inverse_mma_store_layout` to - map fragment indices to threads and local indices. + Create a layout describing how to load MMA operand A from a fragment buffer. + This layout is used in conjunction with `inverse_mma_load_layout` to + map fragment indices to threads and per-thread local indices.
41-59: Clarifyswizzle_atom_size()semantics or compute from bytes to avoid confusion.If you keep this helper, consider defining it as
swizzle_byte_size() // 16(16‑byte atoms), to match descriptor units terminology. Current// 16on bits is non‑obvious. Usage has been corrected in earlier diffs, but this improves readability:def swizzle_atom_size(self) -> int: - if self.is_swizzle_32b(): - return 32 // 16 - elif self.is_swizzle_64b(): - return 64 // 16 - elif self.is_swizzle_128b(): - return 128 // 16 - else: - return 1 + # number of 16-byte atoms in the swizzle size (32B→2, 64B→4, 128B→8) + return self.swizzle_byte_size() // 16
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
src/op/builtin.h(2 hunks)src/target/codegen_cuda.cc(5 hunks)src/target/ptx.cc(9 hunks)src/target/ptx.h(2 hunks)src/tl_templates/cuda/common.h(5 hunks)src/tl_templates/cuda/gemm.h(1 hunks)src/tl_templates/cuda/instruction/wgmma.h(1 hunks)tilelang/intrinsics/wgmma_macro_generator.py(1 hunks)tilelang/language/tir/op.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (5)
src/tl_templates/cuda/instruction/wgmma.h (2)
src/op/builtin.h (1)
tl(22-362)src/tl_templates/cuda/common.h (5)
tl(306-400)DataType(315-360)int(97-100)int(135-142)uint32_t(118-120)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
scope(674-678)scope(674-674)n(366-370)n(366-366)n(371-375)n(371-371)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)src/target/ptx.cc (8)
PrintWGMMAAssembly(1186-1257)PrintWGMMAAssembly(1186-1195)ParseMMAShape(140-148)ParseMMAShape(140-140)DTypeEnumToString(107-109)DTypeEnumToString(107-107)DTypeEnumToString(111-113)DTypeEnumToString(111-111)tilelang/language/builtin.py (2)
initialize_descriptor(355-386)increase_descriptor_offset(389-411)
src/tl_templates/cuda/common.h (3)
src/target/ptx.h (1)
DataType(45-91)src/tl_templates/cuda/ldsm.h (12)
void(7-14)void(16-23)void(25-33)void(35-42)void(44-52)void(54-62)void(64-70)void(72-79)void(81-89)void(91-98)void(100-108)void(110-119)tilelang/language/builtin.py (2)
initialize_descriptor(355-386)increase_descriptor_offset(389-411)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/layout/swizzle.py (3)
make_full_bank_swizzled_layout(39-59)make_half_bank_swizzled_layout(64-84)make_quarter_bank_swizzled_layout(89-109)tilelang/layout/fragment.py (3)
Fragment(13-213)replicate(146-160)repeat(123-144)tilelang/language/builtin.py (1)
initialize_descriptor(355-386)tilelang/language/tir/op.py (3)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)any(1774-1790)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map(159-165)
src/target/ptx.h (1)
src/target/ptx.cc (10)
DTypeFromString(54-104)DTypeFromString(54-54)DTypeEnumToString(107-109)DTypeEnumToString(107-107)DTypeEnumToString(111-113)DTypeEnumToString(111-111)ParseMMAShape(140-148)ParseMMAShape(140-140)PrintWGMMAAssembly(1186-1257)PrintWGMMAAssembly(1186-1195)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py
104-104: Unused method argument: n_dim
(ARG002)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
353-353: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (9)
tilelang/language/tir/op.py (1)
1064-1104: Change docstring from WMMA → WGMMA and update linkFile: tilelang/language/tir/op.py (ptx_wgmma_ss). Docstring currently says WMMA; it should reference WGMMA (warp-group). Verified tl.ptx_wgmma_* ops are registered and expect 15 args.
def ptx_wgmma_ss( @@ - """TVM intrinsic for ptx tensor core wmma instructions - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-wmma - """ + """PTX WGMMA (warp-group MMA) intrinsic wrapper. + + Variant: ss (A uses descriptor, B uses descriptor). + See NVIDIA PTX ISA — Warpgroup matrix instructions (WGMMA). + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + """src/tl_templates/cuda/instruction/wgmma.h (2)
417-452: INT8 coverage seems partial (only N=8/16).valid_wgmma_configs include N up to 256. If codegen may request N=32/64/… you’ll hit the fallback at compile time.
Would you like me to generate the missing specializations (or a meta-emitter) for s8->s32 shapes N∈{32,64,96,128,192,256}?
457-491: FP8 E4M3→F16/F32: only N=8 variants present.Configs list N up to 256. Confirm dispatcher never asks for larger N, or add the remaining specializations.
I can auto-generate these blocks to match the config table.
src/tl_templates/cuda/gemm.h (1)
5-6: Include order LGTM.Including wgmma.h in the sm90+ path before gemm_sm90.h is appropriate.
src/target/ptx.cc (1)
1222-1256: WGMMA operand builder: sparse path abort is fine; ensure scale_in use matches dtype.Scale inputs are emitted only for non-integer types; that matches spec. No action needed.
Please confirm codegen never requests sparse WGMMA yet; otherwise return a diagnostic instead of LOG(FATAL) at higher layers.
src/tl_templates/cuda/common.h (1)
315-339: Enum duplication with ptx::DataType — parity unverifiedNo matches found for tvm::tl::codegen::ptx::DataType; ensure src/tl_templates/cuda/common.h (DataType, lines 315–339) has identical ordinals/values to the ptx definition. To verify locally, run:
rg -nP '\benum class DataType\b' -S rg -nP 'ptx::DataType|namespace\s+ptx' -Ssrc/target/codegen_cuda.cc (2)
898-901: Local descriptor deref handling in GetBufferRef — LGTMReturning the scalar
vidfor"local.descriptor"(same as"local.var") avoids bogus[]/casts ontl::GmmaDescriptor. This matches the Python-sideBuffer(shape=[1])abstraction.
1810-1812: Ensure tl::GmmaDescriptor is visible from generated TUsrc/target/codegen_cuda.cc already emits "#include <tl_templates/cuda/gemm.h>" in Finish() (decl_stream at src/target/codegen_cuda.cc:186). No declaration of tl::GmmaDescriptor was found in the repo—confirm that gemm.h (or an included header like common.h) actually defines tl::GmmaDescriptor; if it does not, add an explicit include for the header that declares it in Finish().
tilelang/intrinsics/wgmma_macro_generator.py (1)
104-109: Validateinst_ninwgmma_prefixagainst supported shapes.Hardware supports specific N sizes (e.g., 8/16/32/64/128).
inst_n = block_col_warps * warp_col_tilesmay produce unsupported values. Add a guard or normalize.
| // Decay to a uint64_t | ||
| CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { | ||
| return desc_; | ||
| } | ||
| template <typename T> | ||
| CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { | ||
| GmmaDescriptor ret; | ||
| ret.reg32_[0] = reg32_[0] + uint32_t(offset); | ||
| ret.reg32_[1] = reg32_[1]; | ||
| return ret; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operator+ adds raw bytes; descriptor fields are in 16‑byte quanta.
The descriptor encodes addresses/offsets with 4 LSBs stripped. Adding raw offset likely mis-encodes.
Apply:
- template <typename T>
- CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
+ template <typename T>
+ CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
GmmaDescriptor ret;
- ret.reg32_[0] = reg32_[0] + uint32_t(offset);
+ ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}And document that offset is in bytes.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // Decay to a uint64_t | |
| CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { | |
| return desc_; | |
| } | |
| template <typename T> | |
| CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { | |
| GmmaDescriptor ret; | |
| ret.reg32_[0] = reg32_[0] + uint32_t(offset); | |
| ret.reg32_[1] = reg32_[1]; | |
| return ret; | |
| } | |
| // Decay to a uint64_t | |
| CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { | |
| return desc_; | |
| } | |
| template <typename T> | |
| CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { | |
| GmmaDescriptor ret; | |
| ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4); | |
| ret.reg32_[1] = reg32_[1]; | |
| return ret; | |
| } |
| assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add divisibility assertion for K micro-tiles.
You iterate range(k_dim // micro_size_k) but don’t enforce divisibility; tails would be silently dropped. Add:
- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"Apply in both SS and RS.
Also applies to: 255-256
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 168-169 (and
similarly around lines 255-256), you assert k_dim >= micro_size_k but do not
require k_dim to be divisible by micro_size_k, causing any tail to be silently
dropped when iterating range(k_dim // micro_size_k); add an assertion ensuring
k_dim % micro_size_k == 0 with a clear error message (e.g., "k_dim must be
divisible by micro_size_k, got k_dim: {k_dim}, micro_size_k: {micro_size_k}") in
both the SS and RS sections so the code fails fast instead of silently
truncating tails.
| assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mirror the fixes in RS path (per‑operand sizes, swizzle LBO units, pointer mode).
Do the same corrections for RS:
- Use separate
a_elems_in_bytes/b_elems_in_bytes. - Use
swizzle_byte_size()for LBO and keep>>4. - Use
access_ptr("r")for descriptors.
Also add the divisibility assert (see separate comment). Apply:
- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
+ assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"
@@
- elems_in_bytes = DataType(self.a_dtype).bits // 8
+ a_elems_in_bytes = DataType(self.a_dtype).bits // 8
+ b_elems_in_bytes = DataType(self.b_dtype).bits // 8
@@
- b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
- elems_in_bytes)
- b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
- elems_in_bytes)
+ b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim *
+ b_elems_in_bytes)
+ b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 *
+ b_elems_in_bytes)
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
else:
# MN Major
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
- b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
- b_stride_byte_offset = 8 * n_dim * elems_in_bytes
+ b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size()
+ b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes
@@
- T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
+ T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
@@
- (B_offset * elems_in_bytes) >> 4,
+ (B_offset * b_elems_in_bytes) >> 4,Also applies to: 257-278, 283-284, 301-302
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 255-256 (and
similarly apply changes to ranges 257-278, 283-284, 301-302), the RS path
currently mirrors the MS fixes incorrectly: replace any single shared
elems_in_bytes with separate a_elems_in_bytes and b_elems_in_bytes computed from
their per‑operand element sizes; for LBO swizzle sizes call
swizzle_byte_size(...) and retain the bitshift >> 4 as before; change descriptor
pointer creation to use access_ptr("r") instead of the current mode; and add the
divisibility assert that ensures k_dim (or relevant dim) is divisible by
micro_size_k (matching the other path). Ensure these edits are applied
consistently across the listed line ranges.
| accum_dtype, | ||
| wgmma_prefix, | ||
| self.a_transposed, | ||
| not self.b_transposed, | ||
| a_dtype_abbrv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Likely wrong K‑major flags passed to RS intrinsic (inverted vs SS).
ptx_wgmma_rs takes the same a_is_k_major/b_is_k_major semantics as SS. SS passes not self.a_transposed and self.b_transposed, but RS passes their negations. Align them:
- self.a_transposed,
- not self.b_transposed,
+ (not self.a_transposed),
+ self.b_transposed,Please confirm against the PTX WGMMA docs for RS.
🌐 Web query:
What are the `a_is_k_major` / `b_is_k_major` semantics for `wgmma.mma_async` RS vs SS variants in NVIDIA PTX?
💡 Result:
Short answer: those flags tell the instruction whether the A / B operand’s shared‑memory descriptor (SS) — or the register‑sourced operand’s logical layout (RS) — is K‑major (i.e. the 8×2 / 2×(N/8) “core‑matrix” tile stride is along the K dimension) or not. Concretely:
-
SS variants (the .ss forms) take 64‑bit matrix descriptors for A and/or B; the descriptor’s “major‑K” bit(s) (a_is_k_major / b_is_k_major) mean “this SMEM tile is laid out K‑major” so the hardware will interpret LBO/SBO in the descriptor accordingly (core matrices are arranged with K as the strided direction). Use these bits when your shared‑memory tile is stored with K as the leading/strided dimension. [1][2]
-
RS variants (the .rs forms) source one operand from registers (or a thread‑local fragment) and the other from a descriptor; for RS the a_is_k_major / b_is_k_major semantic still denotes whether that operand’s logical tile is K‑major, but for the register‑sourced operand the layout interpretation is applied to the fragment/register mapping (i.e. whether the fragment is arranged as K‑major vs M/N‑major). In practice this means you must set the bit to match how your register fragment was produced (or how you packed it), otherwise the multiply/accumulate will use the wrong element ordering. [1][2]
References:
- PTX WGMMA / wgmma.mma_async description (matrix descriptors, K‑major vs MN‑major and SS/RS usage). [1]
- CUTLASS / WGMMA notes (explanation of descriptors, LBO/SBO and K‑major layout and the distinction of SS mode requiring descriptors). [2]
[1] NVIDIA PTX ISA (wgmma / wgmma.mma_async matrix descriptor & layout).
[2] CUTLASS / WGMMA tutorial and notes on shared‑memory descriptors and K‑major layout.
Fix inverted K‑major flags for RS WGMMA intrinsic
RS uses the same a_is_k_major / b_is_k_major semantics as SS; the call currently passes their negations. Update the args at tilelang/intrinsics/wgmma_macro_generator.py:291-295:
- self.a_transposed,
- not self.b_transposed,
+ (not self.a_transposed),
+ self.b_transposed,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| accum_dtype, | |
| wgmma_prefix, | |
| self.a_transposed, | |
| not self.b_transposed, | |
| a_dtype_abbrv, | |
| accum_dtype, | |
| wgmma_prefix, | |
| (not self.a_transposed), | |
| self.b_transposed, | |
| a_dtype_abbrv, |
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 291-295 the K-major
flags for the RS WGMMA intrinsic are passed inverted relative to the SS
semantics; replace the current arguments so the a/b K-major booleans are the
negations of their current values (change self.a_transposed to not
self.a_transposed, and change not self.b_transposed to self.b_transposed) so the
call passes the correct a_is_k_major and b_is_k_major flags.
…e size calculations in wgmma macro generator. Enhanced handling of leading and stride byte offsets based on swizzle mode, improving clarity and performance in tensor core intrinsic emissions.
| print(f"b_leading_byte_offset: {b_leading_byte_offset}") | ||
| print(f"b_stride_byte_offset: {b_stride_byte_offset}") | ||
| # exit() | ||
| @T.macro |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // 暂时注释掉 static_assert 来看调试输出 | ||
| // static_assert(always_false_v<decltype(c)>, | ||
| // "wgmma_ss: No specialization available for given template parameters!"); | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class GemmInst(IntEnum): | ||
| MMA = 0 | ||
| WGMMMA = 1 | ||
| MFMA = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
src/target/codegen_cuda.cc (1)
1305-1362: WGMMA-SS: emit validated asm_code; drop ad‑hoc string (invalid arithmetic on descriptor, dead code)
asm_codefromPrintWGMMAAssembly(...)is computed and then ignored. The customwgmma_asm_codedoes pointer arithmetic ontl::GmmaDescriptorviauint64_t((desc_a) + (A_offset)), which is invalid. Unify with the RS path and streamasm_code.Apply:
const bool a_is_shared = true; this->PrintIndent(); std::string asm_code = PrintWGMMAAssembly( shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, "", "", "", false); - auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); - std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n"; - // replace patterns - tl::codegen::Replacer replacer; - replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype)); - replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype)); - replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype)); - replacer.register_rule("(M)", std::to_string(m)); - replacer.register_rule("(N)", std::to_string(n)); - replacer.register_rule("(K)", std::to_string(k)); - replacer.register_rule("(tnspA)", a_is_k_major? "false": "true"); - replacer.register_rule("(tnspB)", b_is_k_major? "false": "true"); - replacer.register_rule("(scaleA)", scale_in_a? "1": "-1"); - replacer.register_rule("(scaleB)", scale_in_b? "1": "-1"); - replacer.register_rule("(desc_a)", a_desc); - replacer.register_rule("(A_offset)", A_offset); - replacer.register_rule("(desc_b)", b_desc); - replacer.register_rule("(B_offset)", B_offset); - replacer.register_rule("(C)", c_ref + " + " + c_offset); - replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); - wgmma_asm_code = replacer.rewrite(wgmma_asm_code); - this->stream << wgmma_asm_code; + this->stream << asm_code;tilelang/intrinsics/wgmma_macro_generator.py (5)
159-169: Add validations:m_dimmultiple-of-64 andk_dimdivisibility bymicro_size_k.Prevents silent no-op loops and tails.
m_dim = self.block_row_warps * self.warp_row_tiles warp_cols = self.warp_cols micro_size_k = self.micro_size_k k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles wgmma_prefix = self.wgmma_prefix @@ - assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert m_dim >= 64 and (m_dim % 64) == 0, f"m_dim must be a multiple of 64, got {m_dim}" + assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"
176-219: Fix LBO/SBO units, per-operand element sizes, and descriptor pointer modes (SS path).Currently mixes A’s element size for B, uses
swizzle_atom_size()then>>4(double-divide by 16), and passesaccess_ptr("w"). This will generate incorrect descriptors and offsets, especially for int8/fp8/tf32.- elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_elems_in_bytes = DataType(self.a_dtype).bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 @@ - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim * + a_elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 * + a_elems_in_bytes) @@ - else: + else: # MN Major # LBO represents the distance between two atoms along the M dimension # SBO represents the distance between two atoms along the K dimension - a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size() - a_stride_byte_offset = 8 * 64 * elems_in_bytes + a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size() + a_stride_byte_offset = 8 * 64 * a_elems_in_bytes @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * - elems_in_bytes) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * + b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * + b_elems_in_bytes) if not b_swizzle_mode.is_none(): @@ - else: - # MN Major, K * N - # LBO represents the distance between two atoms along the N dimension - # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim - else: - b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + else: + # MN Major (N × K): use swizzle byte size for LBO; SBO spans N in bytes. + b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() + b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes @@ - desc_a = T.alloc_descriptor() - desc_b = T.alloc_descriptor() - T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode, + desc_a = T.alloc_descriptor() + desc_b = T.alloc_descriptor() + T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes) @@ - accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4, - desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data, + accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4, + desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data,Also applies to: 239-254
278-279: Add RS-side divisibility assert.Mirror the SS check to avoid dropped tails.
- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert k_dim >= micro_size_k, f"k_dim must be ≥ {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"
280-309: Mirror per-operand sizes and descriptor pointer mode (RS path).Use
b_elems_in_bytesandaccess_ptr("r"); fix LBO/SBO units for swizzle.- elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_elems_in_bytes = DataType(self.a_dtype).bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * - elems_in_bytes) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * + b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * + b_elems_in_bytes) @@ - else: - # MN Major - # LBO represents the distance between two atoms along the N dimension - # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim - else: - b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + else: + # MN Major + b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() + b_stride_byte_offset = 8 * n_dim * b_elems_in_bytes @@ - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - (B_offset * elems_in_bytes) >> 4, + (B_offset * b_elems_in_bytes) >> 4,Also applies to: 313-314, 332-332
324-325: Fix inverted K-major flags for RS.RS uses the same a/b K-major semantics as SS.
- self.a_transposed, - not self.b_transposed, + (not self.a_transposed), + self.b_transposed,
🧹 Nitpick comments (6)
src/target/codegen_cuda.cc (1)
1363-1397: WGMMA-RS path looks good; minor naming/comment nitLogic mirrors SS but correctly streams
asm_code. Consider renamingA_layout/B_layouttoa_is_k_major/b_is_k_majorand fixing the arg comments (arg 0 is shape; dtype is not an arg) for clarity.tilelang/intrinsics/wgmma_macro_generator.py (5)
104-109: Remove unused arg and make prefix computation explicit.
n_dimis unused; also guardinst_kwith a sanity check (tf32 can be tricky).- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles - # 256 bits per instruction - inst_k = 256 // DataType(self.a_dtype).bits + # k derived from input dtype; ensure it's one of the valid WGMMA shapes. + bits = DataType(self.a_dtype).bits + assert 256 % bits == 0, f"Unsupported dtype bits for WGMMA prefix: {bits}" + inst_k = 256 // bits self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"- self._initialize_wgmma_prefix(self.n_dim) + self._initialize_wgmma_prefix()Please verify tf32’s effective width in your stack (if tvm “tf32” reports 19 bits, switch to an explicit dtype→k map).
Also applies to: 94-95
221-234: Remove debug prints (noise in codegen path).These spam stdout during TIR lowering; drop or guard behind a verbose flag.
- print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}") - print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}") - - print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}") - print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}") - print(f"m_dim: {m_dim}") - print(f"n_dim: {n_dim}") - print(f"k_dim: {k_dim}") - print(f"micro_size_k: {micro_size_k}") - print(f"a_leading_byte_offset: {a_leading_byte_offset}") - print(f"a_stride_byte_offset: {a_stride_byte_offset}") - print(f"b_leading_byte_offset: {b_leading_byte_offset}") - print(f"b_stride_byte_offset: {b_stride_byte_offset}")
342-347: Docstring mismatch: this is a load-layout helper.Tweak wording to avoid confusion.
- Create a layout function for storing MMA results into a fragment buffer. + Create a layout function for loading operand A into a fragment buffer.
364-367: Drop redundant import inside method.
is_fragmentis already imported at module scope.- from tilelang.utils import is_fragment
143-143: Tidy exception messages per linter hint (TRY003).Shorten or use a custom exception type; current messages are fine but flagged.
Also applies to: 384-384
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/target/codegen_cuda.cc(5 hunks)tilelang/intrinsics/wgmma_macro_generator.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
scope(674-678)scope(674-674)n(366-370)n(366-366)n(371-375)n(371-371)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)src/target/ptx.cc (8)
PrintWGMMAAssembly(1186-1257)PrintWGMMAAssembly(1186-1195)ParseMMAShape(140-148)ParseMMAShape(140-140)DTypeEnumToString(107-109)DTypeEnumToString(107-107)DTypeEnumToString(111-113)DTypeEnumToString(111-111)tilelang/language/builtin.py (2)
initialize_descriptor(355-386)increase_descriptor_offset(389-411)
tilelang/intrinsics/wgmma_macro_generator.py (8)
tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/layout/layout.py (3)
Layout(13-148)is_equal(136-145)inverse(125-134)tilelang/layout/swizzle.py (3)
make_full_bank_swizzled_layout(39-59)make_half_bank_swizzled_layout(64-84)make_quarter_bank_swizzled_layout(89-109)tilelang/layout/fragment.py (4)
is_equal(209-213)Fragment(13-213)replicate(146-160)repeat(123-144)tilelang/language/allocate.py (1)
alloc_descriptor(129-135)tilelang/language/builtin.py (1)
initialize_descriptor(355-386)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map(159-165)
🪛 Ruff (0.12.2)
tilelang/intrinsics/wgmma_macro_generator.py
104-104: Unused method argument: n_dim
(ARG002)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
384-384: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Cursor Bugbot
🔇 Additional comments (5)
src/target/codegen_cuda.cc (5)
898-901: Descriptor refs: correct special‑casing for local.descriptorReturning the variable directly for descriptor buffers is the right behavior and aligns with how the descriptor is declared.
1721-1721: LGTMNo issues spotted in tl_shuffle_elect emission.
1845-1845: Scope guard update is correctSkipping the unsupported-scope check for
local.descriptoris necessary after introducing descriptor allocations.
1723-1742: Descriptor intrinsics emission LGTM — confirm offset widthinitialize_descriptor templating is correct. For increase_descriptor_offset(...), confirm the API expects a 32-bit offset; repository search did not locate the declaration, so either switch to std::int32_t (or a wider type) and add an explicit cast to avoid narrowing, or confirm plain int is intended. Location: src/target/codegen_cuda.cc:1739-1742
1809-1811: Descriptor allocation: ensure header is availablesrc/target/codegen_cuda.cc already writes
#include <tl_templates/cuda/gemm.h>into the generated decls, but I couldn't find a definition fortl::GmmaDescriptorin the repo — ensuregemm.h(or a transitive include) definestl::GmmaDescriptor; if it does not, add#include <tl_templates/cuda/common.h>in Finish() so the generated file includes the descriptor definition.
| warp_row_tiles: int = 8, | ||
| warp_col_tiles: int = 8, | ||
| chunk: int = 16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default warp_row_tiles conflicts with your own assertion (must be ≥16).
Set the default to 16 to avoid constructor-time assertion failures via parent init.
- warp_row_tiles: int = 8,
+ warp_row_tiles: int = 16,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| warp_row_tiles: int = 8, | |
| warp_col_tiles: int = 8, | |
| chunk: int = 16, | |
| warp_row_tiles: int = 16, | |
| warp_col_tiles: int = 8, | |
| chunk: int = 16, |
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 83 to 85, the
default value for warp_row_tiles is 8 which conflicts with the class assertion
requiring warp_row_tiles >= 16 and can trigger assertion failures during parent
initialization; change the default warp_row_tiles to 16 (update the parameter
default in the function/class signature) so it satisfies the assertion by
default and avoid constructor-time failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/gemm.cc (1)
780-795: TCGEN5MMA layout still passes old k-factor integers
makeGemmABLayoutSm100now expects a booleank_inner, but we still feedtrans_A ? 1 : 2/trans_B ? 2 : 1. Both expressions evaluate totrue, so the CDNA/Sm100 shared-memory layouts always assume K-inner ordering, breaking transposed operands. Pass explicit booleans that preserve the original semantics.- results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous, - mat_continuous, A->dtype.bits(), - trans_A ? 1 : 2)); + results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous, + mat_continuous, A->dtype.bits(), + !trans_A)); @@ - results.Set(B, - makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, - B->dtype.bits(), trans_B ? 2 : 1)); + results.Set(B, + makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, + B->dtype.bits(), trans_B));
♻️ Duplicate comments (4)
src/tl_templates/cuda/common.h (1)
254-263:operator+should add 16-byte quanta, not raw bytesThe descriptor stores offsets in 16-byte units. Adding a raw byte count corrupts the address field. Apply the same
>> 4used inincrease_descriptor_offsetso thatoperator+preserves the descriptor encoding.- ret.reg32_[0] = reg32_[0] + uint32_t(offset); + ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);tilelang/language/builtin.py (2)
382-383: Critical: Fix operator precedence to prevent AttributeError.The condition has ambiguous operator precedence. Python evaluates
andbeforeor, so ifdescriptoris aBufferLoad(not aBuffer), the expressiondescriptor.shape[0]will be evaluated even after theisinstance(descriptor, Buffer)check fails, causing anAttributeError.As per coding guidelines
Apply this diff to fix the precedence:
- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.")
408-409: Critical: Fix operator precedence to prevent AttributeError.Same operator precedence issue as in
initialize_descriptor. The condition will incorrectly evaluatedescriptor.shape[0]forBufferLoadinstances.As per coding guidelines
Apply this diff:
- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.")src/target/codegen_cuda.cc (1)
1565-1592: Use validated assembly from PrintWGMMAAssembly instead of ad-hoc pattern replacement.The code computes
asm_codeviaPrintWGMMAAssembly(line 1567) but then ignores it, creating a separatewgmma_asm_codestring with pattern replacement. The past review comment suggested using the validatedasm_codedirectly, which is what the RS path does correctly (line 1628).Apply this diff to use the validated assembly:
const bool a_is_shared = true; this->PrintIndent(); std::string asm_code = PrintWGMMAAssembly( shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, "", "", "", false); - auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); - std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n"; - // replace patterns - tl::codegen::Replacer replacer; - replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype)); - replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype)); - replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype)); - replacer.register_rule("(M)", std::to_string(m)); - replacer.register_rule("(N)", std::to_string(n)); - replacer.register_rule("(K)", std::to_string(k)); - replacer.register_rule("(tnspA)", a_is_k_major? "false": "true"); - replacer.register_rule("(tnspB)", b_is_k_major? "false": "true"); - replacer.register_rule("(scaleA)", scale_in_a? "1": "-1"); - replacer.register_rule("(scaleB)", scale_in_b? "1": "-1"); - replacer.register_rule("(desc_a)", a_desc); - replacer.register_rule("(A_offset)", A_offset); - replacer.register_rule("(desc_b)", b_desc); - replacer.register_rule("(B_offset)", B_offset); - replacer.register_rule("(C)", c_ref + " + " + c_offset); - replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); - wgmma_asm_code = replacer.rewrite(wgmma_asm_code); - this->stream << wgmma_asm_code; + this->stream << asm_code;
🧹 Nitpick comments (1)
src/layout/gemm_layouts.cc (1)
741-765: Update error message to reference k_inner instead of kfactor.Line 764's error message still references "kfactor=" which is inconsistent with the new boolean parameter name.
Apply this diff:
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride << ", continuous=" << mat_continuous - << ", element_size=" << element_size << ", kfactor=" << kfactor; + << ", element_size=" << element_size << ", k_inner=" << k_inner;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
src/layout/gemm_layouts.cc(4 hunks)src/layout/layout.cc(4 hunks)src/layout/layout.h(2 hunks)src/op/builtin.cc(2 hunks)src/op/builtin.h(2 hunks)src/op/gemm.cc(7 hunks)src/op/gemm_py.cc(4 hunks)src/op/gemm_py.h(1 hunks)src/target/codegen_cuda.cc(5 hunks)src/tl_templates/cuda/common.h(4 hunks)src/tl_templates/cuda/gemm.h(1 hunks)src/transform/storage_rewrite.cc(2 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/allocate.py(1 hunks)tilelang/language/builtin.py(2 hunks)tilelang/layout/fragment.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/tl_templates/cuda/gemm.h
- src/transform/storage_rewrite.cc
🧰 Additional context used
🧬 Code graph analysis (11)
tilelang/layout/fragment.py (1)
tilelang/layout/layout.py (4)
get_input_shape(59-68)get_output_shape(70-79)index(48-57)is_equal(136-145)
tilelang/language/allocate.py (2)
src/transform/storage_rewrite.cc (4)
dtype(696-702)dtype(696-696)scope(674-678)scope(674-674)tilelang/language/ast/ir.py (1)
alloc_buffer(441-508)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
trans_A(46-47)trans_B(50-51)A(67-68)B(71-72)src/layout/gemm_layouts.cc (2)
makeGemmABLayout(720-739)makeGemmABLayout(720-721)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
tvm(13-442)tilelang/language/ast/ir.py (1)
evaluate(1319-1331)tilelang/language/tir/op.py (1)
call_intrin(119-144)
src/tl_templates/cuda/common.h (2)
src/target/ptx.h (1)
DataType(45-91)tilelang/language/builtin.py (2)
initialize_descriptor(360-391)increase_descriptor_offset(394-416)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (6)
makeGemmABLayoutHopper(741-765)makeGemmABLayoutHopper(741-742)makeGemmABLayoutSm100(767-787)makeGemmABLayoutSm100(767-768)makeGemmABLayoutCDNA(789-792)makeGemmABLayoutCDNA(789-790)
src/op/builtin.cc (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/language/builtin.py (2)
initialize_descriptor(360-391)increase_descriptor_offset(394-416)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (10)
makeGemmABLayout(720-739)makeGemmABLayout(720-721)makeGemmABLayoutHopper(741-765)makeGemmABLayoutHopper(741-742)makeFullBankSwizzleLayout(421-437)makeFullBankSwizzleLayout(421-421)makeHalfBankSwizzleLayout(402-418)makeHalfBankSwizzleLayout(402-402)makeQuarterBankSwizzleLayout(382-399)makeQuarterBankSwizzleLayout(382-383)
src/op/builtin.h (2)
tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/language/builtin.py (2)
initialize_descriptor(360-391)increase_descriptor_offset(394-416)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
scope(674-678)scope(674-674)n(366-370)n(366-366)n(371-375)n(371-371)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)src/target/ptx.cc (8)
PrintWGMMAAssembly(1186-1257)PrintWGMMAAssembly(1186-1195)ParseMMAShape(140-148)ParseMMAShape(140-140)DTypeEnumToString(107-109)DTypeEnumToString(107-107)DTypeEnumToString(111-113)DTypeEnumToString(111-111)tilelang/language/builtin.py (2)
initialize_descriptor(360-391)increase_descriptor_offset(394-416)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
alloc_descriptor(158-164)
🪛 Ruff (0.13.3)
tilelang/language/builtin.py
380-380: Avoid specifying long messages outside the exception class
(TRY003)
382-382: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
383-383: Avoid specifying long messages outside the exception class
(TRY003)
406-406: Avoid specifying long messages outside the exception class
(TRY003)
408-408: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
409-409: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/__init__.py
47-47: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (14)
src/op/gemm_py.cc (3)
108-110: LGTM: Compiler appeasement for unreachable path.The unreachable return statement after
ICHECK(0)is a standard pattern to satisfy compiler requirements about control flow paths returning a value.
283-290: LGTM: New public FFI binding for GetGemmInst.The new
tl.GemmPyGemmInstFFI binding correctly exposes theGetGemmInstmethod, allowing runtime instruction selection logic from Python.
230-232: FFI signature verified Alllowermethods now includelayout_mapas the first argument.src/op/builtin.h (2)
219-239: LGTM: New WGMMA intrinsic declarations.The new
ptx_wgmma_ssandptx_wgmma_rsdeclarations are properly documented with detailed parameter descriptions and correctly exported viaTVM_DLL.
423-439: LGTM: New descriptor management intrinsics.The
initialize_descriptorandincrease_descriptor_offsetdeclarations are well-documented and correctly positioned aftertl_shuffle_electas noted in the enriched summary.src/target/codegen_cuda.cc (4)
1071-1074: LGTM: Descriptor scope handling.Treating
"local.descriptor"scope the same as"local.var"is correct for descriptor buffer references.
1593-1628: LGTM: Correct WGMMA-RS emission.The RS path correctly uses the validated
asm_codefromPrintWGMMAAssemblyand writes it directly to the stream, consistent with the past review guidance.
1953-1973: LGTM: Descriptor operation emission.The
initialize_descriptorandincrease_descriptor_offsethandlers correctly emit template instantiations with the provided parameters.
2116-2117: LGTM: Descriptor allocation.Declaring
tl::GmmaDescriptor vid;for"local.descriptor"scope is correct for WGMMA descriptor buffers.src/layout/gemm_layouts.cc (2)
578-587: LGTM: Boolean k_inner parameter simplifies API.The change from
int kfactortobool k_innermakes the API more intuitive. The conditionif (k_inner)replacingif (kfactor == 2)is clearer about intent.
720-739: LGTM: Well-documented signature change with clear semantics.The function documentation has been updated to explain the boolean
k_innerparameter's meaning for different element sizes. The implementation correctly uses!k_innerandk_innerfor conditional paths.src/layout/layout.cc (3)
487-491: LGTM: New equality check for Layout objects.The
tl.Layout_is_equalFFI binding correctly exposes theIsEqualmethod for Python-side layout comparison.
500-504: LGTM: New equality check for Fragment objects.The
tl.Fragment_is_equalFFI binding correctly exposes theIsEqualmethod for Python-side fragment comparison.
521-550: LGTM: Enhanced swizzled layout construction APIs.The updated
make_swizzled_layoutwithk_innerandallow_padparameters, along with the new specialized constructors (make_wgmma_swizzled_layout,make_full_bank_swizzled_layout, etc.), provide a cleaner public API surface with better dispatch logic.
| template <int layout_type = 0, int leading_byte_offset = 0, | ||
| int stride_byte_offset = 0, typename T> | ||
| TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, | ||
| T *start_address) { | ||
| descriptor.bitfield.start_address_ = | ||
| cute::cast_smem_ptr_to_uint(start_address) >> 4; | ||
| descriptor.bitfield.layout_type_ = layout_type; | ||
| descriptor.bitfield.base_offset_ = 0; | ||
| descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; | ||
| descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Descriptor fields must be encoded in 16-byte quanta
start_address_ is encoded as addr >> 4, but leading_byte_offset_ and stride_byte_offset_ are written with raw byte counts. The PTX descriptor spec strips the lower 4 bits for all three fields, so encoding plain bytes produces values 16× too large, breaking every descriptor consumed by WGMMA/UTCMMA. Shift the offsets by 4 (and consider asserting 16-byte alignment) before storing them.
- descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
- descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
+ descriptor.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
+ descriptor.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| template <int layout_type = 0, int leading_byte_offset = 0, | |
| int stride_byte_offset = 0, typename T> | |
| TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, | |
| T *start_address) { | |
| descriptor.bitfield.start_address_ = | |
| cute::cast_smem_ptr_to_uint(start_address) >> 4; | |
| descriptor.bitfield.layout_type_ = layout_type; | |
| descriptor.bitfield.base_offset_ = 0; | |
| descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; | |
| descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; | |
| template <int layout_type = 0, int leading_byte_offset = 0, | |
| int stride_byte_offset = 0, typename T> | |
| TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, | |
| T *start_address) { | |
| descriptor.bitfield.start_address_ = | |
| cute::cast_smem_ptr_to_uint(start_address) >> 4; | |
| descriptor.bitfield.layout_type_ = layout_type; | |
| descriptor.bitfield.base_offset_ = 0; | |
| descriptor.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; | |
| descriptor.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; |
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 303 to 312, the
initialize_descriptor writes leading_byte_offset_ and stride_byte_offset_ as raw
bytes while the PTX descriptor expects 16-byte quanta like start_address_ (addr
>> 4); change the writes to store (leading_byte_offset >> 4) and
(stride_byte_offset >> 4) and (optionally) add an assertion or static_assert
that both offsets and start_address are 16-byte aligned before shifting to catch
misalignment.
…r improved clarity and consistency. Update includes necessary changes in error messages for Hopper and Sm100 layouts. Additionally, include a new header for CUTE utilities in common.h.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/layout/gemm_layouts.cc (2)
741-765: Fix booleank_innerarguments in gemm_sp call sites
In src/op/gemm_sp.cc at lines 240 and 254, calls tomakeGemmABLayoutHopperstill passtrans_A ? 1 : 2andtrans_B ? 2 : 1(both convert totrue). Update them to pass the boolean flags directly—e.g. replacetrans_A ? 1 : 2withtrans_Aandtrans_B ? 2 : 1withtrans_B.
767-787: Convert k_inner arguments to bool in gemm.cc
- src/op/gemm.cc:783-785: replace
trans_A ? 1 : 2with!trans_Aso k_inner is true when A is not transposed- src/op/gemm.cc:793-795: replace
trans_B ? 2 : 1withtrans_Bso k_inner matches original intent- makeGemmABLayoutSm100(mat_stride, mat_continuous, mat_continuous, A->dtype.bits(), trans_A ? 1 : 2); + makeGemmABLayoutSm100(mat_stride, mat_continuous, mat_continuous, A->dtype.bits(), !trans_A); - makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, B->dtype.bits(), trans_B ? 2 : 1); + makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, B->dtype.bits(), trans_B);
♻️ Duplicate comments (3)
src/tl_templates/cuda/common.h (3)
304-320: Validate descriptor parameters before writing hardware fields.These functions write descriptor bitfields with no checks. Previous reviews requested assertions for layout_type range (0-3), 16-byte alignment of addresses/offsets, and overflow guards in
increase_descriptor_offset.Consider adding:
- Static or runtime assertions for
layout_type∈ [0, 3]- Alignment checks:
(uintptr_t)start_address & 0xF == 0- Verification that offsets are multiples of 16 and fit bitfield widths
- Overflow guard when adding
offset >> 4todescriptor.reg32_[0]
260-265: Critical encoding issue persists: operator+ must shift offset by 4.The descriptor encodes addresses/offsets in 16-byte quanta (4 LSBs stripped), but operator+ at line 262 adds raw bytes without shifting. This was flagged in previous reviews and remains unaddressed.
Apply this diff:
- ret.reg32_[0] = reg32_[0] + uint32_t(offset); + ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4);Note the inconsistency:
increase_descriptor_offsetat line 319 correctly shifts>> 4, but this operator does not.
304-314: Critical encoding issue persists: offset parameters must be shifted by 4.Lines 312-313 write
leading_byte_offsetandstride_byte_offsetas raw values, but the PTX descriptor expects 16-byte quanta likestart_address_(shifted>> 4at line 309). This was flagged in previous reviews and remains unaddressed.Apply this diff:
- descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; - descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; + descriptor.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + descriptor.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/layout/gemm_layouts.cc(6 hunks)src/tl_templates/cuda/common.h(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/common.h (3)
src/target/ptx.h (1)
DataType(45-91)src/tl_templates/cuda/ldsm.h (12)
void(7-14)void(16-23)void(25-33)void(35-42)void(44-52)void(54-62)void(64-70)void(72-79)void(81-89)void(91-98)void(100-108)void(110-119)tilelang/language/builtin.py (2)
initialize_descriptor(360-391)increase_descriptor_offset(394-416)
🔇 Additional comments (5)
src/tl_templates/cuda/common.h (1)
8-8: LGTM!The cute library includes and using declaration are appropriate for the new descriptor functionality.
Also applies to: 17-17
src/layout/gemm_layouts.cc (4)
181-181: Minor cleanup.The blank line addition appears intentional. No concerns.
690-719: LGTM! Well-documented parameter change.The documentation clearly explains the new
k_innerboolean parameter and its impact on layout selection for different data types. The examples for fp64 and int8 are helpful.
720-739: Approve: all callers updated to bool k_inner
Verified every call to makeGemmABLayout (layout.cc, gemm_layouts.cc, op/gemm.cc) now passes a boolean for k_inner; no further changes needed.
578-587: All callers updated to use boolean arguments
Verified in src/op/gemm.cc (lines 671, 684) that bothis_aandk_innerparameters are passed as booleans; no int arguments remain.
| enum class DataType : int { | ||
| kInt4 = 0, | ||
| kUInt4 = 1, | ||
| kInt8 = 2, | ||
| kUInt8 = 3, | ||
| kInt16 = 4, | ||
| kUInt16 = 5, | ||
| kInt32 = 6, | ||
| kUInt32 = 7, | ||
| kInt64 = 8, | ||
| kUInt64 = 9, | ||
| kFloat8_e4m3 = 10, | ||
| kFloat8_e5m2 = 11, | ||
| kFloat16 = 12, | ||
| kBFloat16 = 13, | ||
| kFloat16x2 = 14, | ||
| kFloat32 = 15, | ||
| kTensorFloat32 = 16, | ||
| kFloat64 = 17, | ||
| kBit1 = 18, | ||
| kBit8 = 19, | ||
| kBit16 = 20, | ||
| kBit32 = 21, | ||
| kBit64 = 22 | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Eliminate DataType duplication.
The DataType enum is duplicated in src/target/ptx.h (lines 44-90). Maintain a single definition in ptx.h and reference it here via include/using to avoid maintenance drift.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 181 to 205, remove the duplicated
DataType enum block and instead include the single canonical definition from
src/target/ptx.h; add an #include for "src/target/ptx.h" at the top of this file
(or the appropriate relative include path) and either reference the enum as the
qualified name from ptx.h or add a single using declaration (e.g., using
::DataType; or using target::DataType;) so this file reuses the ptx.h definition
rather than redefining it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/gemm.cc (1)
145-154: Bounds-check optional args to avoid out-of-bounds access.args[16], [17], [18] are accessed unconditionally; if not present, UB/crash.
- node->mbarptr = args[16]; - if (node->mbarptr.as<CallNode>()) { - node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)]; - } else { - node->mbar = std::nullopt; - } - node->C_coords = Array<PrimExpr>( - {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()}); + if (args.size() > 16) { + node->mbarptr = args[16]; + if (node->mbarptr.as<CallNode>()) { + node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)]; + } else { + node->mbar = std::nullopt; + } + } else { + node->mbarptr = PrimExpr(); + node->mbar = std::nullopt; + } + ICHECK(args.size() > 18) + << "Gemm expects C_coords provided as args[17], args[18]"; + node->C_coords = Array<PrimExpr>({args[17].as<PrimExpr>().value(), + args[18].as<PrimExpr>().value()});
♻️ Duplicate comments (6)
tilelang/intrinsics/wgmma_macro_generator.py (5)
139-149: Remove debug prints.These pollute logs; drop them or gate via a logger.
- print(f"_determinate_swizzle_mode mat_stride: {mat_stride}, mat_continuous: {mat_continuous}, element_size: {element_size}") @@ - print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}") - print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}") - print(f"b_leading_byte_offset: {b_leading_byte_offset >> 4}") - print(f"b_stride_byte_offset: {b_stride_byte_offset >> 4}") - - print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}") - print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}")Also applies to: 241-248
176-178: Add k-dimension divisibility asserts to prevent silent tails.Ensure k_dim % micro_size_k == 0 in SS and RS.
- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}" @@ - assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got {k_dim}"Also applies to: 295-296
184-217: Descriptor math uses wrong element size and pointer mode; fix per-operand bytes and offsets.Use separate a/b byte sizes, correct LBO/SBO units, and read-only access_ptr; scale A/B offsets with respective sizes.
- elems_in_bits = DataType(self.a_dtype).bits - elems_in_bytes = elems_in_bits // 8 - - a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + a_elems_in_bytes = DataType(self.a_dtype).bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 + + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // a_elems_in_bytes + b_swizzle_atom_elems = ( + n_dim if b_swizzle_mode.is_none() + else b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes + ) @@ - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim * a_elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 * a_elems_in_bytes) @@ - if a_is_k_major: - a_leading_byte_offset = 16 - a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() + if a_is_k_major: + a_leading_byte_offset = 16 + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() else: @@ - if a_m_axis_atoms <= 1: - a_leading_byte_offset = 0 - else: - a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + if a_m_axis_atoms <= 1: + a_leading_byte_offset = 0 + else: + a_leading_byte_offset = a_swizzle_mode.swizzle_byte_size() @@ - if a_m_axis_atoms <= 1: - a_stride_byte_offset = 8 * elems_in_bytes * m_dim - else: - a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems + if a_m_axis_atoms <= 1: + a_stride_byte_offset = 8 * a_elems_in_bytes * m_dim + else: + a_stride_byte_offset = 8 * a_elems_in_bytes * a_swizzle_atom_elems @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else ( - 0 if n_dim == 8 else (8 * 8 * elems_in_bytes) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else ( + 0 if n_dim == 8 else (8 * 8 * b_elems_in_bytes) ) @@ - if b_is_k_major: - b_leading_byte_offset = 16 - b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() else: @@ - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim + b_stride_byte_offset = 8 * b_elems_in_bytes * n_dim else: - b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + b_stride_byte_offset = 8 * b_elems_in_bytes * b_swizzle_atom_elems @@ - desc_a = T.alloc_descriptor() - desc_b = T.alloc_descriptor() - T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode, + desc_a = T.alloc_descriptor() + desc_b = T.alloc_descriptor() + T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, + T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, a_dtype_abbrv, b_dtype_abbrv, - accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4, - desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data, + accum_dtype_abbrv, desc_a.data, (A_offset * a_elems_in_bytes) >> 4, + desc_b.data, (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data, C_offset, scale_out, scale_in_a, scale_in_b)Also applies to: 217-241, 258-261, 267-271
303-307: RS path: mirror per-operand sizes, use read-only descriptors, and fix K-major flags.Correct B descriptor math/units, pointer mode, and flags; scale B offset with B’s element size.
- elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_elems_in_bytes = DataType(self.a_dtype).bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * - elems_in_bytes) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * b_elems_in_bytes) @@ - b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes) @@ - b_leading_byte_offset = 8 * b_swizzle_mode.swizzle_atom_size() * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() @@ - b_stride_byte_offset = 8 * elems_in_bytes * n_dim + b_stride_byte_offset = 8 * b_elems_in_bytes * n_dim else: - b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + b_stride_byte_offset = 8 * b_elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes) @@ - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - self.a_transposed, - not self.b_transposed, + (not self.a_transposed), + self.b_transposed, @@ - (B_offset * elems_in_bytes) >> 4, + (B_offset * b_elems_in_bytes) >> 4,Also applies to: 316-326, 330-352, 341-343, 349-349
84-85: Default warp_row_tiles conflicts with your assertion; set to 16.Avoid constructor-time assertion failures.
- warp_row_tiles: int = 8, + warp_row_tiles: int = 16,src/op/gemm.cc (1)
112-115: kPack doc vs. validation mismatch; enforce a single rule (1 only).Doc says “kPack must be 1” but runtime allows 1 or 2. Align with doc/PR intent by enforcing 1.
- if (args.size() > 14) { - node->kPack = args[14].as<IntImm>().value()->value; - if (node->kPack != 1 && node->kPack != 2) { - ICHECK(false) << "kPack must be 1 or 2"; - } - } + if (args.size() > 14) { + node->kPack = args[14].as<IntImm>().value()->value; + if (node->kPack != 1) { + ICHECK(false) << "kPack must be 1"; + } + }Also applies to: 136-141
🧹 Nitpick comments (10)
src/op/gemm.cc (3)
103-111: Update constructor arg docs to include mbarptr and C_coords.Docs omit mbarptr (descriptor) and C_coords now required later.
- * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), - * (optional) kPack (Int), (optional) wg_wait (Int)] + * (optional) kPack (Int), (optional) wg_wait (Int), + * (optional) mbarptr (Handle), C_i (PrimExpr), C_j (PrimExpr)]
481-492: Harden arch parsing to avoid std::stoi exceptions.Guard empty/malformed “sm_” strings.
- if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); - } else { + if (arch.rfind("sm_", 0) == 0) { + std::string num = arch.substr(3); + if (!num.empty() && std::isdigit(num[0])) { + arch_int = std::stoi(num); + } else { + arch_int = 0; + } + } else { arch_int = 0; }
760-761: Remove noisy LOG(INFO) in hot path (use VLOG or drop).Unconditional INFO logs during layout inference are noisy.
- LOG(INFO) << "gemm_inst: " << (int)gemm_inst << ", trans_B: " << trans_B; + // VLOG(1) << "gemm_inst: " << (int)gemm_inst << ", trans_B: " << trans_B;tilelang/layout/__init__.py (1)
6-13: Drop unnecessary# noqa: F401comments (RUF100).Ruff flags these as unused. Remove or switch to explicit all exports.
-from .swizzle import ( - make_swizzled_layout, # noqa: F401 - make_wgmma_swizzled_layout, # noqa: F401 - make_full_bank_swizzled_layout, # noqa: F401 - make_half_bank_swizzled_layout, # noqa: F401 - make_quarter_bank_swizzled_layout, # noqa: F401 - make_linear_layout, # noqa: F401 -) +from .swizzle import ( + make_swizzled_layout, + make_wgmma_swizzled_layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, +)tilelang/layout/swizzle.py (1)
10-18: Type Optional properly and remove debug prints; trim exception messages.
- Use Optional[int] for continuity.
- Remove print in make_wgmma_swizzled_layout.
- Shorten ValueError messages.
+from typing import Optional @@ -def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, - continuity: int = None, +def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, + continuity: Optional[int] = None, k_major: bool = True): @@ - print(f"make_wgmma_swizzled_layout: {buffer.shape[0]}, {buffer.shape[1]}, {continuity}, {tvm.DataType(buffer.dtype).bits}, {k_major}") return _ffi_api.make_wgmma_swizzled_layout( @@ - else: - raise ValueError(f"Invalid arguments: {args}") + else: + raise ValueError("Invalid arguments") @@ - else: - raise ValueError(f"Invalid arguments: {args}") + else: + raise ValueError("Invalid arguments") @@ - else: - raise ValueError(f"Invalid arguments: {args}") + else: + raise ValueError("Invalid arguments") @@ - else: - raise ValueError(f"Invalid arguments: {args}") + else: + raise ValueError("Invalid arguments")Also applies to: 22-35, 48-60, 73-86, 98-110, 120-130
tilelang/tileop/gemm/gemm_wgmma.py (2)
13-61: GemmWGMMA layout inference looks good; minor nits.
- Good: k_major/continuity mirror Hopper path.
- Consider shorter error text.
- else: - raise ValueError( - f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + else: + raise ValueError("Unsupported WGMMA layout combination")Also, confirm
from tilelang import tvm as tvmis intentional vsimport tvm.
65-127: Lowering path is clean; minor error-msg nit.- raise ValueError( - f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + raise ValueError("Unsupported WGMMA lowering combination")tilelang/intrinsics/wgmma_macro_generator.py (2)
140-145: Collapse trivial condition (SIM114).Combine NONE/linear checks to satisfy linter.
- if layout is None: - return SwizzleMode.NONE - elif layout.is_equal(make_linear_layout(mat_stride, mat_continuous)): + if layout is None or layout.is_equal(make_linear_layout(mat_stride, mat_continuous)): return SwizzleMode.NONE
105-110: Remove unused method argn_dim.Not used; simplify signature.
- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self): inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tilessrc/layout/gemm_layouts.cc (1)
388-388: Remove or gate debug LOG statement.This LOG(INFO) in
makeQuarterBankSwizzleLayoutwill generate output on every call, potentially flooding logs in production. Debug logging in hot paths should either be removed before release or gated behind a debug flag.Consider one of these approaches:
- LOG(INFO) << "makeQuarterBankSwizzleLayout: " << stride << ", " << continuous << ", " << element_size; + DLOG(INFO) << "makeQuarterBankSwizzleLayout: " << stride << ", " << continuous << ", " << element_size;Or remove entirely if not needed for debugging.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/layout/gemm_layouts.cc(7 hunks)src/layout/layout.cc(3 hunks)src/op/gemm.cc(7 hunks)tilelang/intrinsics/wgmma_macro_generator.py(1 hunks)tilelang/layout/__init__.py(1 hunks)tilelang/layout/swizzle.py(1 hunks)tilelang/tileop/gemm/gemm_wgmma.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/layout/__init__.py (1)
tilelang/layout/swizzle.py (6)
make_swizzled_layout(10-18)make_wgmma_swizzled_layout(22-35)make_full_bank_swizzled_layout(40-60)make_half_bank_swizzled_layout(65-85)make_quarter_bank_swizzled_layout(90-110)make_linear_layout(112-130)
tilelang/layout/swizzle.py (1)
tilelang/language/ast/ir.py (1)
buffer(93-161)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
trans_A(46-47)trans_B(50-51)A(67-68)B(71-72)src/layout/gemm_layouts.cc (4)
makeGemmABLayout(721-740)makeGemmABLayout(721-722)makeGemmABLayoutHopper(742-773)makeGemmABLayoutHopper(742-743)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (12)
makeGemmABLayout(721-740)makeGemmABLayout(721-722)makeGemmABLayoutHopper(742-773)makeGemmABLayoutHopper(742-743)makeFullBankSwizzleLayout(422-438)makeFullBankSwizzleLayout(422-422)makeHalfBankSwizzleLayout(403-419)makeHalfBankSwizzleLayout(403-403)makeQuarterBankSwizzleLayout(382-400)makeQuarterBankSwizzleLayout(382-383)makeGemmLayoutLinear(492-496)makeGemmLayoutLinear(492-492)
tilelang/tileop/gemm/gemm_wgmma.py (6)
tilelang/tileop/gemm/gemm_base.py (20)
GemmBase(12-120)infer_layout(15-16)policy(119-120)M(34-35)N(38-39)in_dtype(54-56)accum_dtype(59-60)trans_A(46-47)trans_B(50-51)chunk(63-64)is_gemm_ss(21-22)K(42-43)A(67-68)B(71-72)C(75-76)is_gemm_rs(27-28)lower(18-19)clear_accum(107-108)is_gemm_sr(24-25)is_gemm_rr(30-31)tilelang/layout/swizzle.py (1)
make_wgmma_swizzled_layout(22-35)tilelang/intrinsics/wgmma_macro_generator.py (6)
TensorCoreIntrinEmitter(64-525)make_mma_store_layout(471-525)make_mma_load_layout(359-469)_assign_a_shared_layout(97-99)_assign_b_shared_layout(101-103)wgmma(153-273)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/transform/simplify.py (1)
_Simplify(30-49)tilelang/tileop/gemm/gemm_mma.py (7)
infer_layout(15-58)is_gemm_ss(204-205)is_gemm_rs(210-211)lower(60-202)_gemm_ssr(95-120)is_gemm_sr(207-208)is_gemm_rr(213-214)
tilelang/intrinsics/wgmma_macro_generator.py (5)
tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/layout/swizzle.py (4)
make_full_bank_swizzled_layout(40-60)make_half_bank_swizzled_layout(65-85)make_quarter_bank_swizzled_layout(90-110)make_linear_layout(112-130)tilelang/language/builtin.py (1)
initialize_descriptor(360-391)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map(159-165)
🪛 GitHub Actions: CI Test on AMD
tilelang/intrinsics/wgmma_macro_generator.py
[error] 142-145: SIM114 Combine if branches using logical or operator
🪛 GitHub Actions: CI Test on Metal
tilelang/intrinsics/wgmma_macro_generator.py
[error] 142-142: SIM114 Combine if branches using logical or operator
🪛 Ruff (0.13.3)
tilelang/layout/__init__.py
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
8-8: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
9-9: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
11-11: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
12-12: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/layout/swizzle.py
23-23: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
55-55: Avoid specifying long messages outside the exception class
(TRY003)
80-80: Avoid specifying long messages outside the exception class
(TRY003)
105-105: Avoid specifying long messages outside the exception class
(TRY003)
126-126: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_wgmma.py
62-63: Avoid specifying long messages outside the exception class
(TRY003)
125-126: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/wgmma_macro_generator.py
105-105: Unused method argument: n_dim
(ARG002)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
401-401: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (6)
src/layout/gemm_layouts.cc (2)
709-740: LGTM: k_inner refactor is well-documented.The updated documentation clearly explains the semantics of the
k_innerparameter for both fp64 and int8 cases, with inline comments reinforcing the layout selection logic. The implementation correctly usesk_innerto select between KxN and NxK layouts.
580-588: AllmakeGemmVoltaABLayoutcall sites now pass a booleank_inner. Only two calls remain (insrc/op/gemm.cclines 671 and 685), each using an explicitboolargument—no numeric k-factor usages.src/layout/layout.cc (4)
487-491: LGTM: Layout equality check.The FFI binding correctly delegates to
LayoutNode::IsEqual()for comparing two Layout objects.
500-504: LGTM: Fragment equality check.The FFI binding correctly delegates to
FragmentNode::IsEqual()for comparing two Fragment objects.
532-554: LGTM: New layout constructor bindings.The new FFI bindings (
make_wgmma_swizzled_layout,make_full_bank_swizzled_layout,make_half_bank_swizzled_layout,make_quarter_bank_swizzled_layout,make_linear_layout) correctly expose the corresponding C++ functions with appropriate parameter passing.
521-531: Dispatch logic verified: all existing call sites rely on the defaultallow_pad=true, preserving previous behavior; explicitallow_pad=falsenow correctly routes tomakeGemmABLayoutHopper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/op/gemm.cc (2)
145-153: Potential args out-of-bounds access (args[16..18])
mbarptr = args[16]andC_coordsfromargs[17], args[18]are read without size checks. If frontend omits these, this will crash.- node->mbarptr = args[16]; + ICHECK_GT(args.size(), 18) << "Gemm expects mbar pointer and C_coords at indices 16..18"; + node->mbarptr = args[16]; if (node->mbarptr.as<CallNode>()) { node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)]; } else { node->mbar = std::nullopt; } - node->C_coords = Array<PrimExpr>( - {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()}); + node->C_coords = Array<PrimExpr>({args[17], args[18]});
784-787: Wrong argument type passed to makeGemmABLayoutSm100 (bool API coerced from int)
makeGemmABLayoutSm100(..., bool k_inner)now takes a bool, but calls still passtrans_A ? 1 : 2andtrans_B ? 2 : 1. Any non‑zero coerces totrue, breaking K‑inner semantics.- results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous, - mat_continuous, A->dtype.bits(), - trans_A ? 1 : 2)); + results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous, + mat_continuous, A->dtype.bits(), + !trans_A)); @@ - results.Set(B, - makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, - B->dtype.bits(), trans_B ? 2 : 1)); + results.Set(B, + makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, + B->dtype.bits(), trans_B));Also applies to: 794-796
♻️ Duplicate comments (15)
src/tl_templates/cuda/instruction/wgmma.h (1)
15-23: Remove debug code before merging.The fallback template contains debug printf statements and a commented-out static_assert that should be removed or uncommented before merging to production. This was previously flagged.
src/op/gemm.cc (1)
112-115: Doc vs. code mismatch for kPackNote says “kPack must be 1”, but constructor accepts 1 or 2 (Line 139) and errors “kPack must be 1 or 2”. Align doc or code.
- * @note If `kPack` is provided it must be 1; otherwise the constructor - * fails with an ICHECK (runtime assertion). No other validation is - * performed here. + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion).Or, if only 1 is intended now:
- if (node->kPack != 1 && node->kPack != 2) { - ICHECK(false) << "kPack must be 1 or 2"; + if (node->kPack != 1) { + ICHECK(false) << "kPack must be 1"; }tilelang/language/builtin.py (2)
382-383: Fix operator precedence in validation (and/or)
andbinds tighter thanor, sodescriptor.shape[0]may be accessed even when not a Buffer. Parenthesize the shape checks.- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
408-409: Fix operator precedence in validation (and/or)Same issue here—wrap the
andside.- if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):src/target/codegen_cuda.cc (1)
1537-1599: WGMMA-SS emission: uses ad-hoc string, ignores validated asm_code, and does invalid descriptor arithmetic
asm_codefrom PrintWGMMAAssembly is computed then ignored.- The fallback string does
uint64_t((desc_a) + (A_offset))which is invalid on a descriptor struct.- RS path already emits the validated
asm_code.Fix by emitting
asm_codeand removing the replacer block.- const bool a_is_shared = true; - this->PrintIndent(); - std::string asm_code = PrintWGMMAAssembly( + const bool a_is_shared = true; + this->PrintIndent(); + std::string asm_code = PrintWGMMAAssembly( shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, "", "", "", false); - auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); - std::string wgmma_asm_code = - "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), " - "(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), " - "uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n"; - // replace patterns - tl::codegen::Replacer replacer; - replacer.register_rule("(AType)", - tl::codegen::ptx::DTypeEnumToString(A_dtype)); - replacer.register_rule("(BType)", - tl::codegen::ptx::DTypeEnumToString(B_dtype)); - replacer.register_rule("(CType)", - tl::codegen::ptx::DTypeEnumToString(C_dtype)); - replacer.register_rule("(M)", std::to_string(m)); - replacer.register_rule("(N)", std::to_string(n)); - replacer.register_rule("(K)", std::to_string(k)); - replacer.register_rule("(tnspA)", a_is_k_major ? "false" : "true"); - replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true"); - replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1"); - replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1"); - replacer.register_rule("(desc_a)", a_desc); - replacer.register_rule("(A_offset)", A_offset); - replacer.register_rule("(desc_b)", b_desc); - replacer.register_rule("(B_offset)", B_offset); - replacer.register_rule("(C)", c_ref + " + " + c_offset); - replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); - wgmma_asm_code = replacer.rewrite(wgmma_asm_code); - this->stream << wgmma_asm_code; + this->stream << asm_code;src/target/ptx.h (2)
154-162: Complete and align WGMMA parameter documentation with the signatureThe comment documents only a subset of parameters and uses outdated names (A_layout/B_layout vs a_is_k_major/b_is_k_major). Please document all parameters and align names with the signature.
/*! * \brief Print WGMMA assembly string given parameters. * \param shape The shape string mMnNkK - * \param A_layout The layout of multiplicand A, can be either "row" or "col". - * \param B_layout The layout of multiplicand B, can be either "row" or "col". - * \param A_dtype The data type of multiplicand A. - * \param B_dtype The data type of multiplicand B. - * \param C_dtype The data type of multiplicand C. + * \param a_is_k_major Whether A is K‑major (true) or MN‑major (false). + * \param b_is_k_major Whether B is K‑major (true) or MN‑major (false). + * \param A_dtype Data type of A (e.g., ".f16", "e4m3", ".s8"). + * \param B_dtype Data type of B. + * \param C_dtype Data type of accumulator/output C. + * \param a_desc Descriptor for A (SS) or base pointer (RS). + * \param A_offset Offset for A (in 16‑byte quanta when descriptor is used). + * \param b_desc Descriptor for B (SS/RS) or base pointer (SS). + * \param B_offset Offset for B (in 16‑byte quanta when descriptor is used). + * \param c_ptr Pointer to accumulator C. + * \param c_offset Offset into C (in elements). + * \param scale_out Scale‑out flag (true keeps/accumulates; false clears). + * \param scale_in_a Scale‑in flag for A (true 1, false −1). + * \param scale_in_b Scale‑in flag for B (true 1, false −1). + * \param a_is_shared Whether A comes from shared memory (SS vs RS selection). + * \param metadata Pointer to metadata buffer (sparse only). + * \param metadata_offset Offset into metadata. + * \param sparsity_selector Sparsity selector (sparse only). + * \param sparse Whether sparse WGMMA is used. */
164-174: Pass booleans by value (avoid const references to primitives)Use plain
boolfor primitive params to match existing APIs and avoid unnecessary references. Update declaration/definition and callers accordingly.-PrintWGMMAAssembly(const std::string &shape, const bool &a_is_k_major, - const bool &b_is_k_major, const std::string &A_dtype, +PrintWGMMAAssembly(const std::string &shape, bool a_is_k_major, + bool b_is_k_major, const std::string &A_dtype, const std::string &B_dtype, const std::string &C_dtype, const std::string &a_desc, const std::string &A_offset, const std::string &b_desc, const std::string &B_offset, const std::string &c_ptr, const std::string &c_offset, - const bool &scale_out, const bool &scale_in_a, - const bool &scale_in_b, const bool &a_is_shared, + bool scale_out, bool scale_in_a, + bool scale_in_b, bool a_is_shared, const std::string &metadata, const std::string &metadata_offset, const std::string &sparsity_selector, bool sparse);tilelang/intrinsics/wgmma_macro_generator.py (4)
84-85: Default contradicts assertion (warp_row_tiles)Set default to 16 to satisfy the class invariant.
- warp_row_tiles: int = 8, + warp_row_tiles: int = 16,
170-171: Enforce divisibility to avoid dropping K tails (SS path)Loop uses
k_dim // micro_size_k; add a divisibility assert.- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}"
178-184: Fix per‑operand element sizes, descriptor pointer mode, and offset scaling (SS)
- Use separate byte sizes for A/B; don’t reuse A’s for B.
- Descriptors read; use
access_ptr("r").- Scale A/B offsets with their own elem sizes.
- elems_in_bits = DataType(self.a_dtype).bits - elems_in_bytes = elems_in_bits // 8 - - a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( - ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + elems_in_bits = DataType(self.a_dtype).bits + a_elems_in_bytes = elems_in_bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 + + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // a_elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else \ + (b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes) @@ - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * - elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * - elems_in_bytes) + a_leading_byte_offset = (8 * 8 * a_elems_in_bytes) if a_is_k_major else (8 * m_dim * a_elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * a_elems_in_bytes) if a_is_k_major else (8 * 8 * a_elems_in_bytes) @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * - elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else - (8 * 8 * elems_in_bytes)) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * b_elems_in_bytes)) @@ - T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode, + T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, - a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, - (A_offset * elems_in_bytes) >> 4, desc_b.data, - (B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset, + T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, + a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, + (A_offset * a_elems_in_bytes) >> 4, desc_b.data, + (B_offset * b_elems_in_bytes) >> 4, C_local_buf.data, C_offset, scale_out, scale_in_a, scale_in_b)Also applies to: 185-212, 213-237, 247-250, 260-264
288-289: RS path: divisibility, per‑operand sizes, pointer mode, K‑major flags, and K‑major swizzle SBO
- Add K‑divisibility assert.
- Use per‑operand byte sizes; don’t reuse A’s for B.
- Descriptor is read‑only: use
access_ptr("r").- Fix K‑major flags (match SS semantics).
- When B is K‑major with swizzle, set SBO to
8 * swizzle_byte_size().- assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}" + assert (k_dim % micro_size_k) == 0, f"k_dim must be divisible by micro_size_k={micro_size_k}, got k_dim: {k_dim}" @@ - elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_elems_in_bytes = DataType(self.a_dtype).bits // 8 + b_elems_in_bytes = DataType(self.b_dtype).bits // 8 @@ - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * - elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * - elems_in_bytes) + b_leading_byte_offset = (8 * 8 * b_elems_in_bytes) if b_is_k_major else (8 * n_dim * b_elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * b_elems_in_bytes) if b_is_k_major else (8 * 8 * b_elems_in_bytes) @@ - if b_is_k_major: - b_leading_byte_offset = 16 + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() @@ - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) @@ - self.a_transposed, - not self.b_transposed, + (not self.a_transposed), + self.b_transposed, @@ - (B_offset * elems_in_bytes) >> 4, + (B_offset * b_elems_in_bytes) >> 4,Also applies to: 290-321, 325-327, 333-350
src/tl_templates/cuda/common.h (3)
303-314: Add compile‑time validation for descriptor fieldsGuard against invalid layout_type and mis‑encoded offsets (template args).
template <int layout_type = 0, int leading_byte_offset = 0, int stride_byte_offset = 0, typename T> TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, T *start_address) { + static_assert(layout_type >= 0 && layout_type <= 3, "layout_type must be in [0,3]"); + static_assert((leading_byte_offset >= 0) && (stride_byte_offset >= 0), "offsets must be non-negative"); + // Expect offsets encoded in 16B quanta (match Python-side >>4) + static_assert((leading_byte_offset & 0xFFFF0000) == 0, "leading_byte_offset too large for bitfield"); + static_assert((stride_byte_offset & 0xFFFF0000) == 0, "stride_byte_offset too large for bitfield"); descriptor.bitfield.start_address_ = cute::cast_smem_ptr_to_uint(start_address) >> 4; descriptor.bitfield.layout_type_ = layout_type; descriptor.bitfield.base_offset_ = 0; descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; }
171-205: Avoid DataType duplication (single source of truth)This enum duplicates src/target/ptx.h. Prefer including/reusing the canonical definition to prevent drift.
254-265: Unify descriptor offset semantics (operator+)
operator+adds raw units whileincrease_descriptor_offsetaccepts bytes (>>4). Makeoperator+accept bytes too for consistency.template <typename T> CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { GmmaDescriptor ret; - ret.reg32_[0] = reg32_[0] + uint32_t(offset); + ret.reg32_[0] = reg32_[0] + uint32_t(offset >> 4); ret.reg32_[1] = reg32_[1]; return ret; }Consider documenting that offsets are byte counts.
tilelang/tileop/gemm/__init__.py (1)
27-42: Fix enum typo: WGMMMA → WGMMA (dispatch bug)The extra “M” mismatches the C++/FFI side and breaks selection.
class GemmInst(IntEnum): MMA = 0 - WGMMMA = 1 + WGMMA = 1 MFMA = 2 def is_mma(self) -> bool: return self == GemmInst.MMA def is_wgmma(self) -> bool: - return self == GemmInst.WGMMMA + return self == GemmInst.WGMMA
🧹 Nitpick comments (5)
tilelang/language/builtin.py (1)
360-365: Type hints don’t match accepted inputsFunctions accept Buffer or BufferLoad, but annotations say
Buffer/PrimExpr. Align hints to avoid confusion and improve IDE/static checks.-def initialize_descriptor(descriptor: Buffer, +def initialize_descriptor(descriptor: Union[Buffer, BufferLoad], start_address: PrimExpr, layout_type_: int = 0, leading_byte_offset: int = 0, stride_byte_offset: int = 0) -> PrimExpr: @@ -def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr: +def increase_descriptor_offset(descriptor: Union[Buffer, BufferLoad], offset: PrimExpr) -> PrimExpr:Also applies to: 394-405
tilelang/layout/swizzle.py (2)
22-35: Use Optional typing and validate continuity derivation
- Annotate
continuityas Optional[int] per PEP 484.- Minor: assert the single-arg form is a Buffer and 2D before using shape.
-from typing import Union, Any +from typing import Union, Any, Optional @@ -def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, - continuity: int = None, +def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, + continuity: Optional[int] = None, k_major: bool = True): assert len(buffer.shape) == 2 - if continuity is None: + if continuity is None: continuity = int(buffer.shape[1])
47-60: Guard the single-argument overloadsWhen
len(args) == 1, ensure it’s a 2D Buffer before reading.shape.- if len(args) == 1: - buffer = args[0] + if len(args) == 1: + buffer = args[0] + assert isinstance(buffer, tvm.tir.Buffer) and len(buffer.shape) == 2Apply similarly to half/quarter/linear helpers.
Also applies to: 72-85, 97-110, 120-130
tilelang/intrinsics/wgmma_macro_generator.py (2)
105-105: Unused parametern_dimRemove or prefix with underscore to satisfy linters.
- def _initialize_wgmma_prefix(self, n_dim: int = 16): + def _initialize_wgmma_prefix(self, _n_dim: int = 16):
376-376: Redundant import
is_fragmentis already imported at file top; remove the inner import.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
.clang-tidy(1 hunks)src/layout/gemm_layouts.cc(6 hunks)src/layout/layout.cc(3 hunks)src/layout/layout.h(2 hunks)src/op/builtin.h(2 hunks)src/op/gemm.cc(7 hunks)src/target/codegen_cuda.cc(5 hunks)src/target/ptx.cc(9 hunks)src/target/ptx.h(2 hunks)src/tl_templates/cuda/common.h(4 hunks)src/tl_templates/cuda/instruction/wgmma.h(1 hunks)src/transform/storage_rewrite.cc(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py(1 hunks)tilelang/language/builtin.py(2 hunks)tilelang/layout/swizzle.py(1 hunks)tilelang/tileop/gemm/__init__.py(3 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/target/ptx.cc
🚧 Files skipped from review as they are similar to previous changes (2)
- src/transform/storage_rewrite.cc
- src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (10)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (12)
makeGemmABLayout(720-739)makeGemmABLayout(720-721)makeGemmABLayoutHopper(741-766)makeGemmABLayoutHopper(741-742)makeFullBankSwizzleLayout(421-437)makeFullBankSwizzleLayout(421-421)makeHalfBankSwizzleLayout(402-418)makeHalfBankSwizzleLayout(402-402)makeQuarterBankSwizzleLayout(382-399)makeQuarterBankSwizzleLayout(382-383)makeGemmLayoutLinear(491-495)makeGemmLayoutLinear(491-491)
src/target/codegen_cuda.cc (4)
src/transform/storage_rewrite.cc (6)
scope(674-679)scope(674-674)n(366-370)n(366-366)n(371-375)n(371-371)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)src/target/ptx.cc (8)
PrintWGMMAAssembly(1188-1261)PrintWGMMAAssembly(1189-1199)ParseMMAShape(142-150)ParseMMAShape(142-142)DTypeEnumToString(108-110)DTypeEnumToString(108-108)DTypeEnumToString(112-115)DTypeEnumToString(112-112)tilelang/language/builtin.py (2)
initialize_descriptor(360-391)increase_descriptor_offset(394-416)
src/tl_templates/cuda/common.h (2)
src/target/ptx.h (1)
DataType(45-90)tilelang/language/builtin.py (2)
initialize_descriptor(360-391)increase_descriptor_offset(394-416)
src/tl_templates/cuda/instruction/wgmma.h (1)
src/tl_templates/cuda/common.h (5)
tl(171-265)DataType(180-225)int(97-100)int(135-142)uint32_t(118-120)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (6)
makeGemmABLayoutHopper(741-766)makeGemmABLayoutHopper(741-742)makeGemmABLayoutSm100(768-788)makeGemmABLayoutSm100(768-769)makeGemmABLayoutCDNA(790-793)makeGemmABLayoutCDNA(790-791)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/layout/layout.py (3)
Layout(13-148)is_equal(136-145)inverse(125-134)tilelang/layout/swizzle.py (4)
make_full_bank_swizzled_layout(39-59)make_half_bank_swizzled_layout(64-84)make_quarter_bank_swizzled_layout(89-109)make_linear_layout(112-130)tilelang/layout/fragment.py (4)
is_equal(209-213)Fragment(13-213)replicate(146-160)repeat(123-144)tilelang/language/builtin.py (1)
initialize_descriptor(360-391)tilelang/intrinsics/mma_macro_generator.py (1)
get_store_index_map(159-165)
src/target/ptx.h (4)
src/tl_templates/cuda/common.h (3)
DataType(180-225)int(97-100)int(135-142)src/target/ptx.cc (10)
DTypeFromString(56-106)DTypeFromString(56-56)DTypeEnumToString(108-110)DTypeEnumToString(108-108)DTypeEnumToString(112-115)DTypeEnumToString(112-112)ParseMMAShape(142-150)ParseMMAShape(142-142)PrintWGMMAAssembly(1188-1261)PrintWGMMAAssembly(1189-1199)src/transform/storage_rewrite.cc (2)
dtype(697-703)dtype(697-697)src/target/codegen_hip.cc (1)
_rules(75-75)
src/op/gemm.cc (2)
tilelang/tileop/gemm/gemm_base.py (4)
trans_A(46-47)trans_B(50-51)A(67-68)B(71-72)src/layout/gemm_layouts.cc (4)
makeGemmABLayout(720-739)makeGemmABLayout(720-721)makeGemmABLayoutHopper(741-766)makeGemmABLayoutHopper(741-742)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
tvm(13-442)tilelang/language/ast/ir.py (1)
evaluate(1319-1331)tilelang/language/tir/op.py (1)
call_intrin(119-144)
tilelang/layout/swizzle.py (3)
tilelang/language/ast/ir.py (1)
buffer(93-161)src/op/builtin.h (1)
tvm(13-442)src/layout/swizzle.h (1)
tvm(12-70)
🪛 Ruff (0.13.3)
tilelang/tileop/gemm/__init__.py
120-120: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/wgmma_macro_generator.py
105-105: Unused method argument: n_dim
(ARG002)
145-145: Avoid specifying long messages outside the exception class
(TRY003)
396-396: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/builtin.py
380-380: Avoid specifying long messages outside the exception class
(TRY003)
382-382: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
383-383: Avoid specifying long messages outside the exception class
(TRY003)
406-406: Avoid specifying long messages outside the exception class
(TRY003)
408-408: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
409-409: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/layout/swizzle.py
23-23: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
54-54: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
104-104: Avoid specifying long messages outside the exception class
(TRY003)
126-126: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-nvidia
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
🔇 Additional comments (14)
.clang-tidy (1)
49-49: Re-evaluate disabling of clang-diagnostic-tautological-constant-compare in .clang-tidy:49. No instances of this diagnostic were found in a full clang-tidy scan—please confirm it’s actually triggered in your build configuration or remove it if unnecessary.src/tl_templates/cuda/instruction/wgmma.h (1)
626-629: Clarify or remove unusedWgmmaOutputRegstrait
- No references to
WgmmaOutputRegswere found in the codebase. If it isn’t used, delete it; otherwise clarify whethervalueis meant to represent per-thread register usage (e.g. 4 for M64N8K16 F32) versus total registers across the warp group, and update the formula or documentation accordingly.src/op/gemm.cc (1)
585-591: New gemm_rs path selection looks goodSelecting rs/sr/ss by fragment/shared scope is consistent. No issues.
src/target/codegen_cuda.cc (2)
2123-2159: Descriptor allocation handling looks correct
local.descriptormapped totl::GmmaDescriptoravoids genericPrintStorageScopeand is consistent with GetBufferRef handling.
1960-1980: Descriptor intrinsics emission looks goodTemplated
initialize_descriptorand typedincrease_descriptor_offset<int>generation matches the Python builtins.src/layout/gemm_layouts.cc (1)
578-587: kfactor→k_inner migration looks correctBranching and error messages updated coherently across targets; no issues spotted.
Also applies to: 720-739, 741-766, 768-787
src/layout/layout.cc (8)
487-491: LGTM!The equality check for
Layoutis correctly implemented, delegating to the existingIsEqualmethod.
500-504: LGTM!The equality check for
Fragmentfollows the same pattern asLayout_is_equaland correctly delegates to the existingIsEqualmethod.
521-531: API simplification for common case.The function passes
continuousfor bothmat_continuousandcontinuityparameters to the underlying layout constructors. This is a reasonable simplification for the common case where these values are equal. For finer control, users can calltl.make_wgmma_swizzled_layoutwhich exposes both parameters separately.The addition of the
allow_padparameter with a default value oftruemaintains backward compatibility while providing control over which layout constructor to use.
532-537: LGTM!The new
make_wgmma_swizzled_layoutfunction correctly exposes all parameters (mat_continuousandcontinuityseparately) for fine-grained control over WGMMA layout construction, properly delegating tomakeGemmABLayoutHopper.
538-541: LGTM!The binding correctly delegates to
makeFullBankSwizzleLayoutwith the appropriate parameters.
542-545: LGTM!The binding correctly delegates to
makeHalfBankSwizzleLayoutwith the appropriate parameters.
546-550: LGTM!The binding correctly delegates to
makeQuarterBankSwizzleLayoutwith the appropriate parameters.
551-553: LGTM!The new
make_linear_layoutbinding correctly delegates tomakeGemmLayoutLinearfor creating linear (non-swizzled) layouts.
| int element_size, bool k_inner = true); | ||
| Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, | ||
| int continuity, int element_size, int kfactor); | ||
| int continuity, int element_size, | ||
| bool k_inner = true); | ||
| Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, | ||
| int element_size, int kfactor); | ||
| int element_size, bool k_inner = true); | ||
| Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, | ||
| int kfactor); | ||
| int kPack); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Audit callers after k_inner bool API change
APIs now take bool k_inner. Verify all call sites pass booleans, not ?: 1 : 2 ternaries. Sm100 calls in src/op/gemm.cc still pass ints.
Also applies to: 171-173, 174-176, 185-186
🏁 Script executed:
#!/bin/bash
# Find suspicious non-bool last-arg usages of Gemm AB layout constructors.
rg -nPU -C2 '\bmakeGemmABLayout(?:Hopper|Sm100)?\s*\('
# Heuristic: show calls where 5th arg looks like a ternary yielding 1/2
rg -nPU '\bmakeGemmABLayout(?:Hopper|Sm100)?\s*\((?:[^,]*,){4}\s*[^),]*\?\s*[12]\s*:\s*[12]\s*\)' -n -C2Length of output: 8230
Fix non-boolean arguments to makeGemmABLayout calls*
Several call sites still pass ints (e.g., trans_A ? 1 : 2, trans_B ? 2 : 1) for the new bool k_inner parameter. Update:
- src/op/gemm.cc: lines 784–786, 794–795
- src/op/gemm_sp.cc: lines 240–242, 254–255
to pass a boolean (e.g.,trans_Aor!trans_A) instead.
| // ================================= BF16 x BF16 -> F32 | ||
| // ================================= | ||
|
|
||
| // M64N8K16 BF16->F32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32, | ||
| 64, 8, 16, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %6, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " | ||
| "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // M64N16K16 BF16->F32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32, | ||
| 64, 16, 16, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile( | ||
| "{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %10, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " | ||
| "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), | ||
| "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), | ||
| "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // ================================= TF32 x TF32 -> F32 | ||
| // ================================= | ||
|
|
||
| // M64N8K8 TF32->F32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32, | ||
| DataType::kFloat32, 64, 8, 8, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %6, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " | ||
| "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // M64N16K8 TF32->F32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32, | ||
| DataType::kFloat32, 64, 16, 8, tnspA, tnspB, scaleA, | ||
| scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile( | ||
| "{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %10, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " | ||
| "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), | ||
| "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), | ||
| "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // ================================= INT8 x INT8 -> INT32 | ||
| // ================================= | ||
|
|
||
| // M64N8K32 S8->S32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 8, | ||
| 32, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %4, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " | ||
| "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // M64N16K32 S8->S32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 16, | ||
| 32, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %6, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " | ||
| "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // ================================= FP8 x FP8 -> F16/F32 | ||
| // ================================= | ||
|
|
||
| // M64N8K32 E4M3->F16 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, | ||
| DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA, | ||
| scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %4, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " | ||
| "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // M64N8K32 E4M3->F32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, | ||
| DataType::kFloat32, 64, 8, 32, tnspA, tnspB, scaleA, | ||
| scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %6, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " | ||
| "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // 函数模板委托给类模板 | ||
| template <DataType A_type, DataType B_type, DataType C_type, int M, int N, | ||
| int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1> | ||
| TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| WgmmaSSImpl<A_type, B_type, C_type, M, N, K, tnspA, tnspB, scaleA, | ||
| scaleB>::execute(desc_a, desc_b, c, scale_out); | ||
| } | ||
|
|
||
| // ================================= Mixed Precision Support | ||
| // ================================= | ||
|
|
||
| // Mixed precision: S8 x U8 -> S32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kInt8, DataType::kUInt8, DataType::kInt32, 64, 8, | ||
| 32, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %4, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " | ||
| "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // Mixed precision: U8 x S8 -> S32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kUInt8, DataType::kInt8, DataType::kInt32, 64, 8, | ||
| 32, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %4, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " | ||
| "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // Mixed precision: U8 x U8 -> S32 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kUInt8, DataType::kUInt8, DataType::kInt32, 64, 8, | ||
| 32, tnspA, tnspB, scaleA, scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %4, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " | ||
| "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // Mixed precision FP8: E4M3 x E5M2 -> F16 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e5m2, | ||
| DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA, | ||
| scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %4, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " | ||
| "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // Mixed precision FP8: E5M2 x E4M3 -> F16 | ||
| template <bool tnspA, bool tnspB, int scaleA, int scaleB> | ||
| struct WgmmaSSImpl<DataType::kFloat8_e5m2, DataType::kFloat8_e4m3, | ||
| DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA, | ||
| scaleB> { | ||
| TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, | ||
| bool scale_out) { | ||
| asm volatile("{\n" | ||
| ".reg .pred p;\n" | ||
| "setp.ne.b32 p, %4, 0;\n" | ||
| "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " | ||
| "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" | ||
| "}\n" | ||
| : "+r"(c[0]), "+r"(c[1]) | ||
| : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), | ||
| "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), | ||
| "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); | ||
| } | ||
| }; | ||
|
|
||
| // ================================= Convenience Templates | ||
| // ================================= | ||
|
|
||
| // Type trait to determine the number of output registers needed | ||
| template <DataType C_type, int M, int N> struct WgmmaOutputRegs { | ||
| static constexpr int value = | ||
| (M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8); | ||
| }; | ||
|
|
||
| // Type trait to get element size in bits | ||
| template <DataType dtype> struct ElementBits { | ||
| static constexpr int value = | ||
| (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 || | ||
| dtype == DataType::kInt32) | ||
| ? 32 | ||
| : (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 || | ||
| dtype == DataType::kInt16 || dtype == DataType::kUInt16) | ||
| ? 16 | ||
| : (dtype == DataType::kInt8 || dtype == DataType::kUInt8 || | ||
| dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2) | ||
| ? 8 | ||
| : (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4 | ||
| : 8; | ||
| }; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Verify tile size coverage is intentional.
Several data type combinations only support small tile sizes:
- BF16 → F32: only 64×8×16 and 64×16×16 (missing 64×32×16, 64×64×16, etc.)
- TF32 → F32: only 64×8×8 and 64×16×8
- INT8 → INT32: only 64×8×32 and 64×16×32
- FP8 variants: only 64×8×32
In contrast, F16 → F16 has full coverage up to 64×256×16. If users attempt unsupported tile sizes, they'll hit the fallback template with only a printf warning (since static_assert is commented out). Confirm whether:
- This limited coverage matches hardware capabilities (NVIDIA Hopper/Ada only supports certain tile configs per data type)
- The fallback should fail at compile time for missing specializations
- Additional specializations should be added
Run this script to check if other parts of the codebase rely on unsupported tile sizes:
🏁 Script executed:
#!/bin/bash
# Description: Search for wgmma_ss calls that might use unsupported tile configurations
# Search for wgmma_ss template instantiations
rg -nP 'wgmma_ss\s*<' -A2 -B2
# Search for BF16, TF32, INT8, FP8 usage patterns
rg -nP '(kBFloat16|kTensorFloat32|kInt8|kFloat8)' --type=cpp -C3Length of output: 43439
Align WgmmaSSImpl specializations with supported tile configurations
Codegen registers dense WMMA shapes for BF16, TF32, INT8 and FP8 with N up to 256, but WgmmaSSImpl only defines N=8,16 (and N=8 for FP8). Instantiations for N>16 will silently hit the fallback. Either add missing specializations for the full shape set or restrict codegen to supported shapes and re-enable compile-time asserts for unsupported cases.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/instruction/wgmma.h around lines 342–646, the review
notes that codegen emits WMMA shapes with N up to 256 but WgmmaSSImpl only
provides specializations for N=8 and N=16 (and N=8 for FP8), so instantiations
for larger N silently fall back; fix by either (A) adding WgmmaSSImpl
specializations that match every tile configuration codegen emits (for each
DataType combination and N values up to 256, mirroring the existing patterns for
M/K variants and ensuring the asm constraints/operand counts match each N), or
(B) change the codegen to only emit the N values already supported (8,16) and
re-enable/restore compile-time static_asserts that reject unsupported shapes so
generation fails fast; choose one approach and implement it consistently across
BF16, TF32, INT8, FP8 and mixed-precision variants.
| template <DataType dtype> struct ElementBits { | ||
| static constexpr int value = | ||
| (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 || | ||
| dtype == DataType::kInt32) | ||
| ? 32 | ||
| : (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 || | ||
| dtype == DataType::kInt16 || dtype == DataType::kUInt16) | ||
| ? 16 | ||
| : (dtype == DataType::kInt8 || dtype == DataType::kUInt8 || | ||
| dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2) | ||
| ? 8 | ||
| : (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4 | ||
| : 8; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle kBit types in ElementBits trait.*
The ElementBits trait's default case returns 8 bits, but DataType includes kBit16, kBit32, and kBit64 variants (from common.h lines 170-224) that should return their respective bit widths, not the default 8.
Apply this diff to handle kBit* types correctly:
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
- dtype == DataType::kInt16 || dtype == DataType::kUInt16)
+ dtype == DataType::kInt16 || dtype == DataType::kUInt16 ||
+ dtype == DataType::kBit16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
- dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
+ dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2 ||
+ dtype == DataType::kBit8)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
+ : (dtype == DataType::kBit32) ? 32
+ : (dtype == DataType::kBit64) ? 64
+ : (dtype == DataType::kBit1) ? 1
: 8;
};📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| template <DataType dtype> struct ElementBits { | |
| static constexpr int value = | |
| (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 || | |
| dtype == DataType::kInt32) | |
| ? 32 | |
| : (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 || | |
| dtype == DataType::kInt16 || dtype == DataType::kUInt16) | |
| ? 16 | |
| : (dtype == DataType::kInt8 || dtype == DataType::kUInt8 || | |
| dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2) | |
| ? 8 | |
| : (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4 | |
| : 8; | |
| }; | |
| template <DataType dtype> struct ElementBits { | |
| static constexpr int value = | |
| (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 || | |
| dtype == DataType::kInt32) | |
| ? 32 | |
| : (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 || | |
| dtype == DataType::kInt16 || dtype == DataType::kUInt16 || | |
| dtype == DataType::kBit16) | |
| ? 16 | |
| : (dtype == DataType::kInt8 || dtype == DataType::kUInt8 || | |
| dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2 || | |
| dtype == DataType::kBit8) | |
| ? 8 | |
| : (dtype == DataType::kInt4 || dtype == DataType::kUInt4) | |
| ? 4 | |
| : (dtype == DataType::kBit32) | |
| ? 32 | |
| : (dtype == DataType::kBit64) | |
| ? 64 | |
| : (dtype == DataType::kBit1) | |
| ? 1 | |
| : 8; | |
| }; |
🤖 Prompt for AI Agents
In src/tl_templates/cuda/instruction/wgmma.h around lines 632 to 645, the
ElementBits trait currently falls back to 8 for unknown types but omits handling
DataType::kBit16, kBit32, and kBit64; update the conditional chain to explicitly
return 16 for kBit16, 32 for kBit32, and 64 for kBit64 (keeping the existing
mappings for float/int variants and retaining an 8-bit fallback) so those
bit-oriented DataType variants return their correct bit widths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)
examples/deepseek_v32/sparse_mla_bwd.py (1)
test_sparse_mla_bwd(334-383)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: format-check
🔇 Additional comments (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)
36-37: LGTM once the import is fixed.Making the function call explicit with keyword arguments improves clarity and aligns with the updated APIs mentioned in the PR objectives. The use of
SKV=4096instead of the default32768is reasonable for testing purposes (smaller, faster execution).
| test_sparse_mla_bwd( | ||
| B=1, S=4096, SKV=4096, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing import for torch.
The code references torch.bfloat16 on line 37, but the torch module is not imported in this file. This will cause a NameError at runtime when the test is executed.
Apply this diff to add the missing import:
# ruff: noqa
+import torch
import tilelang.testing📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| test_sparse_mla_bwd( | |
| B=1, S=4096, SKV=4096, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16) | |
| # ruff: noqa | |
| import torch | |
| import tilelang.testing |
🤖 Prompt for AI Agents
In examples/deepseek_v32/test_tilelang_example_deepseek_v32.py around lines 36
to 37, the test calls torch.bfloat16 but the torch module is not imported; add
an import statement for torch at the top of the file (e.g., import torch) so
torch.bfloat16 resolves, keeping import placement consistent with existing
imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)
examples/deepseek_v32/sparse_mla_bwd.py (1)
test_sparse_mla_bwd(334-383)
🪛 GitHub Actions: CI Test on AMD
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
[error] 1-1: Reformatted files. Please review and stage the changes.
🪛 GitHub Actions: CI Test on Metal
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
[error] 1-1: Code formatting check failed. Reformatted files. Please review and stage the changes. Changes not staged for commit: examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
…checks - Changed SKV parameter from 32768 to 8192 in sparse MLA backward and forward tests. - Added check_correctness parameter to test functions for validation of outputs. - Updated test cases to reflect new SKV values and correctness checks.
TODO Items
T.alloc_descriptorto create a static descriptor, allowing onlystart_addressupdates within tiled GEMM.int8,fp8,tf32.gemm_rs.Summary by CodeRabbit
New Features
API Changes
Chores