Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Nov 6, 2025

  • Introduced TILELANG_USE_GEMM_V1 environment variable to control the selection of GEMM version.
  • Added use_gemm_v1 method in the Environment class to determine if GEMM v1 should be used based on the environment variable.
  • Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.

Summary by CodeRabbit

  • New Features

    • Added support for selecting GEMM kernel versions via TILELANG_USE_GEMM_V1 environment variable.
    • Added FP64 (float64) support for MMA operations with corresponding store layout optimizations.
  • Bug Fixes

    • Improved shared memory allocation and alignment handling for better memory efficiency.
  • Refactor

    • Refactored GEMM implementation to use unified buffer region handling across all backends.
    • Improved backend configuration system in CMake for better maintainability.

- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version.
- Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable.
- Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.
@github-actions
Copy link

github-actions bot commented Nov 6, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 6, 2025

Walkthrough

This pull request introduces environment-driven GEMM kernel selection, refactors the GEMM infrastructure to normalize buffers to a unified BufferRegion representation, renames data members and methods to a consistent underscore-suffixed naming convention, adds region-aware utility functions, and enables FP64 MMA support while removing the BufferGemmCollector class.

Changes

Cohort / File(s) Summary
GEMM Kernel Selection
tilelang/env.py
Adds TILELANG_USE_GEMM_V1 environment variable and corresponding use_gemm_v1() reader method for conditional kernel selection.
GEMM Core Refactoring
tilelang/language/gemm.py, src/op/gemm.cc, src/op/gemm.h
Refactors GEMM to normalize A/B/C inputs to BufferRegion; centralizes logic via _gemm_impl(); renames public methods (AllowTCGEN5MMA → allowTcgen5Mma, GetGemmInst → getGemmInst, etc.) and data members (A/B/C → a_/b_/c_, trans_A → transA_, M/N/K → m_/n_/k_, etc.); adds region-based buffer handling and unified pointer/offset computation.
GemmPy Refactoring
src/op/gemm_py.cc, src/op/gemm_py.h
Renames public methods and data members following underscore convention; replaces direct pointer handling with region-based buffers (aRegion_/bRegion_/cRegion_); updates constructor and layout inference to use new member names and region normalization.
GemmSP Refactoring
src/op/gemm_sp.cc, src/op/gemm_sp.h
Renames ComputeWarpPartition to computeWarpPartition; updates all references to A/B/C/E buffers and M/N/K dimensions to underscore-suffixed variants (a_/b_/c_/e_, m_/n_/k_); adjusts Lower() and layout paths accordingly.
Intrinsics Region Support
tilelang/intrinsics/mfma_macro_generator.py, tilelang/intrinsics/mma_macro_generator.py, tilelang/intrinsics/mma_sm70_macro_generator.py, tilelang/intrinsics/wgmma_macro_generator.py
Extends ldmatrix and MMA signatures to accept `Buffer
FP64 MMA Support
tilelang/intrinsics/mma_macro_generator.py, src/tl_templates/cuda/instruction/mma.h, tilelang/intrinsics/mma_layout.py, tilelang/intrinsics/utils.py
Adds FP64 dtype abbreviation; introduces mma_store_index_map_fp64 and related layout functions; adds MmaDispatcher specialization for 8×8×4 FP64 DMMA.
Region & Buffer Utilities
tilelang/utils/language.py
Adds BufferLoad/BufferRegion support to existing scope checks; introduces to_buffer_region(), retrieve_shape(), retrieve_stride(), retrieve_offset(), retrieve_ptr(), retrive_ptr_from_buffer_region(), prim_expr_equal(), and is_full_region() helpers.
TILEOP/GEMM Wrappers
tilelang/tileop/gemm/__init__.py, tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/gemm_mfma.py, tilelang/tileop/gemm/gemm_mma.py, tilelang/tileop/gemm/gemm_mma_sm70.py, tilelang/tileop/gemm/gemm_sp.py, tilelang/tileop/gemm/gemm_tcgen05.py, tilelang/tileop/gemm/gemm_wgmma.py
Introduces property-based backward-compatibility layer with internal lowerCamel fields; updates region accessors (ARegion/BRegion/CRegion); adds region-based operand handling across all GEMM variants.
Shared Memory & Layout
tilelang/layout/swizzle.py, tilelang/language/builtin.py, tilelang/intrinsics/tcgen05_macro_generator.py
Extends layout creators to accept `Buffer
Operator Infrastructure
src/op/operator.h, src/transform/lower_tile_op.cc, tilelang/utils/__init__.py
Removes buffer_var_gemm from LowerArgs and eliminates BufferGemmCollector class; exports new region/buffer utilities (retrieve_stride, retrieve_shape, retrive_ptr_from_buffer_region, is_full_region, to_buffer_region).
Code Generation & Transforms
src/target/codegen_hip.cc, src/transform/storage_rewrite.cc, src/transform/merge_shared_memory_allocations.cc, src/transform/lower_opaque_block.cc, src/tl_templates/cuda/reduce.h
Updates bfloat16x4 type mapping in HIP codegen; adds metadata merging in storage rewrite; enhances shared memory alignment marking; adds AccType trait for wider accumulator selection; introduces IsEffectivelyEmptyAnnotation helper.
CMake & Build Configuration
CMakeLists.txt, .gitignore
Introduces tilelang_define_backend_option macro and per-backend caching (USE_CUDA/ROCM/METAL) with user-override tracking; adds pre-commit cache ignore rules.
Test & Maintenance Files
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py, maint/gemm_v2/correctness_evaluation.py, maint/gemm_v2/correctness_evaluation_sm70.py, testing/python/dynamic/test_tilelang_dynamic_symbolic.py, examples/linear_attention/example_linear_attn_fwd.py, examples/attention_sink/example_*.py, examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
Adds disable_cache() calls and kernel source printing; adjusts GEMM defaults (num_stages 3→2) with boundary-based overrides; removes test parameterization (N_VALUES); simplifies loop bounds to inline scalar expressions; updates import styles in tests.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant TileLang Runtime
    participant Environment
    participant GEMM Layer
    
    User->>Environment: Check TILELANG_USE_GEMM_V1
    Environment-->>TileLang Runtime: Returns boolean (true/false)
    TileLang Runtime->>GEMM Layer: Resolve gemm to gemm_v1 or gemm_v2
    GEMM Layer->>GEMM Layer: Normalize A/B/C to BufferRegion via to_buffer_region()
    GEMM Layer->>GEMM Layer: Extract shape, stride, offset from region
    GEMM Layer->>GEMM Layer: Call _gemm_impl(op_key, A, B, C, ...)
    GEMM Layer-->>User: Execute selected GEMM kernel
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Areas requiring extra attention:

  • BufferRegion normalization logic (src/op/gemm.cc, src/op/gemm_py.cc) — Validate NormalizeToBufferRegion and MakeAccessPtrFromRegion implementations for correctness across different buffer input forms
  • Method/field renaming consistency (src/op/gemm.h, src/op/gemm_py.h, src/op/gemm_sp.h) — Ensure all internal references and reflection bindings are updated consistently; verify backward-compatibility properties in TILEOP wrappers
  • Region-based intrinsics (tilelang/intrinsics/wgmma_macro_generator.py, tilelang/intrinsics/mma_macro_generator.py) — Validate pointer derivation from regions and offset calculations in ldmatrix/MMA paths
  • Environment-driven dispatch (tilelang/language/gemm.py, tilelang/env.py) — Verify TILELANG_USE_GEMM_V1 env var parsing and conditional routing between gemm_v1/gemm_v2 works as intended
  • FP64 MMA support (tilelang/intrinsics/mma_macro_generator.py) — Confirm FP64-specific tile sizing (8×8), store layouts, and micro-size adjustments are correct

Possibly related PRs

Suggested reviewers

  • chengyupku
  • xysmlx

Poem

🐰 With regions unified and buffers set free,
The GEMM now selects which kernel to be.
V1 or V2 dance at the env's command,
Normalization's grace spans memory's land.
Underscore suffixes march in neat parade—
A rabbit's refactor, cleanly made! 🌟

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 26.07% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: adding an environment-controlled option for GEMM kernel selection, which aligns with the PR's primary objective of introducing TILELANG_USE_GEMM_V1.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999 LeiWang1999 marked this pull request as ready for review November 6, 2025 05:46
- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations.
- Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value.
- Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (3)
tilelang/language/builtin.py (2)

460-491: Consider extracting stride calculation to reduce duplication.

The stride calculation logic (lines 469-478) is duplicated in the BufferRegion branch (lines 517-526). Additionally, the control flow is inconsistent: BufferLoad and BufferRegion branches return early, while Buffer and pointer branches fall through to a common return.

Consider extracting the stride calculation into a helper function:

def _compute_element_offset(indices_or_mins, buffer):
    """Compute element offset from indices/mins using buffer strides or row-major fallback."""
    if len(buffer.strides) == len(buffer.shape) and len(buffer.strides) > 0:
        elem_off = 0
        for idx, stride in zip(indices_or_mins, buffer.strides):
            elem_off = elem_off + idx * stride
    else:
        elem_off = 0
        stride_acc = 1
        for idx, dim in zip(reversed(indices_or_mins), reversed(buffer.shape)):
            elem_off = elem_off + idx * stride_acc
            stride_acc = stride_acc * dim
    return elem_off

Then use it in both branches:

elem_off = _compute_element_offset(buffer_or_ptr.indices, buf)

For consistency, consider restructuring all branches to either use early returns or fall through to a common return statement.


555-577: Consider narrowing the exception handling scope.

The blind exception catch on line 572 is flagged by static analysis. While this defensive pattern is reasonable for probing optional attributes during type inference, consider catching more specific exceptions like AttributeError:

-            except Exception:
+            except AttributeError:
                 inferred = None

This makes the code's intent clearer and avoids accidentally silencing unexpected errors.

tilelang/intrinsics/tcgen05_macro_generator.py (1)

249-278: Prefer TypeError for invalid type errors.

The access_ptr_from helper correctly handles Buffer, BufferLoad, and BufferRegion, but should raise TypeError instead of ValueError when encountering unsupported types (lines 266, 278).

Apply this diff:

                 else:
-                    raise ValueError(f"Unsupported index type: {type(indice)}")
+                    raise TypeError(f"Unsupported index type: {type(indice)}")
                 stride *= shape
             return buffer.access_ptr(access_type, offset=offset)
         else:
-            raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+            raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7089b00 and 03af3e7.

📒 Files selected for processing (25)
  • examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2 hunks)
  • src/op/gemm.cc (10 hunks)
  • src/op/gemm.h (3 hunks)
  • src/op/gemm_py.cc (7 hunks)
  • src/op/gemm_py.h (2 hunks)
  • src/op/gemm_sp.cc (4 hunks)
  • src/op/gemm_sp.h (3 hunks)
  • src/op/operator.h (0 hunks)
  • src/transform/lower_tile_op.cc (2 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (7 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (9 hunks)
  • tilelang/intrinsics/mma_sm70_macro_generator.py (6 hunks)
  • tilelang/intrinsics/tcgen05_macro_generator.py (2 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (4 hunks)
  • tilelang/language/builtin.py (3 hunks)
  • tilelang/language/gemm.py (5 hunks)
  • tilelang/layout/swizzle.py (5 hunks)
  • tilelang/tileop/gemm/__init__.py (1 hunks)
  • tilelang/tileop/gemm/gemm_base.py (2 hunks)
  • tilelang/tileop/gemm/gemm_mfma.py (4 hunks)
  • tilelang/tileop/gemm/gemm_mma.py (4 hunks)
  • tilelang/tileop/gemm/gemm_mma_sm70.py (3 hunks)
  • tilelang/tileop/gemm/gemm_tcgen05.py (1 hunks)
  • tilelang/tileop/gemm/gemm_wgmma.py (3 hunks)
  • tilelang/utils/language.py (2 hunks)
💤 Files with no reviewable changes (1)
  • src/op/operator.h
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/op/gemm.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/op/gemm.cc
🧬 Code graph analysis (23)
tilelang/tileop/gemm/gemm_tcgen05.py (1)
tilelang/tileop/gemm/gemm_base.py (2)
  • ARegion (79-80)
  • BRegion (83-84)
tilelang/intrinsics/mma_sm70_macro_generator.py (2)
tilelang/intrinsics/mfma_macro_generator.py (3)
  • _warp_ldmatrix_a (278-301)
  • ldmatrix_b (305-358)
  • ldmatrix_b (795-870)
tilelang/intrinsics/mma_macro_generator.py (6)
  • _warp_ldmatrix_a (249-285)
  • _warp_ldmatrix_a (808-896)
  • ldmatrix_b (289-376)
  • ldmatrix_b (900-1012)
  • mma_load_layout (223-224)
  • mma_load_layout (307-308)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/intrinsics/mma_macro_generator.py (6)
  • ldmatrix_a (207-287)
  • ldmatrix_a (794-898)
  • _warp_ldmatrix_a (249-285)
  • _warp_ldmatrix_a (808-896)
  • ldmatrix_b (289-376)
  • ldmatrix_b (900-1012)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
  • ldmatrix_a (190-236)
  • _warp_ldmatrix_a (220-234)
  • ldmatrix_b (238-290)
tilelang/tileop/gemm/gemm_mma.py (4)
tilelang/tileop/gemm/gemm_base.py (4)
  • ARegion (79-80)
  • A (67-68)
  • BRegion (83-84)
  • B (71-72)
tilelang/utils/language.py (1)
  • is_shared (47-62)
tilelang/intrinsics/mfma_macro_generator.py (2)
  • ldmatrix_b (305-358)
  • ldmatrix_b (795-870)
tilelang/intrinsics/mma_macro_generator.py (2)
  • ldmatrix_b (289-376)
  • ldmatrix_b (900-1012)
tilelang/language/gemm.py (4)
src/op/builtin.h (1)
  • tvm (13-575)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (164-186)
tilelang/env.py (2)
  • get (175-178)
  • use_gemm_v1 (281-287)
tilelang/primitives/gemm/__init__.py (1)
  • gemm (10-46)
tilelang/tileop/gemm/__init__.py (1)
tilelang/tileop/gemm/gemm_base.py (15)
  • A (67-68)
  • B (71-72)
  • C (75-76)
  • M (34-35)
  • N (38-39)
  • K (42-43)
  • trans_A (46-47)
  • trans_B (50-51)
  • stride_A (91-92)
  • stride_B (95-96)
  • offset_A (99-100)
  • offset_B (103-104)
  • clear_accum (107-108)
  • k_pack (111-112)
  • wg_wait (115-116)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
tilelang/env.py (1)
  • disable_cache (275-276)
src/op/gemm_py.cc (4)
tilelang/language/utils.py (1)
  • region (8-27)
tilelang/language/tir/op.py (1)
  • tvm_access_ptr (651-676)
src/op/gemm.cc (2)
  • NormalizeToBufferRegion (52-97)
  • NormalizeToBufferRegion (52-52)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (74-81)
  • GetVarFromAccessPtr (74-74)
tilelang/tileop/gemm/gemm_mma_sm70.py (5)
tilelang/tileop/gemm/gemm_base.py (4)
  • ARegion (79-80)
  • A (67-68)
  • BRegion (83-84)
  • B (71-72)
tilelang/utils/language.py (1)
  • is_shared (47-62)
tilelang/intrinsics/mfma_macro_generator.py (2)
  • ldmatrix_b (305-358)
  • ldmatrix_b (795-870)
tilelang/intrinsics/mma_macro_generator.py (2)
  • ldmatrix_b (289-376)
  • ldmatrix_b (900-1012)
tilelang/intrinsics/mma_sm70_macro_generator.py (1)
  • ldmatrix_b (238-290)
src/op/gemm_sp.cc (4)
src/op/gemm.cc (2)
  • computeWarpPartition (222-402)
  • computeWarpPartition (222-223)
src/target/utils.cc (6)
  • TargetGetWarpSize (130-135)
  • TargetGetWarpSize (130-130)
  • TargetIsHopper (52-57)
  • TargetIsHopper (52-52)
  • TargetIsAmpere (45-50)
  • TargetIsAmpere (45-45)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (74-81)
  • GetVarFromAccessPtr (74-74)
src/layout/gemm_layouts.cc (8)
  • makeGemmFragmentC (121-136)
  • makeGemmFragmentC (121-123)
  • makeGemmABLayoutHopper (741-766)
  • makeGemmABLayoutHopper (741-742)
  • makeGemmSparseFragmentC (138-157)
  • makeGemmSparseFragmentC (138-140)
  • makeGemmSparseAmpereABLayout (683-688)
  • makeGemmSparseAmpereABLayout (683-684)
tilelang/intrinsics/tcgen05_macro_generator.py (2)
src/transform/lower_tile_op.cc (8)
  • access_ptr (287-385)
  • access_ptr (288-290)
  • buffer (272-280)
  • buffer (272-272)
  • buffer (401-418)
  • buffer (401-401)
  • buffer (420-437)
  • buffer (420-420)
tilelang/intrinsics/wgmma_macro_generator.py (2)
  • _warp_mma (300-336)
  • _warp_mma (406-458)
src/op/gemm_sp.h (4)
src/op/gemm.cc (2)
  • computeWarpPartition (222-402)
  • computeWarpPartition (222-223)
src/op/gemm_sp.cc (2)
  • computeWarpPartition (21-63)
  • computeWarpPartition (21-25)
tilelang/tileop/gemm/__init__.py (2)
  • M (89-90)
  • N (93-94)
tilelang/tileop/gemm/gemm_base.py (2)
  • M (34-35)
  • N (38-39)
tilelang/tileop/gemm/gemm_wgmma.py (3)
tilelang/tileop/gemm/gemm_base.py (12)
  • A_base_offsets (152-154)
  • B_base_offsets (157-159)
  • C_base_offsets (162-164)
  • ARegion (79-80)
  • BRegion (83-84)
  • CRegion (87-88)
  • C (75-76)
  • clear_accum (107-108)
  • wg_wait (115-116)
  • is_gemm_ss (21-22)
  • is_gemm_rs (27-28)
  • A (67-68)
tilelang/intrinsics/wgmma_macro_generator.py (1)
  • wgmma (163-338)
tilelang/transform/simplify.py (1)
  • _Simplify (31-49)
tilelang/intrinsics/mma_macro_generator.py (2)
tilelang/intrinsics/mfma_macro_generator.py (4)
  • extract_thread_binding (226-253)
  • _warp_ldmatrix_a (278-301)
  • ldmatrix_b (305-358)
  • ldmatrix_b (795-870)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
  • extract_thread_binding (158-188)
  • _warp_ldmatrix_a (220-234)
  • ldmatrix_b (238-290)
tilelang/tileop/gemm/gemm_mfma.py (2)
tilelang/tileop/gemm/gemm_base.py (4)
  • ARegion (79-80)
  • A (67-68)
  • BRegion (83-84)
  • B (71-72)
tilelang/utils/language.py (1)
  • is_shared (47-62)
tilelang/utils/language.py (3)
src/transform/lower_tile_op.cc (2)
  • buffer (272-280)
  • buffer (272-272)
src/transform/flatten_buffer.cc (2)
  • buffer (295-321)
  • buffer (295-296)
src/transform/legalize_safe_memory_access.cc (12)
  • buffer (87-95)
  • buffer (87-87)
  • buffer (98-138)
  • buffer (98-99)
  • buffer (256-260)
  • buffer (256-256)
  • buffer (262-265)
  • buffer (262-262)
  • buffer (267-270)
  • buffer (267-267)
  • buffer (272-277)
  • buffer (272-272)
tilelang/intrinsics/wgmma_macro_generator.py (4)
src/transform/lower_tile_op.cc (8)
  • access_ptr (287-385)
  • access_ptr (288-290)
  • buffer (272-280)
  • buffer (272-272)
  • buffer (401-418)
  • buffer (401-401)
  • buffer (420-437)
  • buffer (420-420)
tilelang/language/tir/entry.py (1)
  • macro (66-117)
tilelang/intrinsics/tcgen05_macro_generator.py (1)
  • _warp_mma (281-338)
tilelang/language/builtin.py (1)
  • initialize_wgmma_descriptor (700-727)
tilelang/language/builtin.py (4)
src/op/builtin.h (1)
  • tvm (13-575)
src/transform/loop_vectorize.cc (2)
  • indices (157-189)
  • indices (157-157)
tilelang/language/tir/op.py (4)
  • call_intrin (120-145)
  • tvm_access_ptr (651-676)
  • address_of (464-480)
  • type_annotation (635-648)
tilelang/language/utils.py (1)
  • region (8-27)
src/op/gemm_py.h (3)
src/op/gemm.cc (8)
  • checkWgmma (434-484)
  • checkWgmma (434-434)
  • allowTcgen5Mma (186-193)
  • allowTcgen5Mma (186-186)
  • allowWgmma (195-203)
  • allowWgmma (195-195)
  • getGemmInst (205-220)
  • getGemmInst (205-205)
src/op/gemm_py.cc (8)
  • checkWgmma (255-305)
  • checkWgmma (255-255)
  • allowTcgen5Mma (186-193)
  • allowTcgen5Mma (186-186)
  • allowWgmma (195-203)
  • allowWgmma (195-195)
  • getGemmInst (205-223)
  • getGemmInst (205-205)
src/op/gemm.h (2)
  • RegisterReflection (35-41)
  • RegisterReflection (106-128)
tilelang/tileop/gemm/gemm_base.py (1)
tilelang/tileop/gemm/__init__.py (13)
  • N (93-94)
  • K (97-98)
  • trans_A (101-102)
  • trans_B (105-106)
  • B (69-70)
  • C (73-74)
  • stride_A (109-110)
  • stride_B (113-114)
  • offset_A (117-118)
  • offset_B (121-122)
  • clear_accum (125-126)
  • k_pack (129-130)
  • wg_wait (133-134)
src/op/gemm.h (3)
src/op/gemm.cc (10)
  • computeWarpPartition (222-402)
  • computeWarpPartition (222-223)
  • checkWgmma (434-484)
  • checkWgmma (434-434)
  • getGemmInst (205-220)
  • getGemmInst (205-205)
  • allowTcgen5Mma (186-193)
  • allowTcgen5Mma (186-186)
  • allowWgmma (195-203)
  • allowWgmma (195-195)
src/op/gemm_sp.cc (2)
  • computeWarpPartition (21-63)
  • computeWarpPartition (21-25)
src/op/gemm_py.h (1)
  • RegisterReflection (43-67)
tilelang/layout/swizzle.py (2)
src/transform/lower_tile_op.cc (6)
  • buffer (272-280)
  • buffer (272-272)
  • buffer (401-418)
  • buffer (401-401)
  • buffer (420-437)
  • buffer (420-420)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
src/op/gemm.cc (4)
src/op/gemm_py.cc (13)
  • strides (80-80)
  • NormalizeToBufferRegion (24-69)
  • NormalizeToBufferRegion (24-24)
  • allowTcgen5Mma (186-193)
  • allowTcgen5Mma (186-186)
  • allowWgmma (195-203)
  • allowWgmma (195-195)
  • checkWgmma (255-305)
  • checkWgmma (255-255)
  • getGemmInst (205-223)
  • getGemmInst (205-205)
  • MakeAccessPtrFromRegion (74-102)
  • MakeAccessPtrFromRegion (74-74)
src/op/gemm.h (5)
  • Gemm (144-149)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
tilelang/ir.py (2)
  • Gemm (43-44)
  • GemmWarpPolicy (30-39)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (74-81)
  • GetVarFromAccessPtr (74-74)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py

6-6: Redefinition of unused tvm from line 2

Remove definition: tvm

(F811)


298-298: Ambiguous variable name: l

(E741)

tilelang/language/gemm.py

137-137: Avoid specifying long messages outside the exception class

(TRY003)


458-458: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/tcgen05_macro_generator.py

266-266: Prefer TypeError exception for invalid type

(TRY004)


266-266: Avoid specifying long messages outside the exception class

(TRY003)


278-278: Prefer TypeError exception for invalid type

(TRY004)


278-278: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_wgmma.py

94-94: Local variable A_offsets is assigned to but never used

Remove assignment to unused variable A_offsets

(F841)


95-95: Local variable B_offsets is assigned to but never used

Remove assignment to unused variable B_offsets

(F841)


96-96: Local variable C_offsets is assigned to but never used

Remove assignment to unused variable C_offsets

(F841)


102-102: Local variable C_region is assigned to but never used

Remove assignment to unused variable C_region

(F841)

tilelang/utils/language.py

30-30: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/wgmma_macro_generator.py

266-266: Undefined name Union

(F821)


275-275: Undefined name tir

(F821)


277-277: Undefined name tir

(F821)


279-279: Undefined name tir

(F821)


282-282: Prefer TypeError exception for invalid type

(TRY004)


282-282: Avoid specifying long messages outside the exception class

(TRY003)


282-282: Undefined name indices

(F821)


294-294: Prefer TypeError exception for invalid type

(TRY004)


294-294: Avoid specifying long messages outside the exception class

(TRY003)


301-301: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tilelang/language/builtin.py

466-466: Avoid specifying long messages outside the exception class

(TRY003)


482-482: Avoid specifying long messages outside the exception class

(TRY003)


514-514: Avoid specifying long messages outside the exception class

(TRY003)


543-543: Avoid specifying long messages outside the exception class

(TRY003)


572-572: Do not catch blind exception: Exception

(BLE001)


575-576: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/layout/swizzle.py

29-29: Avoid specifying long messages outside the exception class

(TRY003)


89-89: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


106-106: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🔇 Additional comments (10)
tilelang/language/builtin.py (3)

11-11: LGTM!

The BufferRegion import is correctly added to support the new handling branch.


443-456: LGTM!

The docstring clearly documents the expanded accepted types and dtype inference behavior.


509-552: LGTM!

The BufferRegion handling correctly computes offsets from region mins and includes a nice enhancement to infer num_regs from fully static region extents. The fallback to requiring explicit num_regs for symbolic extents is appropriate.

tilelang/tileop/gemm/gemm_tcgen05.py (1)

111-112: LGTM: Region-based operands for TCGEN5.

The switch from self.A/self.B to self.ARegion/self.BRegion correctly aligns with the broader refactoring toward region-based addressing for shared-memory operands.

tilelang/tileop/gemm/gemm_mma.py (1)

86-89: LGTM: Region-based input selection for shared memory.

The conditional selection of ARegion/BRegion when buffers are in shared memory correctly supports strided/offset tiles while maintaining backward compatibility for non-shared buffers.

src/transform/lower_tile_op.cc (1)

643-644: LGTM: Simplified GEMM lowering path.

The removal of buffer_var_gemm_ from LowerArgs correctly reflects the elimination of the GEMM-buffer collection step, aligning with the broader move to region-based GEMM handling.

tilelang/tileop/gemm/gemm_mfma.py (1)

87-89: LGTM: Consistent region-based input handling.

The conditional selection of ARegion/BRegion matches the pattern used in gemm_mma.py, correctly supporting region-based addressing for shared-memory operands.

tilelang/intrinsics/mma_macro_generator.py (1)

236-246: LGTM: Clean BufferRegion support.

The logic to unwrap BufferRegion into the underlying buffer and base offsets is well-structured and correctly handles both Buffer and BufferRegion inputs uniformly.

tilelang/language/gemm.py (1)

493-493: LGTM: Environment-driven GEMM selection.

The conditional assignment of gemm based on _env.use_gemm_v1() correctly implements the environment-driven selection mechanism described in the PR objectives, defaulting to v2 while allowing v1 to be forced via TILELANG_USE_GEMM_V1.

tilelang/intrinsics/mma_sm70_macro_generator.py (1)

209-235: BufferRegion offset handling aligns with the existing pattern.

The base-offset extraction mirrors the mfma path, so the SM70 macro can now consume BufferRegion sources without surprises.

Comment on lines +89 to +92
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix pointer offset when rebuilding access_ptr

MakeAccessPtrFromRegion never adds the minima of the last two dimensions to offset. For a 2D region such as A[i0:i0+16, j0:j0+8], the loop runs zero times, so we always hand the kernels a pointer to A[0,0]. Every GEMM tile then reads/writes the wrong slice. Please accumulate the minima for all dimensions (same change is needed in gemm.cc).

-  for (int i = 0; i < ndim - 2; ++i) {
+  for (int i = 0; i < ndim; ++i) {
     offset = offset + region->region[i]->min * strides[i];
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim; ++i) {
offset = offset + region->region[i]->min * strides[i];
}

Comment on lines +89 to +92
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix pointer offset when rebuilding access_ptr

Identical to the gemm_py.cc case: skipping the last two dimensions in the offset sum makes every reconstructed pointer land at element (0,0) for typical 2D regions, so all GEMM tiles operate on the wrong data. Please include the minima of every dimension when computing offset.

-  for (int i = 0; i < ndim - 2; ++i) {
+  for (int i = 0; i < ndim; ++i) {
     offset = offset + region->region[i]->min * strides[i];
   }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/op/gemm.cc around lines 89 to 92, the pointer offset computation when
rebuilding access_ptr incorrectly skips the last two dimensions causing all
reconstructed pointers to point to element (0,0); fix it by including the minima
for every dimension in the offset calculation (i.e., do not skip the last two
dims), summing (index - min) * stride for each dimension so the computed offset
accounts for all dims’ minima, and use that full offset when constructing the
rebuilt access_ptr.

Comment on lines 264 to 310
thread_binding = self.get_thread_binding()

def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Union[Buffer, BufferLoad, BufferRegion], access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, tir.IntImm):
offset += indice * stride
elif isinstance(indice, tir.PrimExpr):
offset += indice * stride
elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indices)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")

A_ptr = retrive_ptr_from_buffer_region(A_buf)
B_ptr = retrive_ptr_from_buffer_region(B_buf)

@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
def _warp_mma():
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)

desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode,
int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix NameError when resolving BufferLoad offsets

retrive_ptr_from_buffer_region calls tir.IntImm / tir.PrimExpr, but tir is never imported in this module. The first BufferLoad you run through this helper will raise a NameError, so the GEMM flow regresses for the very BufferRegion cases this PR is enabling. Please import tir (and Union, since it’s now part of the signature) so the helper can execute.

-from typing import Callable
+from typing import Callable, Union
-from tvm import DataType
+from tvm import DataType, tir-        if isinstance(indice, tir.IntImm):
+        if isinstance(indice, tir.IntImm):-        elif isinstance(indice, tir.PrimExpr):
+        elif isinstance(indice, tir.PrimExpr):

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.3)

266-266: Undefined name Union

(F821)


275-275: Undefined name tir

(F821)


277-277: Undefined name tir

(F821)


279-279: Undefined name tir

(F821)


282-282: Prefer TypeError exception for invalid type

(TRY004)


282-282: Avoid specifying long messages outside the exception class

(TRY003)


282-282: Undefined name indices

(F821)


294-294: Prefer TypeError exception for invalid type

(TRY004)


294-294: Avoid specifying long messages outside the exception class

(TRY003)


301-301: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 264 to 310, the
helper retrive_ptr_from_buffer_region references tir and Union but the module
never imports them; add the necessary imports (e.g., add "from typing import
Union" and "from tvm import tir" or "import tvm.tir as tir" at the top of the
file) so the type checks against tir.IntImm/tir.PrimExpr/tir.Ramp and the Union
annotation work without raising NameError.

- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`.
- Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types.
- Simplified buffer access logic for better clarity and maintainability.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/intrinsics/wgmma_macro_generator.py (1)

371-371: Fix undefined B_buf reference.

Line 371 references B_buf which doesn't exist in wgmma_rs. The parameter is B_region, so this should be B_region.buffer or just B_region depending on what _determinate_swizzle_mode expects.

Apply this diff:

-        b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
♻️ Duplicate comments (2)
tilelang/tileop/gemm/gemm_wgmma.py (1)

90-102: Unused offset and region variables.

The variables A_offsets, B_offsets, C_offsets, and C_region are assigned but never used in the subsequent code. While the comments indicate these are placeholders for future offset handling through BufferRegion, they currently serve no purpose and create noise.

Based on past reviews, consider either:

  1. Removing these assignments until they're actually needed
  2. Adding a clear TODO comment explaining when/how they'll be used
  3. Commenting them out if they document the intended future API
tilelang/intrinsics/wgmma_macro_generator.py (1)

4-7: Import Union and tir to fix NameError.

The helper function retrive_ptr_from_buffer_region at line 266 uses Union in its type hint and references tir.IntImm, tir.PrimExpr, and tir.Ramp without importing them. This will cause a NameError at runtime.

Apply this diff to add the missing imports:

-from typing import Callable
+from typing import Callable, Union
-from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferLoad, BufferRegion
+from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferLoad, BufferRegion
+from tvm import tir

Based on past reviews.

🧹 Nitpick comments (2)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)

55-55: Consider removing redundant disable_cache() call.

This call is redundant since disable_cache() is already invoked at the module level (line 4). While not harmful, removing one of these calls would reduce duplication.

Apply this diff if you prefer to keep only the function-level call:

-tilelang.disable_cache()
 # add decorator @tilelang.jit if you want to return a torch function
tilelang/intrinsics/wgmma_macro_generator.py (1)

266-266: Fix typo in helper function name.

The function name retrive_ptr_from_buffer_region has a typo—it should be retrieve_ptr_from_buffer_region.

When refactoring per the previous comments, correct the spelling:

-def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Union[Buffer, BufferLoad, BufferRegion], access_type: str = "r") -> PrimExpr:
+def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Union[Buffer, BufferLoad, BufferRegion], access_type: str = "r") -> PrimExpr:

And update all call sites accordingly.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 03af3e7 and ca4416f.

📒 Files selected for processing (5)
  • examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2 hunks)
  • testing/python/dynamic/test_tilelang_dynamic_symbolic.py (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (9 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (8 hunks)
  • tilelang/tileop/gemm/gemm_wgmma.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/intrinsics/mma_macro_generator.py
🧰 Additional context used
🧬 Code graph analysis (3)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
tilelang/env.py (1)
  • disable_cache (275-276)
tilelang/intrinsics/wgmma_macro_generator.py (4)
tilelang/utils/language.py (1)
  • is_fragment (107-118)
src/transform/lower_tile_op.cc (8)
  • access_ptr (287-385)
  • access_ptr (288-290)
  • buffer (272-280)
  • buffer (272-272)
  • buffer (401-418)
  • buffer (401-401)
  • buffer (420-437)
  • buffer (420-420)
tilelang/language/utils.py (1)
  • region (8-27)
tilelang/language/builtin.py (2)
  • initialize_wgmma_descriptor (700-727)
  • warpgroup_fence_operand (433-589)
tilelang/tileop/gemm/gemm_wgmma.py (3)
tilelang/tileop/gemm/gemm_base.py (11)
  • A_base_offsets (152-154)
  • B_base_offsets (157-159)
  • C_base_offsets (162-164)
  • ARegion (79-80)
  • BRegion (83-84)
  • CRegion (87-88)
  • clear_accum (107-108)
  • wg_wait (115-116)
  • is_gemm_ss (21-22)
  • is_gemm_rs (27-28)
  • A (67-68)
tilelang/intrinsics/wgmma_macro_generator.py (1)
  • wgmma (163-340)
tilelang/transform/simplify.py (1)
  • _Simplify (31-49)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py

266-266: Undefined name Union

(F821)


275-275: Undefined name tir

(F821)


277-277: Undefined name tir

(F821)


279-279: Undefined name tir

(F821)


282-282: Prefer TypeError exception for invalid type

(TRY004)


282-282: Avoid specifying long messages outside the exception class

(TRY003)


282-282: Undefined name indices

(F821)


294-294: Prefer TypeError exception for invalid type

(TRY004)


294-294: Avoid specifying long messages outside the exception class

(TRY003)


302-302: Unused function argument: A_region

(ARG001)


302-302: Unused function argument: B_region

(ARG001)


303-303: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


409-409: Undefined name retrive_ptr_from_buffer_region

(F821)


414-414: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


414-414: Unpacked variable warp_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


417-417: Undefined name B_buf

(F821)


419-419: Undefined name B_buf

(F821)


422-422: Undefined name B_buf

(F821)

tilelang/tileop/gemm/gemm_wgmma.py

94-94: Local variable A_offsets is assigned to but never used

Remove assignment to unused variable A_offsets

(F841)


95-95: Local variable B_offsets is assigned to but never used

Remove assignment to unused variable B_offsets

(F841)


96-96: Local variable C_offsets is assigned to but never used

Remove assignment to unused variable C_offsets

(F841)

🔇 Additional comments (3)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)

4-4: Verify the necessity of module-level cache disabling.

Calling disable_cache() at module import time affects global state and will disable caching for any code that imports this example module. For a standalone example script, the function-level call at line 55 would typically be sufficient.

If the module-level call is intentional (e.g., to ensure cache is disabled regardless of how this example is used), consider adding a comment explaining the rationale.


60-61: Good addition for example clarity.

Printing the kernel source is helpful for users learning about kernel generation and debugging. This makes the generated CUDA code immediately visible during execution.

Note: Lines 78-80 contain commented-out code that performs similar kernel source printing, which is now redundant but can be left as-is or removed for cleanliness.

tilelang/intrinsics/wgmma_macro_generator.py (1)

302-302: Unused macro parameters should be prefixed with underscore.

The _warp_mma macro parameters A_region and B_region are marked as unused by static analysis. However, these are passed through from the outer scope and used via closure (the computed A_ptr, B_ptr, C_buf are captured). This is a false positive from the linter.

Comment on lines 516 to +518
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Test harness bypass breaks test discovery and coverage.

Replacing tilelang.testing.main() with a single direct assertion prevents the test framework from discovering and executing the other test functions defined in this module (test_assert_tl_matmul_block, test_assert_tl_matmul_block_all_dynamic, test_assert_tl_matmul_block_all_dynamic_with_pass_config). This significantly reduces test coverage.

If this change is intentional for debugging, consider reverting it before merging. Otherwise, restore the test harness:

 if __name__ == "__main__":
-    # tilelang.testing.main()
-    assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
+    tilelang.testing.main()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
if __name__ == "__main__":
tilelang.testing.main()
🤖 Prompt for AI Agents
In testing/python/dynamic/test_tilelang_dynamic_symbolic.py around lines
516-518, the test entrypoint was changed to call
assert_tl_matmul_macro_correctness(...) directly which bypasses
tilelang.testing.main() and breaks test discovery/coverage; restore the test
harness by removing the direct assertion and calling tilelang.testing.main() (or
re-enable the original commented call), or if this direct call is only needed
for temporary debugging wrap it behind an explicit debug guard (e.g. env var) so
the normal tilelang.testing.main() runs by default.

elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indices)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix undefined variable name.

Line 282 references indices which is undefined. The variable name should be indice (the loop variable from line 274).

Apply this diff:

-                        raise ValueError(f"Unsupported index type: {type(indices)}")
+                        raise ValueError(f"Unsupported index type: {type(indice)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
raise ValueError(f"Unsupported index type: {type(indices)}")
raise ValueError(f"Unsupported index type: {type(indice)}")
🧰 Tools
🪛 Ruff (0.14.3)

282-282: Prefer TypeError exception for invalid type

(TRY004)


282-282: Avoid specifying long messages outside the exception class

(TRY003)


282-282: Undefined name indices

(F821)

🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around line 282, the code raises
a ValueError referencing an undefined variable `indices`; change the reference
to the loop variable `indice` (from the for-loop that starts at line 274) so the
exception message uses the correct variable name.

…andling

- Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices.
- Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion.
- Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability.
- Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.
- Cleaned up formatting and spacing in GEMM-related files for better readability.
- Standardized comments and code structure across various GEMM functions and macros.
- Enhanced error messages for clarity in buffer region checks.
- Removed redundant lines and improved overall code maintainability.
…unctionality

- Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests.
- Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing.
- Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access.
- Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.
- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code.
- Adjusted function signature formatting in `swizzle.py` for better readability.
- Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation.
- Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness.
- Improved function signature formatting in `language.py` for consistency.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

♻️ Duplicate comments (2)
src/op/gemm_py.cc (1)

95-102: Offset must include every dimension.

Skipping the last two dimensions when accumulating region->region[i]->min * strides[i] recreates the old bug: any 2D tile (the common case) rebuilds its pointer at (0,0), so every kernel reads/writes the wrong slice. Please sum minima across all axes when computing offset.

-  for (int i = 0; i < ndim - 2; ++i) {
+  for (int i = 0; i < ndim; ++i) {
     offset = offset + region->region[i]->min * strides[i];
   }
src/op/gemm.cc (1)

123-130: Same offset bug in GEMM path.

Identical to the issue flagged in gemm_py.cc: stopping the minima accumulation at ndim - 2 resets the pointer for typical 2D regions, so every tile operates on the wrong data. The offset must incorporate all dimensions.

-  for (int i = 0; i < ndim - 2; ++i) {
+  for (int i = 0; i < ndim; ++i) {
     offset = offset + region->region[i]->min * strides[i];
   }
🧹 Nitpick comments (12)
src/transform/lower_tile_op.cc (2)

445-452: Use Array for PTX op filtering to avoid type drift

The container holds RelaxExpr but you compare against op->op (Op). Make this Array for type clarity and safer equality.

Apply:

-    Array<RelaxExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
-                                         builtin::mma_store()};
+    Array<Op> ptx_instructions = {builtin::ptx_ldmatrix(),
+                                  builtin::mma_store()};

642-646: Introduce a named constructor for LowerArgs and replace the aggregate‐initializer

Aggregate‐initializing LowerArgs at src/transform/lower_tile_op.cc (lines 642–646) is brittle—any change to the field order in operator.h will silently break this call.

  1. In src/op/operator.h, add a constructor to LowerArgs:
    struct LowerArgs {
      Target target;
      Range thread_bounds;
      Var thread_var;
      AddWorkspaceCallback AddWorkspace;
      LayoutMap layout_map;
      Map<Buffer, Buffer> buffer_remap;
    
      LowerArgs(Target t,
                Range bounds,
                Var var,
                AddWorkspaceCallback cb,
                LayoutMap lm,
                Map<Buffer, Buffer> remap)
          : target(t),
            thread_bounds(bounds),
            thread_var(var),
            AddWorkspace(cb),
            layout_map(lm),
            buffer_remap(remap) {}
    };
  2. In src/transform/lower_tile_op.cc, update the call site:
    - auto lowered =
    -     tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var,
    -                              callback, layout_map_, buffer_remap_},
    -                    analyzer_);
    + auto lowered =
    +     tile_op->Lower(LowerArgs(target_,
    +                              thread_bounds,
    +                              thread_var_->var,
    +                              callback,
    +                              layout_map_,
    +                              buffer_remap_),
    +                    analyzer_);

No other aggregate‐init sites or BufferGemmCollector remnants were found.

examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)

57-57: Gate cache disabling behind an env/flag to avoid skewing profiling.

Disabling cache forces recompilation and can distort profiler results. Gate it with an env var so users can opt in.

Apply:

-    tilelang.disable_cache()
+    if os.getenv("TILELANG_DISABLE_CACHE", "0") == "1":
+        tilelang.disable_cache()

Add once near the imports:

import os  # for env-gated debug switches

63-64: Guard kernel source dump behind an env flag
The call to print(jit_kernel.get_kernel_source()) can flood stdout; wrap it with an environment check.
Apply:

-    print(jit_kernel.get_kernel_source())
+    if os.getenv("TILELANG_DEBUG_SRC", "0") == "1":
+        print(jit_kernel.get_kernel_source())
tilelang/layout/swizzle.py (4)

6-6: Import PrimExpr for accurate type hints and usage

Add PrimExpr to the tvm.tir import to support corrected annotations below.

-from tvm.tir import Buffer, BufferLoad, BufferRegion
+from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr

49-61: Confirm lane-aware element size for vector dtypes

tvm.DataType(dtype).bits returns per-lane bits. If dtypes with lanes (e.g., float16x8) can appear here, consider bits * lanes() or document that a scalar dtype is required and validated earlier.


94-100: Use Optional typing for continuity and allow PrimExpr

Fix implicit Optional (RUF013) and accept PrimExpr continuity.

-def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
-                               continuity: int = None,
+from typing import Optional  # add near imports if not present
+
+def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
+                               continuity: Optional[PrimExpr] = None,
                                k_major: bool = True):
@@
-    if continuity is None:
+    if continuity is None:
         continuity = continuous

110-117: Same typing fix for TCGEN05MMA continuity

Mirror the Optional[PrimExpr] change here.

-def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
-                                    continuity: int = None,
+def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
+                                    continuity: Optional[PrimExpr] = None,
                                     k_major: bool = True):
@@
-    if continuity is None:
+    if continuity is None:
         continuity = continuous
tilelang/utils/language.py (4)

245-245: Use TypeError for invalid type; shorten messages (TRY004, TRY003)

For type mismatches, raise TypeError and keep messages brief to satisfy linters.

-        raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
+        raise TypeError(f"Unsupported type: {type(obj)}")
@@
-                raise ValueError(f"Unsupported index type: {type(indice)}")
+                raise TypeError(f"Unsupported index type: {type(indice)}")
@@
-        raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+        raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
@@
-    raise ValueError(f"Unsupported retrieve_ptr argument type: {type(obj)} for object {obj}")
+    raise TypeError(f"Unsupported type: {type(obj)}")
@@
-    raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}")
+    raise TypeError(f"Unsupported type: {type(obj)}")

Also applies to: 270-270, 282-282, 329-329, 349-349


214-232: Unify scalar BufferLoad handling with to_buffer_region fallback

Instead of erroring on scalar BufferLoad, derive a 1-sized region via to_buffer_region for consistency.

-    if isinstance(obj, tir.BufferLoad):
-        region = get_buffer_region_from_load(obj)
-        if region is None:
-            raise ValueError("Cannot retrieve shape from scalar BufferLoad without region")
-        return [r.extent for r in region.region]
+    if isinstance(obj, tir.BufferLoad):
+        region = to_buffer_region(obj)
+        return [r.extent for r in region.region]

45-61: Simplify boolean logic in is_shared

Minor readability cleanup; no behavior change.

-    conditions = [False]
-    conditions.append(buffer.scope() == "shared")
-    if allow_dynamic:
-        conditions.append(is_shared_dynamic(buffer))
-    return any(conditions)
+    return (buffer.scope() == "shared") or (allow_dynamic and is_shared_dynamic(buffer))

352-368: Prefer the single helper prim_expr_equal you defined

Optionally use prim_expr_equal in other helpers (e.g., is_full_region) for consistency; current usage is fine.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ca4416f and c0c45d6.

📒 Files selected for processing (22)
  • .gitignore (1 hunks)
  • examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2 hunks)
  • maint/gemm_v2/correctness_evaluation.py (7 hunks)
  • maint/gemm_v2/correctness_evaluation_sm70.py (1 hunks)
  • src/op/gemm.cc (10 hunks)
  • src/op/gemm_py.cc (7 hunks)
  • src/op/gemm_sp.cc (4 hunks)
  • src/transform/lower_tile_op.cc (1 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (7 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (9 hunks)
  • tilelang/intrinsics/mma_sm70_macro_generator.py (6 hunks)
  • tilelang/intrinsics/tcgen05_macro_generator.py (2 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (9 hunks)
  • tilelang/language/builtin.py (3 hunks)
  • tilelang/language/gemm.py (5 hunks)
  • tilelang/layout/swizzle.py (5 hunks)
  • tilelang/tileop/gemm/__init__.py (2 hunks)
  • tilelang/tileop/gemm/gemm_mma.py (6 hunks)
  • tilelang/tileop/gemm/gemm_mma_sm70.py (4 hunks)
  • tilelang/tileop/gemm/gemm_wgmma.py (3 hunks)
  • tilelang/utils/__init__.py (1 hunks)
  • tilelang/utils/language.py (3 hunks)
✅ Files skipped from review due to trivial changes (1)
  • .gitignore
🚧 Files skipped from review as they are similar to previous changes (2)
  • tilelang/tileop/gemm/gemm_wgmma.py
  • tilelang/tileop/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/op/gemm.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/op/gemm.cc
🧬 Code graph analysis (17)
tilelang/language/gemm.py (2)
tilelang/utils/language.py (6)
  • to_buffer_region (187-211)
  • retrieve_shape (214-231)
  • retrieve_stride (234-252)
  • retrieve_ptr (285-329)
  • retrieve_offset (332-349)
  • prim_expr_equal (352-367)
tilelang/env.py (2)
  • get (175-178)
  • use_gemm_v1 (281-287)
tilelang/tileop/gemm/gemm_mma_sm70.py (4)
tilelang/utils/language.py (3)
  • is_shared (45-60)
  • is_fragment (105-116)
  • is_full_region (370-399)
tilelang/tileop/gemm/gemm_base.py (4)
  • ARegion (79-80)
  • BRegion (83-84)
  • CRegion (87-88)
  • is_gemm_rs (27-28)
tilelang/intrinsics/mma_sm70_macro_generator.py (2)
  • ldmatrix_b (234-284)
  • mma (286-327)
tilelang/tileop/gemm/gemm_mma.py (2)
  • _gemm_ssr (103-128)
  • is_gemm_rs (218-219)
tilelang/utils/__init__.py (1)
tilelang/utils/language.py (5)
  • retrieve_stride (234-252)
  • retrieve_shape (214-231)
  • retrive_ptr_from_buffer_region (255-282)
  • is_full_region (370-399)
  • to_buffer_region (187-211)
tilelang/intrinsics/mma_macro_generator.py (4)
tilelang/intrinsics/utils.py (1)
  • get_ldmatrix_offset (21-63)
tilelang/utils/language.py (2)
  • is_fragment (105-116)
  • to_buffer_region (187-211)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
  • extract_thread_binding (158-188)
  • _warp_ldmatrix_a (216-230)
  • ldmatrix_b (234-284)
tilelang/intrinsics/mfma_macro_generator.py (4)
  • extract_thread_binding (225-252)
  • _warp_ldmatrix_a (277-300)
  • ldmatrix_b (304-357)
  • ldmatrix_b (794-869)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
tilelang/env.py (1)
  • disable_cache (275-276)
tilelang/tileop/gemm/gemm_mma.py (3)
tilelang/utils/language.py (3)
  • is_shared (45-60)
  • is_fragment (105-116)
  • is_full_region (370-399)
tilelang/tileop/gemm/gemm_base.py (6)
  • ARegion (79-80)
  • BRegion (83-84)
  • CRegion (87-88)
  • is_gemm_sr (24-25)
  • is_gemm_rs (27-28)
  • is_gemm_rr (30-31)
tilelang/intrinsics/mma_macro_generator.py (6)
  • ldmatrix_b (285-368)
  • ldmatrix_b (892-1004)
  • mma (370-430)
  • mma (1006-1055)
  • mma (1060-1158)
  • mma (1163-1262)
tilelang/layout/swizzle.py (2)
src/layout/swizzle.h (1)
  • tvm (12-69)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
tilelang/language/builtin.py (2)
tilelang/language/ast/ir.py (1)
  • evaluate (1319-1331)
tilelang/language/tir/op.py (3)
  • tvm_access_ptr (651-676)
  • address_of (464-480)
  • type_annotation (635-648)
src/op/gemm_sp.cc (4)
src/op/gemm.cc (2)
  • computeWarpPartition (228-408)
  • computeWarpPartition (228-229)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (74-81)
  • GetVarFromAccessPtr (74-74)
src/op/gemm_sp.h (4)
  • GemmSPWarpPolicy (28-52)
  • GemmSPWarpPolicy (33-37)
  • GemmSPWarpPolicy (39-43)
  • GemmSPWarpPolicy (45-51)
src/layout/gemm_layouts.cc (10)
  • makeGemmFragmentCHopper (176-187)
  • makeGemmFragmentCHopper (176-178)
  • makeGemmFragmentC (121-136)
  • makeGemmFragmentC (121-123)
  • makeGemmABLayoutHopper (741-766)
  • makeGemmABLayoutHopper (741-742)
  • makeGemmSparseFragmentC (138-157)
  • makeGemmSparseFragmentC (138-140)
  • makeGemmSparseAmpereABLayout (683-688)
  • makeGemmSparseAmpereABLayout (683-684)
tilelang/intrinsics/tcgen05_macro_generator.py (1)
src/transform/lower_tile_op.cc (8)
  • access_ptr (287-385)
  • access_ptr (288-290)
  • buffer (272-280)
  • buffer (272-272)
  • buffer (401-418)
  • buffer (401-401)
  • buffer (420-437)
  • buffer (420-420)
maint/gemm_v2/correctness_evaluation.py (2)
maint/gemm_v2/correctness_evaluation_sm70.py (1)
  • run_gemm_rs (180-210)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)
  • run_gemm_rs (191-242)
src/op/gemm_py.cc (3)
src/op/gemm.cc (10)
  • NormalizeToBufferRegion (53-100)
  • NormalizeToBufferRegion (53-54)
  • allowTcgen5Mma (192-199)
  • allowTcgen5Mma (192-192)
  • allowWgmma (201-209)
  • allowWgmma (201-201)
  • checkWgmma (440-491)
  • checkWgmma (440-440)
  • getGemmInst (211-226)
  • getGemmInst (211-211)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (74-81)
  • GetVarFromAccessPtr (74-74)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
tilelang/utils/language.py (2)
  • is_fragment (105-116)
  • to_buffer_region (187-211)
tilelang/intrinsics/mfma_macro_generator.py (4)
  • _warp_ldmatrix_a (277-300)
  • ldmatrix_b (304-357)
  • ldmatrix_b (794-869)
  • _warp_ldmatrix_b (327-355)
tilelang/intrinsics/mma_macro_generator.py (8)
  • _warp_ldmatrix_a (244-281)
  • _warp_ldmatrix_a (800-888)
  • ldmatrix_b (285-368)
  • ldmatrix_b (892-1004)
  • mma_load_layout (223-224)
  • mma_load_layout (310-311)
  • _warp_ldmatrix_b (322-366)
  • _warp_ldmatrix_b (907-1002)
src/op/gemm.cc (4)
tilelang/language/utils.py (1)
  • region (8-27)
src/op/gemm_py.cc (13)
  • strides (85-85)
  • NormalizeToBufferRegion (25-72)
  • NormalizeToBufferRegion (25-26)
  • allowTcgen5Mma (192-199)
  • allowTcgen5Mma (192-192)
  • allowWgmma (201-209)
  • allowWgmma (201-201)
  • checkWgmma (261-312)
  • checkWgmma (261-261)
  • getGemmInst (211-229)
  • getGemmInst (211-211)
  • MakeAccessPtrFromRegion (78-108)
  • MakeAccessPtrFromRegion (78-79)
src/op/gemm.h (1)
  • Gemm (144-149)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (74-81)
  • GetVarFromAccessPtr (74-74)
tilelang/utils/language.py (4)
src/transform/lower_tile_op.cc (8)
  • buffer (272-280)
  • buffer (272-272)
  • buffer (401-418)
  • buffer (401-401)
  • buffer (420-437)
  • buffer (420-420)
  • access_ptr (287-385)
  • access_ptr (288-290)
src/transform/flatten_buffer.cc (6)
  • buffer (295-321)
  • buffer (295-296)
  • region (335-361)
  • region (335-335)
  • buf (227-247)
  • buf (227-227)
src/transform/storage_rewrite.cc (9)
  • e (862-904)
  • e (862-862)
  • i (720-860)
  • buf (249-256)
  • buf (249-249)
  • buf (509-526)
  • buf (509-509)
  • buf (1740-1765)
  • buf (1740-1740)
tilelang/language/utils.py (1)
  • region (8-27)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
  • is_fragment (105-116)
  • retrive_ptr_from_buffer_region (255-282)
  • is_full_region (370-399)
tilelang/language/builtin.py (2)
  • initialize_wgmma_descriptor (703-730)
  • warpgroup_fence_operand (433-592)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
  • ldmatrix_a (190-232)
  • _warp_ldmatrix_a (216-230)
  • ldmatrix_b (234-284)
tilelang/intrinsics/mma_macro_generator.py (6)
  • ldmatrix_a (207-283)
  • ldmatrix_a (786-890)
  • _warp_ldmatrix_a (244-281)
  • _warp_ldmatrix_a (800-888)
  • ldmatrix_b (285-368)
  • ldmatrix_b (892-1004)
🪛 Ruff (0.14.3)
tilelang/utils/__init__.py

13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


16-16: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/swizzle.py

28-29: Avoid specifying long messages outside the exception class

(TRY003)


94-94: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


111-111: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/language/builtin.py

466-466: Avoid specifying long messages outside the exception class

(TRY003)


482-482: Avoid specifying long messages outside the exception class

(TRY003)


514-514: Avoid specifying long messages outside the exception class

(TRY003)


543-545: Avoid specifying long messages outside the exception class

(TRY003)


574-574: Do not catch blind exception: Exception

(BLE001)


577-579: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/tcgen05_macro_generator.py

264-264: Prefer TypeError exception for invalid type

(TRY004)


264-264: Avoid specifying long messages outside the exception class

(TRY003)


276-276: Prefer TypeError exception for invalid type

(TRY004)


276-276: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/utils/language.py

27-28: Avoid specifying long messages outside the exception class

(TRY003)


211-211: Avoid specifying long messages outside the exception class

(TRY003)


229-229: Avoid specifying long messages outside the exception class

(TRY003)


231-231: Avoid specifying long messages outside the exception class

(TRY003)


245-245: Prefer TypeError exception for invalid type

(TRY004)


245-245: Avoid specifying long messages outside the exception class

(TRY003)


270-270: Prefer TypeError exception for invalid type

(TRY004)


270-270: Avoid specifying long messages outside the exception class

(TRY003)


282-282: Prefer TypeError exception for invalid type

(TRY004)


282-282: Avoid specifying long messages outside the exception class

(TRY003)


329-329: Avoid specifying long messages outside the exception class

(TRY003)


349-349: Avoid specifying long messages outside the exception class

(TRY003)


384-384: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/wgmma_macro_generator.py

274-274: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


386-386: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


386-386: Unpacked variable warp_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tilelang/intrinsics/mfma_macro_generator.py

297-297: Ambiguous variable name: l

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (1)
tilelang/layout/swizzle.py (1)

65-76: FFI should receive PrimExpr for dims, not Python ints

With the _get_stride_continuous fix, stride and continuous remain PrimExprs, which is the safer contract for TVM FFI. No further change needed in these call sites—just confirming the intent.

If any FFI expects raw ints, please confirm and we can adapt by tir.IntImm wrapping instead of int().

Also applies to: 80-89, 93-107, 110-124, 126-147, 150-171, 174-195, 198-215

import tilelang
import tilelang.language as T

tilelang.disable_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Avoid top-level side effects; keep a single cache toggle in main().

Calling tilelang.disable_cache() at import time affects anyone importing this module. It’s also duplicated at Line 57. Remove the top-level call and keep a single, gated call in main().

Apply this minimal change:

- tilelang.disable_cache()

Optional verification (ensures a top-level wrapper exists and avoids surprises at runtime):


🏁 Script executed:

#!/bin/bash
# Verify a public wrapper exists and locate its definition(s)
rg -nP -C2 'def\s+disable_cache\s*\('
rg -nP -C2 'disable_cache\s*=\s*.*\.disable_cache'

Length of output: 195


🏁 Script executed:

#!/bin/bash
set -e
# Locate tilelang package __init__.py
find . -type f -path "./tilelang/__init__.py" -print
# Show first 200 lines to check for disable_cache wrapper
sed -n '1,200p' tilelang/__init__.py

Length of output: 3702


🏁 Script executed:

#!/bin/bash
# Display lines 1–80 of the example to confirm top-level call and context
sed -n '1,80p' examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py

Length of output: 2815


Remove top-level disable_cache call
Delete the import-time tilelang.disable_cache() (line 4) and rely solely on the call inside main() to avoid unexpected side effects on import.
Optional: Gate the print(jit_kernel.get_kernel_source()) behind a flag to prevent unneeded verbose output.

🤖 Prompt for AI Agents
In examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
around line 4, remove the top-level tilelang.disable_cache() call so importing
the module has no side effects and rely on the existing disable_cache()
invocation inside main(); additionally, optionally wrap or gate the
print(jit_kernel.get_kernel_source()) behind a verbosity or debug flag so kernel
source is only printed when explicitly requested.

Comment on lines +258 to +274
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tvm.tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Handle Ramp indices before the generic branch.

Ramp inherits from PrimExpr, so the first isinstance(..., PrimExpr) branch always fires. That feeds a vector expression into access_ptr, which is invalid for pointer offsets and breaks vectorized tiles. Please special-case Ramp before the generic branch and validate its stride.

-            if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
-                offset += indice * stride
-            elif isinstance(indice, tvm.tir.Ramp):
+            if isinstance(indice, tvm.tir.Ramp):
+                if not isinstance(indice.stride, tvm.tir.IntImm) or indice.stride.value != 1:
+                    raise ValueError(f"Unsupported ramp stride: {indice.stride}")
+                offset += indice.base * stride
+            elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
                 offset += indice * stride
             else:
                 raise ValueError(f"Unsupported index type: {type(indice)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tvm.tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, tvm.tir.Ramp):
if not isinstance(indice.stride, tvm.tir.IntImm) or indice.stride.value != 1:
raise ValueError(f"Unsupported ramp stride: {indice.stride}")
offset += indice.base * stride
elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
offset += indice * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
🧰 Tools
🪛 Ruff (0.14.3)

264-264: Prefer TypeError exception for invalid type

(TRY004)


264-264: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +191 to 193
a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Pass actual Buffer into _determinate_swizzle_mode
_determinate_swizzle_mode expects a Buffer. Feeding it A_region / B_region (BufferRegion) blows up as soon as the layout helpers touch .shape. This means the refactored wgmma path now crashes before issuing any instruction. Please pass region.buffer instead.

-        a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 191-193, the code
calls _determinate_swizzle_mode(A_region, ...) and
_determinate_swizzle_mode(B_region, ...) passing BufferRegion objects;
_determinate_swizzle_mode expects a Buffer and later accesses .shape which
causes a crash. Change those calls to pass the underlying Buffer (e.g.,
A_region.buffer and B_region.buffer) so the helper receives the correct type and
the layout helpers can access .shape.

Comment on lines +342 to 343
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same BufferRegion-to-Buffer fix needed here
The RS path is subject to the identical crash: _determinate_swizzle_mode still sees a BufferRegion. Please use the underlying buffer.

-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 342 to 343, the
call to self._determinate_swizzle_mode is being passed a BufferRegion (B_region)
which causes the same crash as earlier; change the call to pass the underlying
buffer object from the region (e.g., B_region.buffer or equivalent attribute
used elsewhere) so _determinate_swizzle_mode receives a Buffer, and ensure
related downstream uses (like computing b_swizzle_atom_elems) still reference
the region where appropriate.

Comment on lines 558 to 581
if dtype is None:
raise ValueError("dtype must be provided when passing a pointer expression.")
inferred = None
# Case 1: Pointer from Buffer.access_ptr -> tir.builtin.tvm_access_ptr
if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()):
# args[0] is a type annotation call; its dtype carries the element dtype
inferred = str(data_ptr.args[0].dtype)
# Case 2: Pointer from tir.address_of(BufferLoad(...))
elif isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.address_of()):
# args[0] should be a BufferLoad; its dtype is the element dtype
inferred = str(data_ptr.args[0].dtype)
# Case 3: Typed pointer Var with PrimType element (typed TIR)
elif hasattr(data_ptr, "type_annotation") and data_ptr.type_annotation is not None:
try:
elem_ty = getattr(data_ptr.type_annotation, "element_type", None)
if elem_ty is not None and hasattr(elem_ty, "dtype"):
inferred = str(elem_ty.dtype)
except Exception:
inferred = None
if inferred is None:
raise ValueError(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype = inferred
if num_regs is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix dtype inference for access_ptr pointers
When buffer_or_ptr comes from buffer.access_ptr(...), data_ptr.op is tvm_access_ptr, but data_ptr.args[0].dtype is just "handle"/"object", not the underlying element dtype. We end up fencing with "handle" which breaks any call that relies on automatic dtype inference (e.g. passing a pointer without an explicit dtype now regresses). Please extract the actual element dtype from the type-annotation payload instead of the wrapper call. For example:

-            if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()):
-                inferred = str(data_ptr.args[0].dtype)
+            if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()):
+                type_anno = data_ptr.args[0]
+                if isinstance(type_anno, Call) and type_anno.op.same_as(tir.builtin.type_annotation()):
+                    inferred = str(type_anno.args[0])
+                elif hasattr(type_anno, "value"):
+                    inferred = str(type_anno.value)

(Adapt as needed for other concrete node types.)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.3)

574-574: Do not catch blind exception: Exception

(BLE001)


577-579: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +32 to +47
def _get_stride_continuous(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
"""
Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
tuple: (stride, continuous) as integers
"""
_, shape, _ = _get_buffer_info(buffer_or_load_or_region)
stride = int(shape[-2])
continuous = int(shape[-1])
return stride, continuous

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not cast PrimExpr to int; handle rank>=2 and keep symbolic dims

Casting shape[-2]/shape[-1] to int breaks when shapes are symbolic; also no guard for 1D buffers. Return PrimExprs and validate rank.

-def _get_stride_continuous(
-        buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
+def _get_stride_continuous(
+        buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[PrimExpr, PrimExpr]:
@@
-    _, shape, _ = _get_buffer_info(buffer_or_load_or_region)
-    stride = int(shape[-2])
-    continuous = int(shape[-1])
-    return stride, continuous
+    _, shape, _ = _get_buffer_info(buffer_or_load_or_region)
+    if len(shape) < 2:
+        raise ValueError("Buffer rank must be >= 2 for swizzled layouts")
+    stride = shape[-2]
+    continuous = shape[-1]
+    return stride, continuous
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _get_stride_continuous(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
"""
Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
tuple: (stride, continuous) as integers
"""
_, shape, _ = _get_buffer_info(buffer_or_load_or_region)
stride = int(shape[-2])
continuous = int(shape[-1])
return stride, continuous
def _get_stride_continuous(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[PrimExpr, PrimExpr]:
"""
Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
tuple: (stride, continuous) as integers
"""
_, shape, _ = _get_buffer_info(buffer_or_load_or_region)
if len(shape) < 2:
raise ValueError("Buffer rank must be >= 2 for swizzled layouts")
stride = shape[-2]
continuous = shape[-1]
return stride, continuous
🤖 Prompt for AI Agents
In tilelang/layout/swizzle.py around lines 32 to 47, the function currently
casts shape[-2] and shape[-1] to int which breaks for symbolic PrimExpr dims and
doesn't guard against 1D buffers; change the function to (1) validate that the
buffer rank is >= 2 and raise a clear ValueError if not, (2) stop casting to int
and return the raw PrimExprs for stride and continuous (update the return type
to tuple[PrimExpr, PrimExpr] and import PrimExpr if needed), and (3) keep the
rest of the logic unchanged so callers receive symbolic dimensions when present.

Comment on lines 78 to 84
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion

A_buf = A_region.buffer
C_buf = C_region.buffer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Call GemmBase region accessors

self.ARegion, self.BRegion, and self.CRegion are methods on GemmBase. Taking them without parentheses leaves bound method objects, so the very next line (A_region.buffer) raises 'function' object has no attribute "buffer"' at runtime. Invoke the accessors before using the regions.

Apply this diff:

-        A_region = self.ARegion
-        B_region = self.BRegion
-        C_region = self.CRegion
+        A_region = self.ARegion()
+        B_region = self.BRegion()
+        C_region = self.CRegion()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
C_buf = C_region.buffer
A_region = self.ARegion()
B_region = self.BRegion()
C_region = self.CRegion()
A_buf = A_region.buffer
C_buf = C_region.buffer
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma_sm70.py around lines 78 to 84, the code assigns
self.ARegion, self.BRegion, and self.CRegion (bound method objects) to variables
without calling them, so accessing A_region.buffer raises "'function' object has
no attribute \"buffer\""; fix by invoking the accessors (call self.ARegion(),
self.BRegion(), self.CRegion()) when assigning A_region, B_region, and C_region
so the variables hold region objects before using their .buffer attributes.

Comment on lines 86 to 95
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion

A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Invoke region getters before dereferencing

Just like on SM70, GemmBase.ARegion/BRegion/CRegion are plain methods. Capturing them without calling leaves bound method objects, so A_region.buffer, B_region.buffer, etc. explode with 'function' object has no attribute "buffer"'. This breaks every lowering path.

Please update to:

-        A_region = self.ARegion
-        B_region = self.BRegion
-        C_region = self.CRegion
+        A_region = self.ARegion()
+        B_region = self.BRegion()
+        C_region = self.CRegion()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion()
B_region = self.BRegion()
C_region = self.CRegion()
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma.py around lines 86 to 95, the code captures
ARegion/BRegion/CRegion methods without invoking them so
A_region/B_region/C_region are bound function objects and accessing .buffer
fails; call the getters (e.g., A_region = self.ARegion()) for A, B, and C before
accessing .buffer (and then use A_region.buffer, B_region.buffer,
C_region.buffer) so the actual region objects are used.

Comment on lines +234 to +253
def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve row-major strides for a buffer-like object based on its buffer.shape.
For BufferRegion and BufferLoad, uses the underlying buffer's `shape`.
"""
if isinstance(obj, tir.Buffer):
shape = obj.shape
elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
shape = obj.buffer.shape
else:
raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")

strides = []
stride = 1
for s in reversed(shape):
strides.insert(0, stride)
stride *= s
return strides

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Honor explicit Buffer.strides when available; fall back to row-major

Current implementation assumes row-major strides. Use buffer.strides if defined; otherwise compute from shape.

-def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
+def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
@@
-    if isinstance(obj, tir.Buffer):
-        shape = obj.shape
-    elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
-        shape = obj.buffer.shape
+    if isinstance(obj, tir.Buffer):
+        buf = obj
+    elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
+        buf = obj.buffer
     else:
         raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
 
-    strides = []
-    stride = 1
-    for s in reversed(shape):
-        strides.insert(0, stride)
-        stride *= s
-    return strides
+    # Prefer explicit strides if provided
+    if buf.strides and len(buf.strides) == len(buf.shape):
+        return list(buf.strides)
+    # Fallback: derive row-major strides from shape
+    strides: list[PrimExpr] = []
+    stride: PrimExpr = tir.IntImm("int32", 1)
+    for s in reversed(buf.shape):
+        strides.insert(0, stride)
+        stride = stride * s
+    return strides
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve row-major strides for a buffer-like object based on its buffer.shape.
For BufferRegion and BufferLoad, uses the underlying buffer's `shape`.
"""
if isinstance(obj, tir.Buffer):
shape = obj.shape
elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
shape = obj.buffer.shape
else:
raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
strides = []
stride = 1
for s in reversed(shape):
strides.insert(0, stride)
stride *= s
return strides
def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve row-major strides for a buffer-like object based on its buffer.shape.
For BufferRegion and BufferLoad, uses the underlying buffer's `shape`.
"""
if isinstance(obj, tir.Buffer):
buf = obj
elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
buf = obj.buffer
else:
raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
# Prefer explicit strides if provided
if buf.strides and len(buf.strides) == len(buf.shape):
return list(buf.strides)
# Fallback: derive row-major strides from shape
strides: list[PrimExpr] = []
stride: PrimExpr = tir.IntImm("int32", 1)
for s in reversed(buf.shape):
strides.insert(0, stride)
stride = stride * s
return strides
🧰 Tools
🪛 Ruff (0.14.3)

245-245: Prefer TypeError exception for invalid type

(TRY004)


245-245: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/utils/language.py around lines 234 to 253, the function always
computes row-major strides from shape; change it to use an explicit
buffer.strides when present and only compute row-major strides as a fallback.
Concretely: for tir.Buffer use buffer.strides if buffer.strides is not
None/empty and convert it to a plain list of ints to return; for
tir.BufferRegion and tir.BufferLoad read obj.buffer.strides first and fall back
to computing from obj.buffer.shape if strides are missing; keep the existing
reversed-shape computation logic as the fallback and raise the same ValueError
for unsupported types.

Comment on lines +255 to +283
def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tir.IntImm, tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Fix function name typo and provide a compatibility alias

Rename retrive_ptr_from_buffer_region -> retrieve_ptr_from_buffer_region and keep an alias to avoid breaking callers.

-def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
-                                   access_type: str = "r") -> PrimExpr:
+def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
+                                    access_type: str = "r") -> PrimExpr:
@@
-    else:
-        raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+    else:
+        raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+
+# Backward-compat alias (to be deprecated)
+retrive_ptr_from_buffer_region = retrieve_ptr_from_buffer_region

Run to find potential external/internal references to the old name:


🏁 Script executed:

#!/bin/bash
rg -nP '\bretrive_ptr_from_buffer_region\b' -C2

Length of output: 2188


Fix typo in function name, error type, and add backward-compat alias

Rename retrive_ptr_from_buffer_regionretrieve_ptr_from_buffer_region, change the final exception to TypeError, and add a backward-compat alias:

-def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
-                                   access_type: str = "r") -> PrimExpr:
+def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
+                                    access_type: str = "r") -> PrimExpr:
@@
-    else:
-        raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+    else:
+        raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+
+# Backward-compat alias
+retrive_ptr_from_buffer_region = retrieve_ptr_from_buffer_region

This preserves the existing __init__.py import and all calls in wgmma_macro_generator.py.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tir.IntImm, tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tir.IntImm, tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
# Backward-compat alias
retrive_ptr_from_buffer_region = retrieve_ptr_from_buffer_region
🧰 Tools
🪛 Ruff (0.14.3)

270-270: Prefer TypeError exception for invalid type

(TRY004)


270-270: Avoid specifying long messages outside the exception class

(TRY003)


282-282: Prefer TypeError exception for invalid type

(TRY004)


282-282: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/utils/language.py around lines 255 to 283, rename the function
retrive_ptr_from_buffer_region to retrieve_ptr_from_buffer_region, change the
final raised exception from ValueError to TypeError for unsupported input types,
and add a backward-compatible alias so retrive_ptr_from_buffer_region points to
the new retrieve_ptr_from_buffer_region function (preserving existing
imports/calls).

- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection.
- Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations.
- Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage.
- Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms.
- Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/op/gemm.cc (1)

121-125: Fix pointer offset when rebuilding access_ptr.

Identical to the issue flagged in gemm_py.cc: skipping the last two dimensions in the offset sum causes every reconstructed pointer to land at element (0,0) for typical 2D regions, so all GEMM tiles operate on the wrong data. The loop must include the minima of every dimension when computing offset.

Apply this diff:

-  for (int i = 0; i < ndim - 2; ++i) {
+  for (int i = 0; i < ndim; ++i) {
     offset = offset + region->region[i]->min * strides[i];
   }
🧹 Nitpick comments (6)
tilelang/intrinsics/mma_macro_generator.py (3)

144-168: Consider adding assertions for FP64 path.

The FP64 micro-size initialization logic is correct, but unlike the non-FP64 path (lines 153-156), the FP64 path lacks assertions to validate warp_row_tiles and warp_col_tiles constraints. For FP64 (m8n8k4), you should verify:

  • warp_row_tiles >= 8 and warp_row_tiles % 8 == 0
  • warp_col_tiles >= 8 and warp_col_tiles % 8 == 0

Apply this diff to add validation:

     if k_dim == 4:
         # fp64 path: m_dim must be 8, n_dim 8
         assert m_dim == 8, f"For fp64 MMA, m_dim must be 8, got {m_dim}"
+        assert warp_row_tiles >= 8, f"warp_row_tiles must be >= 8 for fp64, got {warp_row_tiles}"
+        assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8 for fp64, got {warp_row_tiles}"
+        assert warp_col_tiles >= 8, f"warp_col_tiles must be >= 8 for fp64, got {warp_col_tiles}"
+        assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8 for fp64, got {warp_col_tiles}"
         self.n_dim = 8
         self.micro_size_y = 8
         self.warp_rows = warp_row_tiles // m_dim
         self.warp_cols = warp_col_tiles // 8

234-270: FP64 fast path implementation looks correct, but consider the unused parameter.

The FP64 fast path correctly implements direct per-lane loads since PTX ldmatrix doesn't support FP64. The index calculations (mi = tx // micro_size_k, mk = tx % micro_size_k) and handling of both transposed/non-transposed cases appear correct.

However, static analysis correctly identifies that the A_shared_buf parameter in the inner macro _warp_ld_a_fp64 (line 254) is unused. The macro only uses A_local_buf, ki, thread_binding, and rk, while the actual buffer access uses the captured A_buf and base offsets.

If the parameter is kept for API consistency with other macro variants, consider adding a comment explaining this. Otherwise, remove it from the signature.

Apply this diff to remove the unused parameter:

         @T.macro
         def _warp_ld_a_fp64(
             A_local_buf,
-            A_shared_buf,
             ki,
             thread_binding,
             rk=0,
         ):

And update the return statement:

-        return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk)
+        return _warp_ld_a_fp64(A_local_buf, ki, thread_binding, rk)

351-387: FP64 fast path is correct, but has an unused parameter.

Similar to ldmatrix_a, the FP64 path for matrix B correctly implements direct per-lane loads. The index calculations and transpose handling are appropriate.

The same unused parameter issue exists: B_shared_buf in the inner macro _warp_ld_b_fp64 (line 371) is not used. Consider removing it or documenting why it's kept for API consistency.

Apply this diff:

         @T.macro
         def _warp_ld_b_fp64(
             B_local_buf,
-            B_shared_buf,
             ki,
             thread_binding,
             rk=0,
         ):

And update the return:

-        return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk)
+        return _warp_ld_b_fp64(B_local_buf, ki, thread_binding, rk)
tilelang/intrinsics/mma_layout.py (1)

48-51: Logic is correct; consider adding documentation.

The math for the 8×8 FP64 layout mapping is correct:

  • 32 threads with 2 local elements each map to a 64-element (8×8) shared memory tile
  • Row assignment groups every 4 threads together (rows 0-7)
  • Column assignment interleaves within thread groups (cols 0-7)

Consider adding a docstring to explain the layout strategy and expected parameter ranges (thread_id: 0-31, local_id: 0-1).

Optional: Add docstring for clarity.

 def mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id):
+    """
+    Map thread_id and local_id to shared memory layout for FP64 MMA store.
+    
+    Maps 32 threads × 2 local elements to an 8×8 shared memory tile.
+    
+    Args:
+        thread_id: Thread index within warp (0-31)
+        local_id: Local element index (0-1)
+    
+    Returns:
+        tuple: (row, col) indices in the 8×8 shared memory layout
+    """
     row = thread_id // 4
     col = (thread_id % 4) * 2 + local_id
     return row, col
src/op/gemm.cc (2)

87-89: Minor: prefer .at() for consistency with gemm_py.cc.

Line 89 uses vmap[var] while the equivalent code in gemm_py.cc:89 uses vmap.at(var). Using .at() is more explicit about lookup semantics and provides clearer error messages when a key is missing.

Apply this diff:

-      Buffer buf = vmap[var];
+      Buffer buf = vmap.at(var);

159-164: Update stale documentation comment.

The comment at line 47 states that kPack "must be 1", but the code at line 161 accepts both 1 and 2. The comment should be updated to reflect the actual constraint.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c0c45d6 and 7aeb963.

📒 Files selected for processing (5)
  • src/op/gemm.cc (9 hunks)
  • src/tl_templates/cuda/instruction/mma.h (1 hunks)
  • tilelang/intrinsics/mma_layout.py (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (15 hunks)
  • tilelang/intrinsics/utils.py (2 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/op/gemm.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/op/gemm.cc
🧬 Code graph analysis (3)
tilelang/intrinsics/utils.py (1)
tilelang/intrinsics/mma_layout.py (1)
  • mma_store_32x2_to_shared_8x8_layout_fp64 (48-51)
tilelang/intrinsics/mma_macro_generator.py (3)
tilelang/intrinsics/utils.py (3)
  • mma_store_index_map (82-83)
  • get_ldmatrix_offset (22-64)
  • mma_store_index_map_fp64 (86-87)
tilelang/utils/language.py (2)
  • is_fragment (105-116)
  • to_buffer_region (187-211)
tilelang/intrinsics/mma_sm70_macro_generator.py (4)
  • get_thread_binding (139-145)
  • extract_thread_binding (158-188)
  • _warp_ldmatrix_a (216-230)
  • ldmatrix_b (234-284)
src/op/gemm.cc (5)
tilelang/language/utils.py (1)
  • region (8-27)
tilelang/language/tir/op.py (1)
  • tvm_access_ptr (651-676)
src/op/gemm_py.cc (13)
  • strides (85-85)
  • NormalizeToBufferRegion (25-72)
  • NormalizeToBufferRegion (25-26)
  • allowTcgen5Mma (192-199)
  • allowTcgen5Mma (192-192)
  • allowWgmma (201-209)
  • allowWgmma (201-201)
  • getGemmInst (211-229)
  • getGemmInst (211-211)
  • MakeAccessPtrFromRegion (78-108)
  • MakeAccessPtrFromRegion (78-79)
  • GetArchInt (328-339)
  • GetArchInt (328-328)
src/op/gemm.h (5)
  • Gemm (144-149)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (74-81)
  • GetVarFromAccessPtr (74-74)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mma_macro_generator.py

254-254: Unused function argument: A_shared_buf

(ARG001)


371-371: Unused function argument: B_shared_buf

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (14)
src/tl_templates/cuda/instruction/mma.h (1)

139-141: LGTM! FP64 DMMA dispatcher implementation verified and safe.

The cute::SM80_8x8x4_F64F64F64F64_TN implementation is already in use throughout the codebase in src/tl_templates/cuda/gemm_mma.h (lines 55, 68, 80, 92, 100). The code addition follows the exact same established pattern with no new dependencies introduced. CuTe is managed as a git submodule and is a proven dependency.

tilelang/intrinsics/mma_macro_generator.py (8)

6-7: LGTM!

The import additions for BufferRegion and to_buffer_region are appropriate for the BufferRegion integration feature. These enable the methods to accept both Buffer and BufferRegion as inputs.

Also applies to: 13-13


44-44: LGTM!

The FP64 dtype abbreviation is correctly added and follows the established pattern for other data types.


83-87: LGTM!

The FP64 initialization correctly overrides M_DIM to 8 for the m8n8k4 MMA shape. The comment clearly explains the rationale.


126-128: LGTM!

The FP64 MMA prefix "m8n8k4" is correctly set for k_dim == 4, matching the FP64 matrix multiplication instruction shape.


186-191: LGTM!

The conditional selection of the FP64-specific store index map (mma_store_index_map_fp64) based on the accumulator dtype is correct. This accounts for the different memory layout of m8n8k4 FP64 MMA results.


296-343: LGTM!

The BufferRegion integration is well-implemented:

  • Properly normalizes inputs using to_buffer_region
  • Correctly extracts base offsets from the last two dimensions (region[-2].min, region[-1].min)
  • Updates both ldmatrix and fallback paths to use region-based addressing
  • Handles transposed and non-transposed layouts correctly

The multi-line buffer element assignments (lines 318-320, 336-341) improve readability for complex indexing expressions.


399-467: LGTM!

The BufferRegion integration for matrix B mirrors the approach in ldmatrix_a and is correctly implemented. Region-based addressing is properly applied to both the ldmatrix path (lines 440-442) and the fallback path (lines 460-465), with correct handling of transposed and non-transposed layouts.


753-754: LGTM!

The improved assertion message now includes the buffer name and actual scope, making debugging easier when a non-fragment buffer is incorrectly passed.

tilelang/intrinsics/utils.py (2)

11-11: LGTM: Import added correctly.

The import for the FP64 layout function is properly placed and formatted.


86-87: New FP64 wrapper function is unused; verify integration plan with mma_macro_generator.py.

The function is syntactically correct and properly delegates to mma_store_32x2_to_shared_8x8_layout_fp64, but it's not called anywhere in the codebase. The standard mma_store_index_map is actively used in mma_macro_generator.py (lines 558, 576), but there are no corresponding FP64 calls. Confirm whether mma_macro_generator.py should be updated to support FP64 variants, or if this wrapper is preparatory work not yet ready for integration.

src/op/gemm.cc (3)

192-224: LGTM: method renames and refactoring.

The method renames (allowTcgen5Mma, allowWgmma, getGemmInst, checkWgmma) follow consistent camelCase naming conventions, and the refactored implementation correctly uses the new region-based fields (a_, b_, c_, transA_, transB_, m_, n_, k_, etc.).


538-540: These pointers will have incorrect offsets until MakeAccessPtrFromRegion is fixed.

The calls to MakeAccessPtrFromRegion on these lines will produce incorrect offsets due to the bug flagged in the review comment for lines 121-125. Once that bug is fixed, these calls will work correctly.


683-904: LGTM: InferLayout refactoring.

The InferLayout method correctly uses all renamed fields (c_, a_, b_, m_, n_, k_, transA_, transB_, kPack_, etc.) consistently across all target architectures. The refactoring is complete and the logic appears sound.

…ling

- Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency.
- Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope.
- Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types.
- Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)

191-192: Fix BufferRegion-to-Buffer mismatch (still unfixed from previous review).

_determinate_swizzle_mode expects a Buffer and accesses .shape (line 152), but you're passing A_region and B_region which are BufferRegion objects. This will raise an AttributeError at runtime.

Apply this diff to pass the underlying buffers:

-        a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)

Based on past review comments.


344-344: Fix BufferRegion-to-Buffer mismatch in wgmma_rs (still unfixed from previous review).

Same issue as in wgmma: _determinate_swizzle_mode expects a Buffer but receives a BufferRegion. This will crash when the method accesses .shape.

Apply this diff:

-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)

Based on past review comments.

🧹 Nitpick comments (4)
tilelang/intrinsics/wgmma_macro_generator.py (2)

274-274: Prefix unused variable with underscore.

The extracted tx variable is never used. Consider prefixing it with an underscore to indicate it's intentionally unused.

Apply this diff:

-            tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+            _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)

388-388: Prefix unused variables with underscores.

The extracted tx and warp_m variables are never used in the macro. Consider prefixing them with underscores to indicate they're intentionally unused.

Apply this diff:

-            tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+            _tx, warp_n, _warp_m = self.extract_thread_binding(thread_binding)
src/transform/merge_shared_memory_allocations.cc (1)

357-372: LGTM! Well-structured helper for alignment marking.

The helper correctly guards against null pointers and non-alignment scopes, properly extracts storage scope information, and applies appropriate alignment based on the target architecture.

Consider extracting the hardcoded alignment values as named constants for better maintainability:

+private:
+  static constexpr int kHopperAlignment = 1024;
+  static constexpr int kDefaultAlignment = 16;
+
   // Helper to record alignment for a shared/shared.dyn Var under alignment
   // scope
   void MarkSharedVarIfNeeded(const VarNode *op) {
     if (!op || !under_alignment_scope_)
       return;
     auto ptr_type = op->type_annotation.as<PointerTypeNode>();
     if (!ptr_type)
       return;
     auto scope = GetPtrStorageScope(tvm::ffi::GetRef<Var>(op));
     if (scope == "shared" || scope == "shared.dyn") {
       auto target = Target::Current();
       ICHECK(target.defined()) << "Target is not defined";
-      const int alignment = TargetIsHopper(target) ? 1024 : 16;
+      const int alignment = TargetIsHopper(target) ? kHopperAlignment : kDefaultAlignment;
       shmem_alignment_map_[op] = alignment;
     }
   }
maint/gemm_v2/correctness_evaluation.py (1)

726-735: Consider removing commented test code.

The commented-out test blocks at the end of the file appear to be development or debugging artifacts. If they're no longer needed, consider removing them to keep the codebase clean. If they serve a purpose (e.g., manual testing scenarios), consider moving them to a separate debug script or adding a comment explaining why they're preserved.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7aeb963 and 27ba821.

📒 Files selected for processing (3)
  • maint/gemm_v2/correctness_evaluation.py (7 hunks)
  • src/transform/merge_shared_memory_allocations.cc (2 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (10 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/merge_shared_memory_allocations.cc (2)
tilelang/language/builtin.py (3)
  • tma_load (86-95)
  • initialize_wgmma_descriptor (703-730)
  • initialize_tcgen05_descriptor (733-764)
src/transform/storage_access.cc (4)
  • VisitExpr_ (40-62)
  • VisitExpr_ (40-40)
  • VisitExpr_ (314-453)
  • VisitExpr_ (314-314)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
  • is_fragment (105-116)
  • retrive_ptr_from_buffer_region (255-282)
  • is_full_region (370-399)
tilelang/language/builtin.py (2)
  • initialize_wgmma_descriptor (703-730)
  • warpgroup_fence_operand (433-592)
maint/gemm_v2/correctness_evaluation.py (2)
maint/gemm_v2/correctness_evaluation_sm70.py (1)
  • run_gemm_rs (180-210)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)
  • run_gemm_rs (191-242)
  • run_gemm_sr (326-377)
  • run_gemm_rr (465-516)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py

274-274: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


388-388: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


388-388: Unpacked variable warp_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (8)
tilelang/intrinsics/wgmma_macro_generator.py (3)

7-8: LGTM: Region-based imports added correctly.

The imports properly support the BufferRegion refactoring, including the helper function retrive_ptr_from_buffer_region which is now imported from utils instead of being defined locally.


266-271: LGTM: Correct region-to-pointer/buffer extraction.

The pointer extraction and buffer setup properly utilize the imported helpers. The full-region assertion for the output fragment is appropriate.


380-385: LGTM: Correct region assertions and buffer extraction.

The full-region assertions for fragment inputs/outputs and the subsequent buffer/pointer extraction are properly implemented.

src/transform/merge_shared_memory_allocations.cc (3)

374-387: LGTM! Correct intrinsic updates for alignment scope detection.

The updated condition properly recognizes the normalized descriptor initialization intrinsics (initialize_wgmma_descriptor and initialize_tcgen05_descriptor) that require stricter shared memory alignment. This aligns with the refactoring to use more generic intrinsics across GEMM paths.


389-392: LGTM! Correct handling of direct variable references.

The visitor properly marks shared memory variables when they're directly referenced within an alignment scope before delegating to the base class.


394-402: LGTM! Proper handling of BufferLoad nodes.

The new visitor correctly handles the case where buffers are accessed via BufferLoad by marking the underlying data variable for alignment when within an alignment scope. The explanatory comment clearly describes the intent, and the implementation properly delegates to the base visitor.

maint/gemm_v2/correctness_evaluation.py (2)

49-49: LGTM! Kernel selection now respects environment variable.

The change from T.gemm_v2 to T.gemm in the basic matmul variant allows the system to select between GEMM v1 and v2 based on the TILELANG_USE_GEMM_V1 environment variable, which aligns with the PR objectives. The other matmul variants (rs, sr, rr) correctly continue to use T.gemm_v2 explicitly.


439-483: LGTM! Cleaner code by relying on defaults.

The wrapper functions have been appropriately simplified to rely on default parameter values. This makes the code more maintainable and ensures consistency with the updated defaults.

Comment on lines +105 to +109
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Document the num_stages override logic and verify consistency.

Three concerns:

  1. Default change inconsistency: The num_stages default is changed from 3 to 2 here, but the relevant code snippets show that similar files (correctness_evaluation_sm70.py and test_tilelang_tilelibrary_gemm.py) still use the default value of 3. Verify this divergence is intentional.

  2. Undocumented conditional override: The logic that sets num_stages=0 when any block dimension is >= 256 completely disables pipelining for large blocks, but there's no comment explaining why. Add a comment describing the rationale (e.g., "Disable pipelining for large blocks to avoid resource constraints" or similar).

  3. Code duplication: This conditional logic is duplicated in run_gemm, run_gemm_rs, run_gemm_sr, and run_gemm_rr. Consider extracting it into a helper function.


🏁 Script executed:

#!/bin/bash
# Description: Check if other test files have inconsistent num_stages defaults

rg -n "num_stages\s*=\s*[0-9]" --type=py -g "*gemm*" -g "*correctness*" -C2

Length of output: 31867


Add documentation for num_stages logic and verify if default change is intentional.

The script output confirms all three concerns:

  1. Verified inconsistency: correctness_evaluation.py uses num_stages=2 (lines 105, 193, 279, 369), while correctness_evaluation_sm70.py and test_tilelang_tilelibrary_gemm.py use num_stages=3. Confirm whether this divergence is intentional.

  2. Verified lack of documentation: The conditional if block_N >= 256 or block_M >= 256 or block_K >= 256: num_stages = 0 appears identically at lines 108–109, 196–197, 282–283, and 372–373 with no explanatory comments. Add inline documentation describing the rationale (e.g., resource constraints, memory limitations, or performance tuning).

  3. Verified code duplication: This conditional block is duplicated across run_gemm, run_gemm_rs, run_gemm_sr, and run_gemm_rr. Extract into a helper function to reduce duplication and improve maintainability.

🤖 Prompt for AI Agents
maint/gemm_v2/correctness_evaluation.py lines 105-109: The num_stages default
and repeated conditional are inconsistent and undocumented; confirm whether
default num_stages=2 (here and at lines 193, 279, 369) should match other files
that use 3, add an inline comment explaining why we force num_stages=0 when any
block dimension >=256 (e.g., register/shared-memory/resource limits or
performance reasons), and remove duplication by extracting the conditional into
a single helper (e.g., determine_num_stages(block_M, block_N, block_K,
default_num_stages)) used by run_gemm, run_gemm_rs, run_gemm_sr, and run_gemm_rr
so all call the helper and the rationale is documented in the helper's
docstring.

- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations.
- Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication.
- Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.
…cture and function calls

- Updated import statements to directly reference modules instead of individual test functions, enhancing clarity.
- Modified function calls to use the new module structure for better organization and maintainability in testing examples.
- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations.
- Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tilelang/tileop/gemm/gemm_mma_sm70.py (1)

78-85: Call region accessor methods before dereferencing.

self.ARegion, self.BRegion, and self.CRegion are methods on GemmBase. Taking them without parentheses leaves bound method objects, so the very next line (A_region.buffer) raises AttributeError: 'function' object has no attribute 'buffer' at runtime.

Apply this diff:

-        A_region = self.ARegion
-        B_region = self.BRegion
-        C_region = self.CRegion
+        A_region = self.ARegion()
+        B_region = self.BRegion()
+        C_region = self.CRegion()
tilelang/tileop/gemm/gemm_mma.py (1)

88-96: Invoke region getters before dereferencing.

Just like the other GEMM files, GemmBase.ARegion/BRegion/CRegion are methods. Capturing them without calling leaves bound method objects, so A_region.buffer, B_region.buffer, etc. will raise AttributeError: 'function' object has no attribute 'buffer' at runtime.

Apply this diff:

-        A_region = self.ARegion
-        B_region = self.BRegion
-        C_region = self.CRegion
+        A_region = self.ARegion()
+        B_region = self.BRegion()
+        C_region = self.CRegion()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 27ba821 and 15035cd.

📒 Files selected for processing (6)
  • examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1 hunks)
  • src/transform/storage_rewrite.cc (1 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (7 hunks)
  • tilelang/tileop/gemm/gemm_mfma.py (6 hunks)
  • tilelang/tileop/gemm/gemm_mma.py (6 hunks)
  • tilelang/tileop/gemm/gemm_mma_sm70.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/transform/storage_rewrite.cc (1)
src/transform/merge_shared_memory_allocations.cc (2)
  • buffer (532-551)
  • buffer (532-532)
tilelang/tileop/gemm/gemm_mfma.py (4)
tilelang/utils/language.py (3)
  • is_shared (45-60)
  • is_fragment (105-116)
  • is_full_region (370-399)
tilelang/tileop/gemm/gemm_base.py (7)
  • ARegion (79-80)
  • BRegion (83-84)
  • CRegion (87-88)
  • clear_accum (107-108)
  • is_gemm_sr (24-25)
  • is_gemm_rs (27-28)
  • is_gemm_rr (30-31)
tilelang/language/fill.py (1)
  • clear (50-74)
tilelang/intrinsics/mfma_macro_generator.py (5)
  • ldmatrix_a (254-298)
  • ldmatrix_a (711-784)
  • ldmatrix_b (300-349)
  • ldmatrix_b (786-861)
  • mfma (351-394)
tilelang/intrinsics/mfma_macro_generator.py (3)
tilelang/utils/language.py (1)
  • to_buffer_region (187-211)
tilelang/intrinsics/mma_macro_generator.py (6)
  • ldmatrix_a (229-343)
  • ldmatrix_a (887-991)
  • _warp_ldmatrix_a (304-341)
  • _warp_ldmatrix_a (901-989)
  • ldmatrix_b (345-467)
  • ldmatrix_b (993-1105)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
  • ldmatrix_a (190-232)
  • _warp_ldmatrix_a (216-230)
  • ldmatrix_b (234-284)
tilelang/tileop/gemm/gemm_mma_sm70.py (4)
tilelang/utils/language.py (3)
  • is_shared (45-60)
  • is_fragment (105-116)
  • is_full_region (370-399)
tilelang/tileop/gemm/gemm_base.py (5)
  • ARegion (79-80)
  • BRegion (83-84)
  • CRegion (87-88)
  • clear_accum (107-108)
  • is_gemm_rs (27-28)
tilelang/language/fill.py (1)
  • clear (50-74)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
  • ldmatrix_a (190-232)
  • ldmatrix_b (234-284)
  • mma (286-327)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (5)
examples/deepseek_v32/sparse_mla_fwd.py (2)
  • sparse_mla_fwd (15-174)
  • test_sparse_mla_fwd (253-299)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
  • sparse_mla_fwd (18-311)
  • test_sparse_mla_fwd_pipelined (400-452)
examples/deepseek_v32/sparse_mla_bwd.py (2)
  • sparse_mla_bwd (283-320)
  • test_sparse_mla_bwd (334-384)
examples/deepseek_v32/topk_selector.py (1)
  • test_topk_selector (188-245)
examples/deepseek_v32/fp8_lighting_indexer.py (1)
  • test_fp8_lighting_indexer (260-302)
tilelang/tileop/gemm/gemm_mma.py (5)
tilelang/utils/language.py (3)
  • is_shared (45-60)
  • is_fragment (105-116)
  • is_full_region (370-399)
tilelang/tileop/gemm/gemm_base.py (7)
  • ARegion (79-80)
  • BRegion (83-84)
  • CRegion (87-88)
  • clear_accum (107-108)
  • is_gemm_sr (24-25)
  • is_gemm_rs (27-28)
  • is_gemm_rr (30-31)
tilelang/language/fill.py (1)
  • clear (50-74)
tilelang/intrinsics/mma_macro_generator.py (8)
  • ldmatrix_a (229-343)
  • ldmatrix_a (887-991)
  • ldmatrix_b (345-467)
  • ldmatrix_b (993-1105)
  • mma (469-529)
  • mma (1107-1156)
  • mma (1161-1259)
  • mma (1264-1363)
tilelang/tileop/gemm/gemm_mfma.py (7)
  • _gemm_ssr (107-133)
  • is_gemm_sr (224-225)
  • _gemm_srr (142-163)
  • is_gemm_rs (227-228)
  • _gemm_rsr (174-193)
  • _gemm_rsr (203-212)
  • is_gemm_rr (230-231)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py

293-293: Ambiguous variable name: l

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (8)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (2)

4-8: Clean refactor to module-level imports.

The change from direct function imports to module-level imports improves code clarity by making it explicit where each test function originates. The refactor is consistent across all five modules.


12-38: Function calls correctly updated to qualified access pattern.

All test function invocations have been consistently updated to use the module namespace (e.g., topk_selector.test_topk_selector()), aligning with the import style changes. The arguments passed to each function match the signatures verified in the codebase.

src/transform/storage_rewrite.cc (1)

1428-1451: LGTM! Merge logic supports multi-pass region-based buffer handling.

The merge logic correctly handles repeated buffer_var declarations by:

  • Promoting handle dtypes to concrete types when available
  • Updating extents if previously unknown
  • Combining declaration locations via bitwise OR

This aligns well with the broader refactor enabling region-based buffer representations across GEMM paths.

tilelang/tileop/gemm/gemm_mfma.py (3)

102-102: Good defensive check for fragment output.

The assertion ensures C_region is a full region before proceeding, preventing subtle bugs with partial region outputs.


115-116: Accumulator clearing logic is correctly gated.

The conditional T.clear(C_buf) when clear_accum is true provides flexibility for callers to control accumulator initialization, and is consistently applied across all GEMM paths.


139-139: Fragment input validation is appropriate.

The assertion that fragment inputs must be full regions (lines 139, 171, 199-200) prevents misuse and ensures the MMA intrinsics receive complete operands.

tilelang/intrinsics/mfma_macro_generator.py (2)

254-270: Region-based buffer handling is correctly implemented.

The updated signature A_shared_buf: Buffer | BufferRegion and the normalization via to_buffer_region() correctly supports both plain buffers and buffer regions, extracting base offsets for strided access. This maintains backward compatibility while enabling region-based GEMM operations.


287-296: Buffer indexing correctly applies region offsets.

The buffer accesses using A_buf[A_base0 + l + row, A_base1 + r + col] correctly incorporate the region's base offsets, enabling proper strided GEMM support.

Comment on lines +88 to +96
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion

A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Call region accessor methods before dereferencing.

self.ARegion, self.BRegion, and self.CRegion are methods (per GemmBase), not properties. Assigning them without parentheses captures bound method objects, so the next lines accessing .buffer will raise AttributeError: 'function' object has no attribute 'buffer' at runtime.

Apply this diff:

-        A_region = self.ARegion
-        B_region = self.BRegion
-        C_region = self.CRegion
+        A_region = self.ARegion()
+        B_region = self.BRegion()
+        C_region = self.CRegion()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion()
B_region = self.BRegion()
C_region = self.CRegion()
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 88 to 96, the code assigns
self.ARegion, self.BRegion, and self.CRegion (which are methods on GemmBase) to
variables and then tries to access .buffer, causing AttributeError; call the
accessor methods (e.g., self.ARegion(), self.BRegion(), self.CRegion()) before
reading .buffer so A_region = self.ARegion() etc., then use A_region.buffer,
B_region.buffer, C_region.buffer.

- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)

712-862: Fix critical parameter mismatch in ldmatrix_b super() call and add BufferRegion support to preshuffle paths.

Line 800 passes 6 arguments to parent's ldmatrix_b which only accepts 4 parameters. The parent class signature is def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0), but the child calls super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n). This causes a TypeError when b_preshuffle=False and pid_m/pid_n are passed.

Additionally, the preshuffle custom implementations (lines 803–859) directly access buffer parameters without normalizing via to_buffer_region(). If BufferRegion objects are passed to these paths, indexing will fail. Apply the same normalization pattern used in the parent class (lines 313–315) to ensure compatibility with both Buffer and BufferRegion inputs.

🧹 Nitpick comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)

288-297: Consider more descriptive variable names for l and r.

The buffer access logic correctly incorporates base offsets from the BufferRegion. However, the variable names l and r are ambiguous. Based on context, they appear to represent logical row and column indices.

Consider renaming to more descriptive names like logical_row and logical_col, or at minimum row_idx and col_idx.

                         row, col = T.meta_var(reverse_index_map(tx, local_id))
-                        l, r = (rk * chunk + ki * (k_pack * micro_size_k),
-                                warp_m * warp_row_tiles + i * micro_size_x)
-                        A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row,
-                                                                                  A_base1 + r + col]
+                        logical_row, logical_col = (rk * chunk + ki * (k_pack * micro_size_k),
+                                                    warp_m * warp_row_tiles + i * micro_size_x)
+                        A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + logical_row + row,
+                                                                                  A_base1 + logical_col + col]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 15035cd and 71f4284.

📒 Files selected for processing (1)
  • tilelang/intrinsics/mfma_macro_generator.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/utils/language.py (1)
  • to_buffer_region (187-211)
tilelang/intrinsics/mma_macro_generator.py (4)
  • ldmatrix_a (229-343)
  • ldmatrix_a (887-991)
  • ldmatrix_b (345-467)
  • ldmatrix_b (993-1105)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py

294-294: Ambiguous variable name: l

(E741)

🔇 Additional comments (5)
tilelang/intrinsics/mfma_macro_generator.py (5)

5-5: LGTM! Necessary imports for BufferRegion support.

The added imports (BufferRegion and to_buffer_region) are essential for the new region-aware buffer handling introduced in this refactor.

Also applies to: 12-12


142-142: LGTM! Completes bfloat16 support.

The addition of the "bfloat16": "bf16" mapping ensures consistency with the class-level dtype_abbrv dictionary and properly supports bfloat16 operations.


301-301: LGTM! Consistent BufferRegion normalization.

The signature change and normalization logic for ldmatrix_b mirror the implementation in ldmatrix_a, ensuring consistent handling of Buffer | BufferRegion inputs across both methods.

Also applies to: 313-317


336-348: LGTM! Buffer access consistent with ldmatrix_a.

The buffer access patterns correctly add base offsets and are consistent with the ldmatrix_a implementation.


255-255: 2D buffer assumption verified and confirmed valid.

The type signature change to Buffer | BufferRegion and normalization logic are correctly implemented. Verification across 29 call sites confirms all ldmatrix_a invocations are GEMM/MFMA matrix operations, which inherently require 2D buffers. The region indexing via region[-2].min and region[-1].min safely accesses matrix dimensions. The to_buffer_region() conversion maintains backward compatibility while properly handling both input types.

…eneration

- Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits.
- Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity.
- Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71f4284 and 005ffe9.

📒 Files selected for processing (1)
  • src/target/codegen_hip.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_hip.cc (1)
src/target/codegen_hip.h (1)
  • enable_fp8_ (79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)

Comment on lines 949 to 955
std::unordered_map<std::string, std::string> acc_type_map = {
{"float32", "float"},
{"float64", "double"},
{"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"},
{"float32x4", "float32x4"},
{"float8_e4m3fnuzx4", "fp8_e4_4_t"},
{"float8_e4m3fnuzx8", "long"},
{"int32x4", "int32x4"},
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
})";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Include <unordered_map> to fix the compile failure

std::unordered_map is introduced here but the file does not include <unordered_map>, so this will not compile in a clean build. Please add the missing header near the other standard includes.

+#include <unordered_map>

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/target/codegen_hip.cc around lines 949 to 955, the new use of
std::unordered_map requires including the <unordered_map> header; add #include
<unordered_map> alongside the other standard includes near the top of the file
so the code compiles in a clean build.

…dling in CodeGenTileLangHIP

- Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings.
- Updated logic to track user-selected backends and conditionally enable defaults based on environment variables.
- Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity.
- Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tilelang/intrinsics/mfma_macro_generator.py (1)

257-257: LGTM: Region-based buffer access properly implemented.

The changes correctly normalize the input to BufferRegion and use base offsets in buffer accesses. The pattern is consistent with similar changes in mma_macro_generator.py and maintains backward compatibility by accepting both Buffer and BufferRegion.

Optional: Consider more descriptive variable names.

Static analysis flags l (line 296) as an ambiguous variable name. While the meaning is clear in context (left/right indices), consider using more descriptive names like row_idx and col_idx for improved readability.

Also applies to: 269-274, 290-291, 298-299

CMakeLists.txt (1)

157-177: Verify mutual exclusivity of backend selection.

The auto-selection logic (lines 157–177) may enable multiple backends simultaneously (e.g., both Metal and CUDA on macOS if the conditionals don't properly exclude each other). Check whether the build system supports or requires mutually exclusive backends. If mutual exclusivity is required, add validation logic to error when multiple backends are enabled.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 005ffe9 and 219b9e8.

📒 Files selected for processing (2)
  • CMakeLists.txt (4 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/utils/language.py (1)
  • to_buffer_region (187-211)
tilelang/intrinsics/mma_macro_generator.py (4)
  • ldmatrix_a (229-343)
  • ldmatrix_a (887-991)
  • _warp_ldmatrix_a (304-341)
  • _warp_ldmatrix_a (901-989)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py

296-296: Ambiguous variable name: l

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
🔇 Additional comments (5)
tilelang/intrinsics/mfma_macro_generator.py (3)

5-5: LGTM: Necessary imports for BufferRegion support.

The added imports enable region-based buffer handling, which is consistent with the broader refactoring effort across the codebase.

Also applies to: 12-12


142-156: LGTM: bf16 support properly added.

The bfloat16 handling follows the established pattern for other data types and uses the appropriate MFMA suffix (bf16_1k) for the intrinsic.


303-303: LGTM: Consistent implementation across ldmatrix_b.

The changes to ldmatrix_b correctly mirror the pattern used in ldmatrix_a, ensuring consistent region-based buffer handling across both methods.

Also applies to: 315-320, 338-339, 349-350

CMakeLists.txt (2)

78-106: No issues found with CMake syntax on line 88.

The syntax if(DEFINED CACHE{${_backend_var}}) is valid CMake, introduced in CMake 3.14 to test cache entries specifically. This concern can be removed.


100-105: Type change to STRING is intentional and necessary for dual-purpose storage; no compatibility issues found in visible codebase.

The change from BOOL to STRING at line 127 is intentional and required by design. Line 205 demonstrates the dual purpose: cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA) reassigns USE_CUDA to a filesystem path after the CUDA toolkit is discovered. The variable must hold both boolean values (ON/OFF) and paths, necessitating STRING type.

All visible consumers are compatible:

  • Conditional checks (if(USE_CUDA) at lines 179, 186, 198) work correctly with STRING values "ON"/"OFF"
  • Function calls (find_rocm(${USE_ROCM}) at line 189) pass the value without type assumptions
  • Environment variable consumers (version_provider.py) are independent

The external TVM module (FindROCM at line 188) cannot be verified from the visible codebase, but the current usage pattern shows no issues.

…or.py

- Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity.
- Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)

291-292: LGTM: Buffer access correctly applies region-based offsets.

The buffer indexing properly applies A_base0 and A_base1 offsets derived from the BufferRegion, enabling correct addressing for both transposed and non-transposed cases.

Optional: Consider more descriptive variable names. The variable l (used here but defined in unchanged code) is flagged by static analysis as ambiguous. Consider renaming to something more descriptive like row_offset or spatial_idx in a future refactor.

Also applies to: 299-300

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 219b9e8 and 683d479.

📒 Files selected for processing (2)
  • src/target/codegen_hip.cc (1 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/mfma_macro_generator.py (3)
tilelang/utils/language.py (2)
  • is_fragment (105-116)
  • to_buffer_region (187-211)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
  • ldmatrix_a (190-232)
  • _warp_ldmatrix_a (216-230)
  • ldmatrix_b (234-284)
tilelang/intrinsics/mma_macro_generator.py (6)
  • ldmatrix_a (229-343)
  • ldmatrix_a (887-991)
  • _warp_ldmatrix_a (304-341)
  • _warp_ldmatrix_a (901-989)
  • ldmatrix_b (345-467)
  • ldmatrix_b (993-1105)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py

297-297: Ambiguous variable name: l

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Quick Lint
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build SDist
🔇 Additional comments (7)
src/target/codegen_hip.cc (1)

931-931: ****

Verification confirms the change is correct. The type bfloat16x4_vec is defined in src/tl_templates/hip/common.h as a GCC vector attribute type, and it is used in src/tl_templates/hip/gemm.h for MFMA operations. The naming is consistent—only bfloat16x4_vec exists in the codebase; analogous vector types like float16x4_vec and float32x4_vec are not defined. The struct bfloat16x4 remains available for existing code that uses it directly, so this dtype_map change does not introduce breaking changes—it only affects the C++ type emitted in generated MFMA code.

tilelang/intrinsics/mfma_macro_generator.py (6)

5-5: LGTM: Import additions support BufferRegion functionality.

The new imports for BufferRegion and to_buffer_region are necessary for the region-based buffer addressing introduced in this file, consistent with similar changes across the codebase.

Also applies to: 12-12


142-142: LGTM: bf16 abbreviation mapping added.

The addition of the "bfloat16": "bf16" mapping is consistent with other dtype abbreviations and enables bf16 support in MFMA intrinsics.


154-156: LGTM: bf16 suffix formatting matches HIP intrinsic conventions.

The special suffix format for bf16 (bf16_1k without underscore separator) is correctly documented and aligns with HIP intrinsic naming requirements.


304-304: LGTM: Consistent BufferRegion normalization for B matrix.

The BufferRegion normalization in ldmatrix_b mirrors the implementation in ldmatrix_a, correctly extracting base offsets for region-based addressing. The same dimension assumptions apply (2D+ regions for matrix operations).

Also applies to: 316-320


339-340: LGTM: Buffer access correctly applies region-based offsets for B matrix.

The buffer indexing in ldmatrix_b properly applies B_base0 and B_base1 offsets, mirroring the correct implementation pattern from ldmatrix_a.

Also applies to: 350-351


258-258: Region dimension assumption verified—code is safe.

All callers of ldmatrix_a are GEMM kernels where the matrix A has at least 2 dimensions. The to_buffer_region function converts matrices to BufferRegions with full-region coverage, guaranteeing the region has at least 2 dimensions. The indexing pattern region[-2].min and region[-1].min is consistent and used throughout the codebase with no reported issues.

…r non-negative index checks to reduce log verbosity.
- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices.
- Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tilelang/intrinsics/wgmma_macro_generator.py (1)

191-193: Pass buffers to _determinate_swizzle_mode.

_determinate_swizzle_mode operates on Buffer (it probes .shape via layout helpers). Feeding it BufferRegion will blow up as soon as the helper runs, so both the main WGMMA path and the RS fallback regress. Please unwrap the regions before calling the helper.

-        a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
...
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)

Also applies to: 352-353

🧹 Nitpick comments (1)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

80-86: Consider applying the same refactor pattern for consistency.

The forward pass still uses the older pattern with alloc_local and conditional branches, while the backward pass (lines 270-273) has been refactored to use an inline ternary expression. For consistency and maintainability, consider refactoring this section similarly:

-    end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
-    start = T.alloc_local([1], 'int32')
-    if window_size is not None:
-        start[0] = T.max(0, (bx * block_M - window_size) // block_N)
-    else:
-        start[0] = 0
-
-    for k in T.Pipelined(start[0], end, num_stages=num_stages):
+    end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
+    start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0
+
+    for k in T.Pipelined(start, end, num_stages=num_stages):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 683d479 and 4a74b62.

📒 Files selected for processing (4)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (1 hunks)
  • src/transform/legalize_negative_index.cc (1 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
  • Pipelined (57-94)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
  • is_fragment (105-116)
  • retrive_ptr_from_buffer_region (255-282)
  • is_full_region (370-399)
tilelang/language/builtin.py (2)
  • initialize_wgmma_descriptor (703-730)
  • warpgroup_fence_operand (433-592)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
  • Pipelined (57-94)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py

274-274: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


396-396: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


396-396: Unpacked variable warp_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
🔇 Additional comments (4)
src/transform/legalize_negative_index.cc (1)

53-55: Verify that suppressing this warning in release builds is intentional.

The change from LOG(WARNING) to DLOG(WARNING) means this diagnostic message will only appear in debug builds. If this warning helps identify potential indexing issues or aids in debugging production problems, consider keeping it as LOG(WARNING). However, if this warning is noisy and not actionable in production, the debug-only logging is appropriate.

examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)

84-85: LGTM: Clean refactor to scalar expression.

The change from local array allocation to a direct scalar expression simplifies the code while preserving the logic. The use of T.max(0, ...) correctly guards against negative start positions when window_size is large relative to bx * block_M.


87-87: LGTM: Correctly updated to use scalar start.

The update to pass the scalar start directly (instead of dereferencing start[0]) is the necessary complement to the refactor on lines 84-85. This aligns with the T.Pipelined signature expectation of a PrimExpr.

examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

270-273: Good simplification of loop bound computation.

The refactor consolidates the loop end calculation into a single inline expression, removing the local allocation overhead and improving readability. The logic is correct: it computes the appropriate upper bound for the sliding window case and falls back to the full sequence length when no window is specified.

- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices.
- Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)

191-192: Pass region.buffer to _determinate_swizzle_mode.

As flagged in previous reviews, _determinate_swizzle_mode expects a Buffer (line 150) but receives BufferRegion objects, causing a crash when accessing .shape.

Apply this diff:

-        a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)

346-346: Pass B_region.buffer to _determinate_swizzle_mode.

Identical to the issue at lines 191-192: _determinate_swizzle_mode expects a Buffer but receives a BufferRegion, which will crash when accessing .shape.

Apply this diff:

-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
🧹 Nitpick comments (3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

81-84: Consider using T.floordiv for consistency.

The inline expression for start simplifies the code well by removing the local allocation. However, line 81-82 uses Python's // operator while line 266 uses T.floordiv for a similar floor division operation. Although // should work through PrimExpr operator overloading, using T.floordiv would be more consistent with the codebase pattern.

Apply this diff for consistency:

-            start = T.max(0,
-                          (bx * block_M - window_size) // block_N) if window_size is not None else 0
+            start = T.max(0,
+                          T.floordiv(bx * block_M - window_size, block_N)) if window_size is not None else 0
tilelang/intrinsics/wgmma_macro_generator.py (2)

274-274: Consider prefixing unused variable with underscore.

Variable tx is unpacked but never used in the macro body. If it's intentionally unused, prefix it with _ to indicate this.

Apply this diff:

-            tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+            _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)

Based on static analysis.


390-390: Prefix unused variables with underscores.

Variables tx and warp_m are unpacked but never used in the wgmma_rs macro body. Prefix them with _ to indicate they're intentionally unused.

Apply this diff:

-            tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+            _tx, warp_n, _warp_m = self.extract_thread_binding(thread_binding)

Based on static analysis.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4a74b62 and c2e3f08.

📒 Files selected for processing (7)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2 hunks)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (2 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py (1 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/linear_attention/example_linear_attn_fwd.py (1 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
tilelang/language/loop.py (1)
  • Pipelined (57-94)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/language/loop.py (1)
  • Pipelined (57-94)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/language/loop.py (1)
  • Pipelined (57-94)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
  • is_fragment (105-116)
  • retrive_ptr_from_buffer_region (255-282)
  • is_full_region (370-399)
tilelang/language/builtin.py (2)
  • initialize_wgmma_descriptor (703-730)
  • warpgroup_fence_operand (433-592)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
  • Pipelined (57-94)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
  • Pipelined (57-94)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py

274-274: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


390-390: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


390-390: Unpacked variable warp_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (12)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)

165-166: LGTM! Cleaner scalar expression for start boundary.

The direct scalar computation is more readable and efficient than the previous array-based approach. The conditional logic correctly handles both full attention (start=0) and sliding window cases (clamping to non-negative block index).


168-168: LGTM! Correct usage of scalar start with Pipelined.

Passing the scalar start directly aligns with the Pipelined function signature, which expects tir.PrimExpr parameters. This eliminates the unnecessary array indirection.

examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)

84-87: LGTM! Clean refactor to scalar expression.

The refactor from per-iteration local allocation to a direct scalar computation simplifies the code and eliminates unnecessary overhead. The T.max(0, ...) correctly handles edge cases where the sliding window start might be negative, and the conditional expression properly adapts to the window_size parameter.


266-270: LGTM! Consistent backward pass refactor.

The refactor mirrors the forward pass changes, replacing the per-iteration local allocation with a direct scalar computation. The logic correctly computes the loop end boundary for both sliding window and full attention modes, and the change maintains consistency across the codebase.

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)

175-179: LGTM! Clean refactoring that simplifies loop bound computation.

The change replaces per-iteration local scalar tracking with a direct scalar expression, removing unnecessary allocation and indexing. The start computation correctly identifies the first block within the sliding window (or 0 for full attention), and the scalar value aligns with the Pipelined function signature which expects a PrimExpr scalar.

examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

267-270: LGTM! Clean simplification of loop bounds.

The inline expression for loop_ed effectively simplifies the logic by removing the local allocation and using a direct scalar value instead of array indexing (loop_ed vs loop_ed[0]). The ternary expression correctly computes the loop bound for sliding window attention in the backward pass, and the change improves code readability.

tilelang/intrinsics/wgmma_macro_generator.py (4)

7-8: LGTM on import additions.

The BufferRegion import and utility function imports are correct and properly support the refactoring to region-based parameters.


163-171: Signature refactoring looks good.

The method now correctly accepts BufferRegion parameters, and the fragment dispatch logic at line 170 properly delegates to wgmma_rs when A is a fragment.


266-270: Proper region-to-pointer conversion and validation.

The pointer extraction using retrive_ptr_from_buffer_region and the full-region assertion for the output accumulator are correctly implemented.


382-386: Correct region assertions and pointer/buffer extraction.

The full-region assertions for fragment inputs and the extraction of buffers/pointers from regions are properly implemented for the RS (register-to-shared) path.

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)

168-169: LGTM! Simplified start computation is correct.

The scalar computation correctly determines the starting block for the sliding window. The floor division may occasionally include one extra block at the boundary that gets fully masked out by the explicit bounds checking in MMA0, but this conservative approach is safe and doesn't affect correctness. The change eliminates unnecessary local allocation overhead.


172-172: LGTM! Correct usage of scalar start.

Passing start directly (instead of start[0]) is correct now that it's computed as a scalar PrimExpr. This aligns with the Pipelined function signature and eliminates unnecessary indexing.

def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make kernel source printing conditional to avoid cluttering benchmark output.

The print statement will execute every time tl_fused_chunk_fwd is called, including during the do_bench iterations at line 141. This will flood the output with kernel source text and make the benchmark results unreadable.

Consider one of these alternatives:

Option 1: Make it conditional on an environment variable or flag

-    print(kernel.get_kernel_source())
+    if os.environ.get('TILELANG_DEBUG_KERNEL_SOURCE'):
+        print(kernel.get_kernel_source())

Option 2: Print once before benchmarking in main()

 def tl_fused_chunk_fwd(q, k, v):
     B, S, H, D = q.shape
     kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
-    print(kernel.get_kernel_source())
     o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
     h = kernel(q, k, v, o)
     return o, h

Then add in main() before benchmarking:

def main(B=1, S=512, H=16, D=128):
    # ... setup code ...
    
    # Print kernel source once for inspection
    kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
    print(kernel.get_kernel_source())
    
    o, h = tl_fused_chunk_fwd(q, k, v)
    # ... rest of main ...
🤖 Prompt for AI Agents
In examples/linear_attention/example_linear_attn_fwd.py around line 93, the
unconditional print(kernel.get_kernel_source()) inside tl_fused_chunk_fwd is
flooding benchmark output; remove that print from the hot path and instead
either (A) print the kernel source once before benchmarking in main() after
constructing the kernel, or (B) wrap the print in a conditional guard (e.g.,
check an environment variable or a debug flag) so it only emits when explicitly
enabled; update main() to print the kernel source once if you choose option A,
and ensure no prints occur inside the timed/iterative code paths.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)

191-192: Pass .buffer to _determinate_swizzle_mode to avoid AttributeError.

_determinate_swizzle_mode expects a Buffer (line 150) and accesses .shape and other Buffer attributes. Passing BufferRegion directly will crash. This issue was flagged in previous reviews.

Apply this diff:

-        a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)

346-346: Pass .buffer to _determinate_swizzle_mode in wgmma_rs.

Same issue as lines 191-192: _determinate_swizzle_mode expects a Buffer, not a BufferRegion. This was flagged in previous reviews.

Apply this diff:

-        b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+        b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
🧹 Nitpick comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)

274-274: Prefix unused unpacked variables with underscore.

tx is unpacked but never used. Consider prefixing it with an underscore to indicate it's intentionally unused: _tx, warp_n, warp_m = ...


390-390: Prefix unused unpacked variables with underscore.

Both tx and warp_m are unpacked but never used in this macro. Consider prefixing them with underscores: _tx, warp_n, _warp_m = ...

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f7fe22d and a6bab65.

📒 Files selected for processing (1)
  • tilelang/intrinsics/wgmma_macro_generator.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
  • is_fragment (105-116)
  • retrive_ptr_from_buffer_region (255-282)
  • is_full_region (370-399)
tilelang/language/builtin.py (1)
  • warpgroup_fence_operand (433-592)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py

274-274: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


390-390: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


390-390: Unpacked variable warp_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist
🔇 Additional comments (6)
tilelang/intrinsics/wgmma_macro_generator.py (6)

7-8: LGTM: Imports support BufferRegion refactor.

The imports correctly add BufferRegion and the utility functions needed to work with region-based parameters. The retrive_ptr_from_buffer_region helper is now properly centralized in tilelang.utils (addressing previous review concerns).


164-171: LGTM: Signature refactored to BufferRegion.

The method signature correctly migrates to region-based parameters. The dispatch logic at lines 170-171 appropriately routes to wgmma_rs when A is a fragment.


266-270: LGTM: Region-to-pointer/buffer extraction is correct.

The code properly extracts pointers for shared memory operands (A, B) using retrive_ptr_from_buffer_region and ensures C is a full region before extracting its buffer for fragment operations.


273-315: LGTM: Macro signature correctly uses pointers for shared memory operands.

The inner _warp_mma macro signature at line 273 properly reflects the pointer-based access pattern for A and B (shared memory) while keeping C as a buffer (fragment). The descriptor initialization (lines 278-283), fence operations (lines 284, 313), and final invocation (line 315) are all consistent with this approach.


382-386: LGTM: wgmma_rs region extraction is correct.

The code properly distinguishes between fragment buffers (A, C) and shared memory pointers (B), with appropriate full-region assertions before buffer extraction.


389-438: LGTM: wgmma_rs macro correctly handles mixed buffer/pointer operands.

The macro signature (line 389) and implementation properly handle the asymmetric operand types: fragment buffers for A and C, shared memory pointer for B. The descriptor initialization (line 393), fences (lines 396-397, 435-436), and invocation (line 438) are all consistent.

- Removed outdated shared memory reduction logic from `reduce.cc`.
- Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes.
- Updated CUDA header to define a wider accumulator type for better numerical accuracy.
- Enhanced error handling for buffer scope validation in the reduction process.
- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present.
- Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/reduce.cc (1)

363-365: Restore shared-memory lowering before landing

We used to lower ReduceOp when either side lived in shared / shared.dyn. After this edit the code returns LOG(FATAL) for those scopes, so every existing kernel that still performs a block-level reduction through shared memory now hard-crashes instead of compiling. Please reinstate the shared-memory lowering (or offer an equivalent replacement) before merging—otherwise we break the current API surface.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a6bab65 and 3f211ae.

📒 Files selected for processing (5)
  • src/op/reduce.cc (1 hunks)
  • src/tl_templates/cuda/reduce.h (1 hunks)
  • src/transform/lower_opaque_block.cc (3 hunks)
  • tilelang/language/reduce.py (2 hunks)
  • tilelang/language/v2/builder.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.

Applied to files:

  • src/tl_templates/cuda/reduce.h
🧬 Code graph analysis (2)
tilelang/language/reduce.py (3)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/allocate.py (1)
  • alloc_fragment (59-70)
tilelang/utils/language.py (2)
  • is_shared (45-60)
  • is_fragment (105-116)
src/tl_templates/cuda/reduce.h (2)
src/tl_templates/cuda/common.h (5)
  • tl (195-289)
  • tl (518-520)
  • tl (521-523)
  • tl (538-618)
  • bfloat16_t (528-529)
src/tl_templates/hip/reduce.h (1)
  • tl (5-88)
🪛 Ruff (0.14.3)
tilelang/language/reduce.py

100-100: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build SDist
🔇 Additional comments (6)
tilelang/language/v2/builder.py (1)

228-249: No issues found. The change allowing silent Buffer evaluation is correct and intentional.

The modification treats tvm.tir.Buffer objects like None and str values in the eval() method—they're valid but no-op cases. This aligns with Python semantics where expression statements can contain just identifiers or other values without action. After unwrap_expr() converts local.var Buffers to BufferLoad nodes (line 229), any remaining Buffer objects are likely parameter buffers or results from other operations that shouldn't raise an error when evaluated as standalone expressions.

The change is consistent with the broader BufferRegion refactor evident in the codebase and poses no risk of suppressing legitimate errors since malformed expressions would be caught earlier in the evaluation chain.

src/transform/lower_opaque_block.cc (2)

122-122: LGTM: Consistent handling of unit loops with unroll pragmas.

The updated logic correctly treats loops with extent 1 and only pragma_unroll_explicit as unit loops that can be simplified away. Both call sites (line 122 for recording the loop variable, lines 138-139 for elimination) use the same helper, ensuring consistent behavior.

Also applies to: 138-139


154-169: LGTM: Well-designed helper for annotation filtering.

The helper function is correctly implemented with appropriate const-correctness. The logic properly identifies annotations that are either truly empty or contain only the unroll pragma, allowing unit-length loops produced by unroll directives to be optimized away.

src/tl_templates/cuda/reduce.h (3)

1-250: No action needed. SharedReduceWarp removal is not a breaking change.

The search confirms that SharedReduceWarp has zero usages anywhere in the codebase. CUDA and HIP versions are separate, independently-loaded implementations for different hardware backends—they're not meant to be identical. The CUDA version has been modernized with the removal of unused code, which is correct and safe. The HIP version can maintain its own code independently.


4-5: The original review comment is incorrect—<cstdint> is actively used in the file.

The file uses uint32_t (from <cstdint>) at lines 72 and 92 in the expression uint32_t(-1). However, <type_traits> does appear to be unused based on the absence of any std:: namespace usages in the file.

Likely an incorrect or invalid review comment.


9-19: The AccType trait is defined but not used in this file or elsewhere in the codebase.

While the trait's design is sound—promoting FP16/BF16 to float for accumulation improves numerical accuracy and follows standard CUDA practices—the trait struct remains unused. The reduction operators in this file do not reference it.

Either integrate AccType into the reduction implementations using typename AccType<T>::type, or remove it if it's planned for future work rather than current use.

Comment on lines 43 to 89
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
red_frag_out = alloc_fragment(out.shape, out.dtype)

# rename buffers
IRBuilder.name(buffer.name + "_frag", red_frag_in)
IRBuilder.name(out.name + "_frag", red_frag_out)

copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_shared(buffer) and is_fragment(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
IRBuilder.name(buffer.name + "_frag", red_frag_in)

copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
out.access_ptr("w"),
reduce_type,
dim,
clear,
)
elif is_fragment(buffer) and is_shared(out):
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)

tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_fragment(buffer) and is_fragment(out):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Preserve accumulation semantics when clear=False.

In both shared-output branches we allocate red_frag_out and pass it to tl.reduce without seeding it with the existing out contents. For clear=False, the intrinsic expects the destination pointer to contain the prior accumulator; feeding an uninitialized fragment drops the existing values and corrupts results. Please copy the shared buffer into the fragment before the reduce when clear is false.

         if is_shared(buffer) and is_shared(out):
             red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
             red_frag_out = alloc_fragment(out.shape, out.dtype)
@@
-            tir.call_intrin(
+            if not clear:
+                copy(out, red_frag_out)
+            tir.call_intrin(
                 "handle",
                 tir.op.Op.get("tl.reduce"),
                 red_frag_in.access_ptr("r"),
                 red_frag_out.access_ptr("w"),
                 reduce_type,
@@
             red_frag_out = alloc_fragment(out.shape, out.dtype)
             IRBuilder.name(out.name + "_frag", red_frag_out)
 
             tir.call_intrin(
+                if not clear:
+                    copy(out, red_frag_out)
                 "handle",
                 tir.op.Op.get("tl.reduce"),
                 buffer.access_ptr("r"),
                 red_frag_out.access_ptr("w"),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
red_frag_out = alloc_fragment(out.shape, out.dtype)
# rename buffers
IRBuilder.name(buffer.name + "_frag", red_frag_in)
IRBuilder.name(out.name + "_frag", red_frag_out)
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_shared(buffer) and is_fragment(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
IRBuilder.name(buffer.name + "_frag", red_frag_in)
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
out.access_ptr("w"),
reduce_type,
dim,
clear,
)
elif is_fragment(buffer) and is_shared(out):
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_fragment(buffer) and is_fragment(out):
if is_shared(buffer) and is_shared(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
red_frag_out = alloc_fragment(out.shape, out.dtype)
# rename buffers
IRBuilder.name(buffer.name + "_frag", red_frag_in)
IRBuilder.name(out.name + "_frag", red_frag_out)
copy(buffer, red_frag_in)
if not clear:
copy(out, red_frag_out)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_shared(buffer) and is_fragment(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
IRBuilder.name(buffer.name + "_frag", red_frag_in)
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
out.access_ptr("w"),
reduce_type,
dim,
clear,
)
elif is_fragment(buffer) and is_shared(out):
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)
if not clear:
copy(out, red_frag_out)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_fragment(buffer) and is_fragment(out):
🤖 Prompt for AI Agents
In tilelang/language/reduce.py around lines 43 to 89, the branches that allocate
a red_frag_out for shared output do not seed it with the existing shared buffer
when clear is False; change those branches so that immediately after allocating
red_frag_out (and naming it) you conditionally copy(out, red_frag_out) if not
clear, then call tir.call_intrin; this ensures the fragment contains the prior
accumulator before tl.reduce writes into it, preserving accumulation semantics.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
CMakeLists.txt (1)

164-165: ROCm path handling lost when auto-selecting backend.

As flagged in the previous review, line 165 converts any non-empty $ENV{USE_ROCM} to ON, discarding path information. Since the documentation (line 71) states that USE_ROCM supports "ON/OFF/or ROCm SDK path", this loses the ability to use a specific ROCm SDK path if provided.

  elseif($ENV{USE_ROCM})
-   set(USE_ROCM ON)
+   set(USE_ROCM $ENV{USE_ROCM})

This preserves the environment value, allowing find_rocm() to receive either ON or a specific path as intended.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3f211ae and 502d71f.

📒 Files selected for processing (2)
  • CMakeLists.txt (4 hunks)
  • examples/linear_attention/example_retention_fwd.py (0 hunks)
💤 Files with no reviewable changes (1)
  • examples/linear_attention/example_retention_fwd.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build SDist
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (2)
CMakeLists.txt (2)

122-129: Re-application of backend settings post-TVM configuration is sound.

The logic correctly caches TileLang's backend preferences before including TVM's config.cmake, then restores them afterward to prevent TVM from overriding user-selected backends. This approach is robust and preserves user intent.


147-155: User selection detection logic is clear and comprehensive.

The logic correctly identifies whether any backend was explicitly selected by the user (either via command-line -DUSE_* or environment variable at configuration time). This gate prevents unintended auto-selection when user has explicitly chosen a backend.

Comment on lines +88 to +93
if(DEFINED CACHE{${_backend_var}})
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
if(_cache_type STREQUAL "UNINITIALIZED")
set(_user_override ON)
endif()
endif()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Invalid CMake syntax: DEFINED CACHE{...} is not a recognized construct.

CMake does not support DEFINED CACHE{...} syntax. To check if a cache entry exists and its type, use get_property() instead. For example:

if(DEFINED ${_backend_var})
  get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
  if(_cache_type STREQUAL "UNINITIALIZED")
    set(_user_override ON)
  endif()
endif()

This code will likely fail during CMake configuration.

- if(DEFINED CACHE{${_backend_var}})
+ if(DEFINED ${_backend_var})
    get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
🤖 Prompt for AI Agents
In CMakeLists.txt around lines 88 to 93, the conditional uses invalid CMake
syntax `DEFINED CACHE{...}`; replace it by first checking the variable name
existence and/or directly using get_property to query the CACHE entry type.
Specifically, call get_property(CACHE_VAR_TYPE CACHE ${_backend_var} PROPERTY
TYPE RESULT_VARIABLE) or use if(DEFINED ${_backend_var}) before calling
get_property, then test if the returned type equals "UNINITIALIZED" and
set(_user_override ON) accordingly so CMake config will not fail.

@yyttt6
Copy link
Contributor

yyttt6 commented Nov 11, 2025

Benchmarked on an H100 server — observed significant speedups of gemm_v2 over gemm_v1 in example_mha_fwd_bshd, example_mha_fwd_bshd_wgmma_pipelined, and example_mha_fwd_bhsd_wgmma_pipelined, with example_mha_fwd_bshd achieving up to 1.36× acceleration .
Speedup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants