-
Notifications
You must be signed in to change notification settings - Fork 332
[Refactor] Add kernel selection option for GEMM v1 in environment settings #1200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version. - Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable. - Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThis pull request introduces environment-driven GEMM kernel selection, refactors the GEMM infrastructure to normalize buffers to a unified BufferRegion representation, renames data members and methods to a consistent underscore-suffixed naming convention, adds region-aware utility functions, and enables FP64 MMA support while removing the BufferGemmCollector class. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant TileLang Runtime
participant Environment
participant GEMM Layer
User->>Environment: Check TILELANG_USE_GEMM_V1
Environment-->>TileLang Runtime: Returns boolean (true/false)
TileLang Runtime->>GEMM Layer: Resolve gemm to gemm_v1 or gemm_v2
GEMM Layer->>GEMM Layer: Normalize A/B/C to BufferRegion via to_buffer_region()
GEMM Layer->>GEMM Layer: Extract shape, stride, offset from region
GEMM Layer->>GEMM Layer: Call _gemm_impl(op_key, A, B, C, ...)
GEMM Layer-->>User: Execute selected GEMM kernel
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations. - Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value. - Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (3)
tilelang/language/builtin.py (2)
460-491: Consider extracting stride calculation to reduce duplication.The stride calculation logic (lines 469-478) is duplicated in the BufferRegion branch (lines 517-526). Additionally, the control flow is inconsistent: BufferLoad and BufferRegion branches return early, while Buffer and pointer branches fall through to a common return.
Consider extracting the stride calculation into a helper function:
def _compute_element_offset(indices_or_mins, buffer): """Compute element offset from indices/mins using buffer strides or row-major fallback.""" if len(buffer.strides) == len(buffer.shape) and len(buffer.strides) > 0: elem_off = 0 for idx, stride in zip(indices_or_mins, buffer.strides): elem_off = elem_off + idx * stride else: elem_off = 0 stride_acc = 1 for idx, dim in zip(reversed(indices_or_mins), reversed(buffer.shape)): elem_off = elem_off + idx * stride_acc stride_acc = stride_acc * dim return elem_offThen use it in both branches:
elem_off = _compute_element_offset(buffer_or_ptr.indices, buf)For consistency, consider restructuring all branches to either use early returns or fall through to a common return statement.
555-577: Consider narrowing the exception handling scope.The blind exception catch on line 572 is flagged by static analysis. While this defensive pattern is reasonable for probing optional attributes during type inference, consider catching more specific exceptions like
AttributeError:- except Exception: + except AttributeError: inferred = NoneThis makes the code's intent clearer and avoids accidentally silencing unexpected errors.
tilelang/intrinsics/tcgen05_macro_generator.py (1)
249-278: PreferTypeErrorfor invalid type errors.The
access_ptr_fromhelper correctly handlesBuffer,BufferLoad, andBufferRegion, but should raiseTypeErrorinstead ofValueErrorwhen encountering unsupported types (lines 266, 278).Apply this diff:
else: - raise ValueError(f"Unsupported index type: {type(indice)}") + raise TypeError(f"Unsupported index type: {type(indice)}") stride *= shape return buffer.access_ptr(access_type, offset=offset) else: - raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (25)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py(2 hunks)src/op/gemm.cc(10 hunks)src/op/gemm.h(3 hunks)src/op/gemm_py.cc(7 hunks)src/op/gemm_py.h(2 hunks)src/op/gemm_sp.cc(4 hunks)src/op/gemm_sp.h(3 hunks)src/op/operator.h(0 hunks)src/transform/lower_tile_op.cc(2 hunks)tilelang/intrinsics/mfma_macro_generator.py(7 hunks)tilelang/intrinsics/mma_macro_generator.py(9 hunks)tilelang/intrinsics/mma_sm70_macro_generator.py(6 hunks)tilelang/intrinsics/tcgen05_macro_generator.py(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py(4 hunks)tilelang/language/builtin.py(3 hunks)tilelang/language/gemm.py(5 hunks)tilelang/layout/swizzle.py(5 hunks)tilelang/tileop/gemm/__init__.py(1 hunks)tilelang/tileop/gemm/gemm_base.py(2 hunks)tilelang/tileop/gemm/gemm_mfma.py(4 hunks)tilelang/tileop/gemm/gemm_mma.py(4 hunks)tilelang/tileop/gemm/gemm_mma_sm70.py(3 hunks)tilelang/tileop/gemm/gemm_tcgen05.py(1 hunks)tilelang/tileop/gemm/gemm_wgmma.py(3 hunks)tilelang/utils/language.py(2 hunks)
💤 Files with no reviewable changes (1)
- src/op/operator.h
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/op/gemm.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/op/gemm.cc
🧬 Code graph analysis (23)
tilelang/tileop/gemm/gemm_tcgen05.py (1)
tilelang/tileop/gemm/gemm_base.py (2)
ARegion(79-80)BRegion(83-84)
tilelang/intrinsics/mma_sm70_macro_generator.py (2)
tilelang/intrinsics/mfma_macro_generator.py (3)
_warp_ldmatrix_a(278-301)ldmatrix_b(305-358)ldmatrix_b(795-870)tilelang/intrinsics/mma_macro_generator.py (6)
_warp_ldmatrix_a(249-285)_warp_ldmatrix_a(808-896)ldmatrix_b(289-376)ldmatrix_b(900-1012)mma_load_layout(223-224)mma_load_layout(307-308)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/intrinsics/mma_macro_generator.py (6)
ldmatrix_a(207-287)ldmatrix_a(794-898)_warp_ldmatrix_a(249-285)_warp_ldmatrix_a(808-896)ldmatrix_b(289-376)ldmatrix_b(900-1012)tilelang/intrinsics/mma_sm70_macro_generator.py (3)
ldmatrix_a(190-236)_warp_ldmatrix_a(220-234)ldmatrix_b(238-290)
tilelang/tileop/gemm/gemm_mma.py (4)
tilelang/tileop/gemm/gemm_base.py (4)
ARegion(79-80)A(67-68)BRegion(83-84)B(71-72)tilelang/utils/language.py (1)
is_shared(47-62)tilelang/intrinsics/mfma_macro_generator.py (2)
ldmatrix_b(305-358)ldmatrix_b(795-870)tilelang/intrinsics/mma_macro_generator.py (2)
ldmatrix_b(289-376)ldmatrix_b(900-1012)
tilelang/language/gemm.py (4)
src/op/builtin.h (1)
tvm(13-575)tilelang/utils/language.py (1)
get_buffer_region_from_load(164-186)tilelang/env.py (2)
get(175-178)use_gemm_v1(281-287)tilelang/primitives/gemm/__init__.py (1)
gemm(10-46)
tilelang/tileop/gemm/__init__.py (1)
tilelang/tileop/gemm/gemm_base.py (15)
A(67-68)B(71-72)C(75-76)M(34-35)N(38-39)K(42-43)trans_A(46-47)trans_B(50-51)stride_A(91-92)stride_B(95-96)offset_A(99-100)offset_B(103-104)clear_accum(107-108)k_pack(111-112)wg_wait(115-116)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
tilelang/env.py (1)
disable_cache(275-276)
src/op/gemm_py.cc (4)
tilelang/language/utils.py (1)
region(8-27)tilelang/language/tir/op.py (1)
tvm_access_ptr(651-676)src/op/gemm.cc (2)
NormalizeToBufferRegion(52-97)NormalizeToBufferRegion(52-52)src/op/operator.cc (2)
GetVarFromAccessPtr(74-81)GetVarFromAccessPtr(74-74)
tilelang/tileop/gemm/gemm_mma_sm70.py (5)
tilelang/tileop/gemm/gemm_base.py (4)
ARegion(79-80)A(67-68)BRegion(83-84)B(71-72)tilelang/utils/language.py (1)
is_shared(47-62)tilelang/intrinsics/mfma_macro_generator.py (2)
ldmatrix_b(305-358)ldmatrix_b(795-870)tilelang/intrinsics/mma_macro_generator.py (2)
ldmatrix_b(289-376)ldmatrix_b(900-1012)tilelang/intrinsics/mma_sm70_macro_generator.py (1)
ldmatrix_b(238-290)
src/op/gemm_sp.cc (4)
src/op/gemm.cc (2)
computeWarpPartition(222-402)computeWarpPartition(222-223)src/target/utils.cc (6)
TargetGetWarpSize(130-135)TargetGetWarpSize(130-130)TargetIsHopper(52-57)TargetIsHopper(52-52)TargetIsAmpere(45-50)TargetIsAmpere(45-45)src/op/operator.cc (2)
GetVarFromAccessPtr(74-81)GetVarFromAccessPtr(74-74)src/layout/gemm_layouts.cc (8)
makeGemmFragmentC(121-136)makeGemmFragmentC(121-123)makeGemmABLayoutHopper(741-766)makeGemmABLayoutHopper(741-742)makeGemmSparseFragmentC(138-157)makeGemmSparseFragmentC(138-140)makeGemmSparseAmpereABLayout(683-688)makeGemmSparseAmpereABLayout(683-684)
tilelang/intrinsics/tcgen05_macro_generator.py (2)
src/transform/lower_tile_op.cc (8)
access_ptr(287-385)access_ptr(288-290)buffer(272-280)buffer(272-272)buffer(401-418)buffer(401-401)buffer(420-437)buffer(420-420)tilelang/intrinsics/wgmma_macro_generator.py (2)
_warp_mma(300-336)_warp_mma(406-458)
src/op/gemm_sp.h (4)
src/op/gemm.cc (2)
computeWarpPartition(222-402)computeWarpPartition(222-223)src/op/gemm_sp.cc (2)
computeWarpPartition(21-63)computeWarpPartition(21-25)tilelang/tileop/gemm/__init__.py (2)
M(89-90)N(93-94)tilelang/tileop/gemm/gemm_base.py (2)
M(34-35)N(38-39)
tilelang/tileop/gemm/gemm_wgmma.py (3)
tilelang/tileop/gemm/gemm_base.py (12)
A_base_offsets(152-154)B_base_offsets(157-159)C_base_offsets(162-164)ARegion(79-80)BRegion(83-84)CRegion(87-88)C(75-76)clear_accum(107-108)wg_wait(115-116)is_gemm_ss(21-22)is_gemm_rs(27-28)A(67-68)tilelang/intrinsics/wgmma_macro_generator.py (1)
wgmma(163-338)tilelang/transform/simplify.py (1)
_Simplify(31-49)
tilelang/intrinsics/mma_macro_generator.py (2)
tilelang/intrinsics/mfma_macro_generator.py (4)
extract_thread_binding(226-253)_warp_ldmatrix_a(278-301)ldmatrix_b(305-358)ldmatrix_b(795-870)tilelang/intrinsics/mma_sm70_macro_generator.py (3)
extract_thread_binding(158-188)_warp_ldmatrix_a(220-234)ldmatrix_b(238-290)
tilelang/tileop/gemm/gemm_mfma.py (2)
tilelang/tileop/gemm/gemm_base.py (4)
ARegion(79-80)A(67-68)BRegion(83-84)B(71-72)tilelang/utils/language.py (1)
is_shared(47-62)
tilelang/utils/language.py (3)
src/transform/lower_tile_op.cc (2)
buffer(272-280)buffer(272-272)src/transform/flatten_buffer.cc (2)
buffer(295-321)buffer(295-296)src/transform/legalize_safe_memory_access.cc (12)
buffer(87-95)buffer(87-87)buffer(98-138)buffer(98-99)buffer(256-260)buffer(256-256)buffer(262-265)buffer(262-262)buffer(267-270)buffer(267-267)buffer(272-277)buffer(272-272)
tilelang/intrinsics/wgmma_macro_generator.py (4)
src/transform/lower_tile_op.cc (8)
access_ptr(287-385)access_ptr(288-290)buffer(272-280)buffer(272-272)buffer(401-418)buffer(401-401)buffer(420-437)buffer(420-420)tilelang/language/tir/entry.py (1)
macro(66-117)tilelang/intrinsics/tcgen05_macro_generator.py (1)
_warp_mma(281-338)tilelang/language/builtin.py (1)
initialize_wgmma_descriptor(700-727)
tilelang/language/builtin.py (4)
src/op/builtin.h (1)
tvm(13-575)src/transform/loop_vectorize.cc (2)
indices(157-189)indices(157-157)tilelang/language/tir/op.py (4)
call_intrin(120-145)tvm_access_ptr(651-676)address_of(464-480)type_annotation(635-648)tilelang/language/utils.py (1)
region(8-27)
src/op/gemm_py.h (3)
src/op/gemm.cc (8)
checkWgmma(434-484)checkWgmma(434-434)allowTcgen5Mma(186-193)allowTcgen5Mma(186-186)allowWgmma(195-203)allowWgmma(195-195)getGemmInst(205-220)getGemmInst(205-205)src/op/gemm_py.cc (8)
checkWgmma(255-305)checkWgmma(255-255)allowTcgen5Mma(186-193)allowTcgen5Mma(186-186)allowWgmma(195-203)allowWgmma(195-195)getGemmInst(205-223)getGemmInst(205-205)src/op/gemm.h (2)
RegisterReflection(35-41)RegisterReflection(106-128)
tilelang/tileop/gemm/gemm_base.py (1)
tilelang/tileop/gemm/__init__.py (13)
N(93-94)K(97-98)trans_A(101-102)trans_B(105-106)B(69-70)C(73-74)stride_A(109-110)stride_B(113-114)offset_A(117-118)offset_B(121-122)clear_accum(125-126)k_pack(129-130)wg_wait(133-134)
src/op/gemm.h (3)
src/op/gemm.cc (10)
computeWarpPartition(222-402)computeWarpPartition(222-223)checkWgmma(434-484)checkWgmma(434-434)getGemmInst(205-220)getGemmInst(205-205)allowTcgen5Mma(186-193)allowTcgen5Mma(186-186)allowWgmma(195-203)allowWgmma(195-195)src/op/gemm_sp.cc (2)
computeWarpPartition(21-63)computeWarpPartition(21-25)src/op/gemm_py.h (1)
RegisterReflection(43-67)
tilelang/layout/swizzle.py (2)
src/transform/lower_tile_op.cc (6)
buffer(272-280)buffer(272-272)buffer(401-418)buffer(401-401)buffer(420-437)buffer(420-420)tilelang/language/ast/ir.py (1)
buffer(93-161)
src/op/gemm.cc (4)
src/op/gemm_py.cc (13)
strides(80-80)NormalizeToBufferRegion(24-69)NormalizeToBufferRegion(24-24)allowTcgen5Mma(186-193)allowTcgen5Mma(186-186)allowWgmma(195-203)allowWgmma(195-195)checkWgmma(255-305)checkWgmma(255-255)getGemmInst(205-223)getGemmInst(205-205)MakeAccessPtrFromRegion(74-102)MakeAccessPtrFromRegion(74-74)src/op/gemm.h (5)
Gemm(144-149)GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)tilelang/ir.py (2)
Gemm(43-44)GemmWarpPolicy(30-39)src/op/operator.cc (2)
GetVarFromAccessPtr(74-81)GetVarFromAccessPtr(74-74)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py
6-6: Redefinition of unused tvm from line 2
Remove definition: tvm
(F811)
298-298: Ambiguous variable name: l
(E741)
tilelang/language/gemm.py
137-137: Avoid specifying long messages outside the exception class
(TRY003)
458-458: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/tcgen05_macro_generator.py
266-266: Prefer TypeError exception for invalid type
(TRY004)
266-266: Avoid specifying long messages outside the exception class
(TRY003)
278-278: Prefer TypeError exception for invalid type
(TRY004)
278-278: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_wgmma.py
94-94: Local variable A_offsets is assigned to but never used
Remove assignment to unused variable A_offsets
(F841)
95-95: Local variable B_offsets is assigned to but never used
Remove assignment to unused variable B_offsets
(F841)
96-96: Local variable C_offsets is assigned to but never used
Remove assignment to unused variable C_offsets
(F841)
102-102: Local variable C_region is assigned to but never used
Remove assignment to unused variable C_region
(F841)
tilelang/utils/language.py
30-30: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/wgmma_macro_generator.py
266-266: Undefined name Union
(F821)
275-275: Undefined name tir
(F821)
277-277: Undefined name tir
(F821)
279-279: Undefined name tir
(F821)
282-282: Prefer TypeError exception for invalid type
(TRY004)
282-282: Avoid specifying long messages outside the exception class
(TRY003)
282-282: Undefined name indices
(F821)
294-294: Prefer TypeError exception for invalid type
(TRY004)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
301-301: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tilelang/language/builtin.py
466-466: Avoid specifying long messages outside the exception class
(TRY003)
482-482: Avoid specifying long messages outside the exception class
(TRY003)
514-514: Avoid specifying long messages outside the exception class
(TRY003)
543-543: Avoid specifying long messages outside the exception class
(TRY003)
572-572: Do not catch blind exception: Exception
(BLE001)
575-576: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/layout/swizzle.py
29-29: Avoid specifying long messages outside the exception class
(TRY003)
89-89: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
106-106: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🔇 Additional comments (10)
tilelang/language/builtin.py (3)
11-11: LGTM!The
BufferRegionimport is correctly added to support the new handling branch.
443-456: LGTM!The docstring clearly documents the expanded accepted types and dtype inference behavior.
509-552: LGTM!The BufferRegion handling correctly computes offsets from region mins and includes a nice enhancement to infer
num_regsfrom fully static region extents. The fallback to requiring explicitnum_regsfor symbolic extents is appropriate.tilelang/tileop/gemm/gemm_tcgen05.py (1)
111-112: LGTM: Region-based operands for TCGEN5.The switch from
self.A/self.Btoself.ARegion/self.BRegioncorrectly aligns with the broader refactoring toward region-based addressing for shared-memory operands.tilelang/tileop/gemm/gemm_mma.py (1)
86-89: LGTM: Region-based input selection for shared memory.The conditional selection of
ARegion/BRegionwhen buffers are in shared memory correctly supports strided/offset tiles while maintaining backward compatibility for non-shared buffers.src/transform/lower_tile_op.cc (1)
643-644: LGTM: Simplified GEMM lowering path.The removal of
buffer_var_gemm_fromLowerArgscorrectly reflects the elimination of the GEMM-buffer collection step, aligning with the broader move to region-based GEMM handling.tilelang/tileop/gemm/gemm_mfma.py (1)
87-89: LGTM: Consistent region-based input handling.The conditional selection of
ARegion/BRegionmatches the pattern used ingemm_mma.py, correctly supporting region-based addressing for shared-memory operands.tilelang/intrinsics/mma_macro_generator.py (1)
236-246: LGTM: Clean BufferRegion support.The logic to unwrap
BufferRegioninto the underlying buffer and base offsets is well-structured and correctly handles bothBufferandBufferRegioninputs uniformly.tilelang/language/gemm.py (1)
493-493: LGTM: Environment-driven GEMM selection.The conditional assignment of
gemmbased on_env.use_gemm_v1()correctly implements the environment-driven selection mechanism described in the PR objectives, defaulting to v2 while allowing v1 to be forced viaTILELANG_USE_GEMM_V1.tilelang/intrinsics/mma_sm70_macro_generator.py (1)
209-235: BufferRegion offset handling aligns with the existing pattern.The base-offset extraction mirrors the mfma path, so the SM70 macro can now consume BufferRegion sources without surprises.
| PrimExpr offset = make_const(buf->shape[0].dtype(), 0); | ||
| for (int i = 0; i < ndim - 2; ++i) { | ||
| offset = offset + region->region[i]->min * strides[i]; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix pointer offset when rebuilding access_ptr
MakeAccessPtrFromRegion never adds the minima of the last two dimensions to offset. For a 2D region such as A[i0:i0+16, j0:j0+8], the loop runs zero times, so we always hand the kernels a pointer to A[0,0]. Every GEMM tile then reads/writes the wrong slice. Please accumulate the minima for all dimensions (same change is needed in gemm.cc).
- for (int i = 0; i < ndim - 2; ++i) {
+ for (int i = 0; i < ndim; ++i) {
offset = offset + region->region[i]->min * strides[i];
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| PrimExpr offset = make_const(buf->shape[0].dtype(), 0); | |
| for (int i = 0; i < ndim - 2; ++i) { | |
| offset = offset + region->region[i]->min * strides[i]; | |
| } | |
| PrimExpr offset = make_const(buf->shape[0].dtype(), 0); | |
| for (int i = 0; i < ndim; ++i) { | |
| offset = offset + region->region[i]->min * strides[i]; | |
| } |
| ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); | ||
| } | ||
| return BufferRegion(buf, ranges); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix pointer offset when rebuilding access_ptr
Identical to the gemm_py.cc case: skipping the last two dimensions in the offset sum makes every reconstructed pointer land at element (0,0) for typical 2D regions, so all GEMM tiles operate on the wrong data. Please include the minima of every dimension when computing offset.
- for (int i = 0; i < ndim - 2; ++i) {
+ for (int i = 0; i < ndim; ++i) {
offset = offset + region->region[i]->min * strides[i];
}Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/op/gemm.cc around lines 89 to 92, the pointer offset computation when
rebuilding access_ptr incorrectly skips the last two dimensions causing all
reconstructed pointers to point to element (0,0); fix it by including the minima
for every dimension in the offset calculation (i.e., do not skip the last two
dims), summing (index - min) * stride for each dimension so the computed offset
accounts for all dims’ minima, and use that full offset when constructing the
rebuilt access_ptr.
| thread_binding = self.get_thread_binding() | ||
|
|
||
| def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Union[Buffer, BufferLoad, BufferRegion], access_type: str = "r") -> PrimExpr: | ||
| if isinstance(buffer_or_load_or_region, Buffer): | ||
| return buffer_or_load_or_region.access_ptr(access_type) | ||
| elif isinstance(buffer_or_load_or_region, BufferLoad): | ||
| buffer_load = buffer_or_load_or_region | ||
| offset, stride = 0, 1 | ||
| buffer = buffer_load.buffer | ||
| for i, shape in enumerate(reversed(buffer.shape)): | ||
| indice = buffer_load.indices[len(buffer_load.indices) - i - 1] | ||
| if isinstance(indice, tir.IntImm): | ||
| offset += indice * stride | ||
| elif isinstance(indice, tir.PrimExpr): | ||
| offset += indice * stride | ||
| elif isinstance(indice, tir.Ramp): | ||
| offset += indice.base * stride | ||
| else: | ||
| raise ValueError(f"Unsupported index type: {type(indices)}") | ||
| stride *= shape | ||
| return buffer.access_ptr(access_type, offset=offset) | ||
| elif isinstance(buffer_or_load_or_region, BufferRegion): | ||
| buffer_region = buffer_or_load_or_region | ||
| buffer = buffer_region.buffer | ||
| offset, stride = 0, 1 | ||
| for i, shape in enumerate(reversed(buffer.shape)): | ||
| offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride | ||
| stride *= shape | ||
| return buffer.access_ptr(access_type, offset=offset) | ||
| else: | ||
| raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") | ||
|
|
||
| A_ptr = retrive_ptr_from_buffer_region(A_buf) | ||
| B_ptr = retrive_ptr_from_buffer_region(B_buf) | ||
|
|
||
| @T.macro | ||
| def _warp_mma(A_buf, B_buf, C_local_buf): | ||
| def _warp_mma(): | ||
| tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) | ||
|
|
||
| desc_a = T.alloc_wgmma_desc() | ||
| desc_b = T.alloc_wgmma_desc() | ||
| T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, | ||
| T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, | ||
| int(a_leading_byte_offset >> 4), | ||
| int(a_stride_byte_offset >> 4)) | ||
| T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, | ||
| T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, | ||
| int(b_leading_byte_offset >> 4), | ||
| int(b_stride_byte_offset >> 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix NameError when resolving BufferLoad offsets
retrive_ptr_from_buffer_region calls tir.IntImm / tir.PrimExpr, but tir is never imported in this module. The first BufferLoad you run through this helper will raise a NameError, so the GEMM flow regresses for the very BufferRegion cases this PR is enabling. Please import tir (and Union, since it’s now part of the signature) so the helper can execute.
-from typing import Callable
+from typing import Callable, Union
-from tvm import DataType
+from tvm import DataType, tir
…
- if isinstance(indice, tir.IntImm):
+ if isinstance(indice, tir.IntImm):
…
- elif isinstance(indice, tir.PrimExpr):
+ elif isinstance(indice, tir.PrimExpr):Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.3)
266-266: Undefined name Union
(F821)
275-275: Undefined name tir
(F821)
277-277: Undefined name tir
(F821)
279-279: Undefined name tir
(F821)
282-282: Prefer TypeError exception for invalid type
(TRY004)
282-282: Avoid specifying long messages outside the exception class
(TRY003)
282-282: Undefined name indices
(F821)
294-294: Prefer TypeError exception for invalid type
(TRY004)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
301-301: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 264 to 310, the
helper retrive_ptr_from_buffer_region references tir and Union but the module
never imports them; add the necessary imports (e.g., add "from typing import
Union" and "from tvm import tir" or "import tvm.tir as tir" at the top of the
file) so the type checks against tir.IntImm/tir.PrimExpr/tir.Ramp and the Union
annotation work without raising NameError.
- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`. - Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types. - Simplified buffer access logic for better clarity and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/intrinsics/wgmma_macro_generator.py (1)
371-371: Fix undefined B_buf reference.Line 371 references
B_bufwhich doesn't exist inwgmma_rs. The parameter isB_region, so this should beB_region.bufferor justB_regiondepending on what_determinate_swizzle_modeexpects.Apply this diff:
- b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
♻️ Duplicate comments (2)
tilelang/tileop/gemm/gemm_wgmma.py (1)
90-102: Unused offset and region variables.The variables
A_offsets,B_offsets,C_offsets, andC_regionare assigned but never used in the subsequent code. While the comments indicate these are placeholders for future offset handling through BufferRegion, they currently serve no purpose and create noise.Based on past reviews, consider either:
- Removing these assignments until they're actually needed
- Adding a clear TODO comment explaining when/how they'll be used
- Commenting them out if they document the intended future API
tilelang/intrinsics/wgmma_macro_generator.py (1)
4-7: Import Union and tir to fix NameError.The helper function
retrive_ptr_from_buffer_regionat line 266 usesUnionin its type hint and referencestir.IntImm,tir.PrimExpr, andtir.Rampwithout importing them. This will cause aNameErrorat runtime.Apply this diff to add the missing imports:
-from typing import Callable +from typing import Callable, Union -from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferLoad, BufferRegion +from tvm.tir import PrimExpr, Buffer, Var, IndexMap, BufferLoad, BufferRegion +from tvm import tirBased on past reviews.
🧹 Nitpick comments (2)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
55-55: Consider removing redundantdisable_cache()call.This call is redundant since
disable_cache()is already invoked at the module level (line 4). While not harmful, removing one of these calls would reduce duplication.Apply this diff if you prefer to keep only the function-level call:
-tilelang.disable_cache() # add decorator @tilelang.jit if you want to return a torch functiontilelang/intrinsics/wgmma_macro_generator.py (1)
266-266: Fix typo in helper function name.The function name
retrive_ptr_from_buffer_regionhas a typo—it should beretrieve_ptr_from_buffer_region.When refactoring per the previous comments, correct the spelling:
-def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Union[Buffer, BufferLoad, BufferRegion], access_type: str = "r") -> PrimExpr: +def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Union[Buffer, BufferLoad, BufferRegion], access_type: str = "r") -> PrimExpr:And update all call sites accordingly.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py(2 hunks)testing/python/dynamic/test_tilelang_dynamic_symbolic.py(1 hunks)tilelang/intrinsics/mma_macro_generator.py(9 hunks)tilelang/intrinsics/wgmma_macro_generator.py(8 hunks)tilelang/tileop/gemm/gemm_wgmma.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/intrinsics/mma_macro_generator.py
🧰 Additional context used
🧬 Code graph analysis (3)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
tilelang/env.py (1)
disable_cache(275-276)
tilelang/intrinsics/wgmma_macro_generator.py (4)
tilelang/utils/language.py (1)
is_fragment(107-118)src/transform/lower_tile_op.cc (8)
access_ptr(287-385)access_ptr(288-290)buffer(272-280)buffer(272-272)buffer(401-418)buffer(401-401)buffer(420-437)buffer(420-420)tilelang/language/utils.py (1)
region(8-27)tilelang/language/builtin.py (2)
initialize_wgmma_descriptor(700-727)warpgroup_fence_operand(433-589)
tilelang/tileop/gemm/gemm_wgmma.py (3)
tilelang/tileop/gemm/gemm_base.py (11)
A_base_offsets(152-154)B_base_offsets(157-159)C_base_offsets(162-164)ARegion(79-80)BRegion(83-84)CRegion(87-88)clear_accum(107-108)wg_wait(115-116)is_gemm_ss(21-22)is_gemm_rs(27-28)A(67-68)tilelang/intrinsics/wgmma_macro_generator.py (1)
wgmma(163-340)tilelang/transform/simplify.py (1)
_Simplify(31-49)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py
266-266: Undefined name Union
(F821)
275-275: Undefined name tir
(F821)
277-277: Undefined name tir
(F821)
279-279: Undefined name tir
(F821)
282-282: Prefer TypeError exception for invalid type
(TRY004)
282-282: Avoid specifying long messages outside the exception class
(TRY003)
282-282: Undefined name indices
(F821)
294-294: Prefer TypeError exception for invalid type
(TRY004)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
302-302: Unused function argument: A_region
(ARG001)
302-302: Unused function argument: B_region
(ARG001)
303-303: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
409-409: Undefined name retrive_ptr_from_buffer_region
(F821)
414-414: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
414-414: Unpacked variable warp_m is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
417-417: Undefined name B_buf
(F821)
419-419: Undefined name B_buf
(F821)
422-422: Undefined name B_buf
(F821)
tilelang/tileop/gemm/gemm_wgmma.py
94-94: Local variable A_offsets is assigned to but never used
Remove assignment to unused variable A_offsets
(F841)
95-95: Local variable B_offsets is assigned to but never used
Remove assignment to unused variable B_offsets
(F841)
96-96: Local variable C_offsets is assigned to but never used
Remove assignment to unused variable C_offsets
(F841)
🔇 Additional comments (3)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
4-4: Verify the necessity of module-level cache disabling.Calling
disable_cache()at module import time affects global state and will disable caching for any code that imports this example module. For a standalone example script, the function-level call at line 55 would typically be sufficient.If the module-level call is intentional (e.g., to ensure cache is disabled regardless of how this example is used), consider adding a comment explaining the rationale.
60-61: Good addition for example clarity.Printing the kernel source is helpful for users learning about kernel generation and debugging. This makes the generated CUDA code immediately visible during execution.
Note: Lines 78-80 contain commented-out code that performs similar kernel source printing, which is now redundant but can be left as-is or removed for cleanliness.
tilelang/intrinsics/wgmma_macro_generator.py (1)
302-302: Unused macro parameters should be prefixed with underscore.The
_warp_mmamacro parametersA_regionandB_regionare marked as unused by static analysis. However, these are passed through from the outer scope and used via closure (the computedA_ptr,B_ptr,C_bufare captured). This is a false positive from the linter.
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
| # tilelang.testing.main() | ||
| assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test harness bypass breaks test discovery and coverage.
Replacing tilelang.testing.main() with a single direct assertion prevents the test framework from discovering and executing the other test functions defined in this module (test_assert_tl_matmul_block, test_assert_tl_matmul_block_all_dynamic, test_assert_tl_matmul_block_all_dynamic_with_pass_config). This significantly reduces test coverage.
If this change is intentional for debugging, consider reverting it before merging. Otherwise, restore the test harness:
if __name__ == "__main__":
- # tilelang.testing.main()
- assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
+ tilelang.testing.main()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if __name__ == "__main__": | |
| tilelang.testing.main() | |
| # tilelang.testing.main() | |
| assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") | |
| if __name__ == "__main__": | |
| tilelang.testing.main() |
🤖 Prompt for AI Agents
In testing/python/dynamic/test_tilelang_dynamic_symbolic.py around lines
516-518, the test entrypoint was changed to call
assert_tl_matmul_macro_correctness(...) directly which bypasses
tilelang.testing.main() and breaks test discovery/coverage; restore the test
harness by removing the direct assertion and calling tilelang.testing.main() (or
re-enable the original commented call), or if this direct call is only needed
for temporary debugging wrap it behind an explicit debug guard (e.g. env var) so
the normal tilelang.testing.main() runs by default.
| elif isinstance(indice, tir.Ramp): | ||
| offset += indice.base * stride | ||
| else: | ||
| raise ValueError(f"Unsupported index type: {type(indices)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined variable name.
Line 282 references indices which is undefined. The variable name should be indice (the loop variable from line 274).
Apply this diff:
- raise ValueError(f"Unsupported index type: {type(indices)}")
+ raise ValueError(f"Unsupported index type: {type(indice)}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| raise ValueError(f"Unsupported index type: {type(indices)}") | |
| raise ValueError(f"Unsupported index type: {type(indice)}") |
🧰 Tools
🪛 Ruff (0.14.3)
282-282: Prefer TypeError exception for invalid type
(TRY004)
282-282: Avoid specifying long messages outside the exception class
(TRY003)
282-282: Undefined name indices
(F821)
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around line 282, the code raises
a ValueError referencing an undefined variable `indices`; change the reference
to the loop variable `indice` (from the for-loop that starts at line 274) so the
exception message uses the correct variable name.
…andling - Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices. - Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion. - Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability. - Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.
- Cleaned up formatting and spacing in GEMM-related files for better readability. - Standardized comments and code structure across various GEMM functions and macros. - Enhanced error messages for clarity in buffer region checks. - Removed redundant lines and improved overall code maintainability.
…unctionality - Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests. - Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing. - Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access. - Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.
- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code. - Adjusted function signature formatting in `swizzle.py` for better readability. - Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation. - Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness. - Improved function signature formatting in `language.py` for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
♻️ Duplicate comments (2)
src/op/gemm_py.cc (1)
95-102: Offset must include every dimension.Skipping the last two dimensions when accumulating
region->region[i]->min * strides[i]recreates the old bug: any 2D tile (the common case) rebuilds its pointer at(0,0), so every kernel reads/writes the wrong slice. Please sum minima across all axes when computingoffset.- for (int i = 0; i < ndim - 2; ++i) { + for (int i = 0; i < ndim; ++i) { offset = offset + region->region[i]->min * strides[i]; }src/op/gemm.cc (1)
123-130: Same offset bug in GEMM path.Identical to the issue flagged in
gemm_py.cc: stopping the minima accumulation atndim - 2resets the pointer for typical 2D regions, so every tile operates on the wrong data. The offset must incorporate all dimensions.- for (int i = 0; i < ndim - 2; ++i) { + for (int i = 0; i < ndim; ++i) { offset = offset + region->region[i]->min * strides[i]; }
🧹 Nitpick comments (12)
src/transform/lower_tile_op.cc (2)
445-452: Use Array for PTX op filtering to avoid type driftThe container holds RelaxExpr but you compare against op->op (Op). Make this Array for type clarity and safer equality.
Apply:
- Array<RelaxExpr> ptx_instructions = {builtin::ptx_ldmatrix(), - builtin::mma_store()}; + Array<Op> ptx_instructions = {builtin::ptx_ldmatrix(), + builtin::mma_store()};
642-646: Introduce a named constructor for LowerArgs and replace the aggregate‐initializerAggregate‐initializing LowerArgs at src/transform/lower_tile_op.cc (lines 642–646) is brittle—any change to the field order in operator.h will silently break this call.
- In src/op/operator.h, add a constructor to LowerArgs:
struct LowerArgs { Target target; Range thread_bounds; Var thread_var; AddWorkspaceCallback AddWorkspace; LayoutMap layout_map; Map<Buffer, Buffer> buffer_remap; LowerArgs(Target t, Range bounds, Var var, AddWorkspaceCallback cb, LayoutMap lm, Map<Buffer, Buffer> remap) : target(t), thread_bounds(bounds), thread_var(var), AddWorkspace(cb), layout_map(lm), buffer_remap(remap) {} };- In src/transform/lower_tile_op.cc, update the call site:
- auto lowered = - tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var, - callback, layout_map_, buffer_remap_}, - analyzer_); + auto lowered = + tile_op->Lower(LowerArgs(target_, + thread_bounds, + thread_var_->var, + callback, + layout_map_, + buffer_remap_), + analyzer_);No other aggregate‐init sites or BufferGemmCollector remnants were found.
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (2)
57-57: Gate cache disabling behind an env/flag to avoid skewing profiling.Disabling cache forces recompilation and can distort profiler results. Gate it with an env var so users can opt in.
Apply:
- tilelang.disable_cache() + if os.getenv("TILELANG_DISABLE_CACHE", "0") == "1": + tilelang.disable_cache()Add once near the imports:
import os # for env-gated debug switches
63-64: Guard kernel source dump behind an env flag
The call toprint(jit_kernel.get_kernel_source())can flood stdout; wrap it with an environment check.
Apply:- print(jit_kernel.get_kernel_source()) + if os.getenv("TILELANG_DEBUG_SRC", "0") == "1": + print(jit_kernel.get_kernel_source())tilelang/layout/swizzle.py (4)
6-6: Import PrimExpr for accurate type hints and usageAdd PrimExpr to the tvm.tir import to support corrected annotations below.
-from tvm.tir import Buffer, BufferLoad, BufferRegion +from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr
49-61: Confirm lane-aware element size for vector dtypes
tvm.DataType(dtype).bitsreturns per-lane bits. If dtypes with lanes (e.g., float16x8) can appear here, considerbits * lanes()or document that a scalar dtype is required and validated earlier.
94-100: Use Optional typing for continuity and allow PrimExprFix implicit Optional (RUF013) and accept PrimExpr continuity.
-def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - continuity: int = None, +from typing import Optional # add near imports if not present + +def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, + continuity: Optional[PrimExpr] = None, k_major: bool = True): @@ - if continuity is None: + if continuity is None: continuity = continuous
110-117: Same typing fix for TCGEN05MMA continuityMirror the Optional[PrimExpr] change here.
-def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, - continuity: int = None, +def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, + continuity: Optional[PrimExpr] = None, k_major: bool = True): @@ - if continuity is None: + if continuity is None: continuity = continuoustilelang/utils/language.py (4)
245-245: Use TypeError for invalid type; shorten messages (TRY004, TRY003)For type mismatches, raise
TypeErrorand keep messages brief to satisfy linters.- raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}") + raise TypeError(f"Unsupported type: {type(obj)}") @@ - raise ValueError(f"Unsupported index type: {type(indice)}") + raise TypeError(f"Unsupported index type: {type(indice)}") @@ - raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") @@ - raise ValueError(f"Unsupported retrieve_ptr argument type: {type(obj)} for object {obj}") + raise TypeError(f"Unsupported type: {type(obj)}") @@ - raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}") + raise TypeError(f"Unsupported type: {type(obj)}")Also applies to: 270-270, 282-282, 329-329, 349-349
214-232: Unify scalar BufferLoad handling withto_buffer_regionfallbackInstead of erroring on scalar
BufferLoad, derive a 1-sized region viato_buffer_regionfor consistency.- if isinstance(obj, tir.BufferLoad): - region = get_buffer_region_from_load(obj) - if region is None: - raise ValueError("Cannot retrieve shape from scalar BufferLoad without region") - return [r.extent for r in region.region] + if isinstance(obj, tir.BufferLoad): + region = to_buffer_region(obj) + return [r.extent for r in region.region]
45-61: Simplify boolean logic in is_sharedMinor readability cleanup; no behavior change.
- conditions = [False] - conditions.append(buffer.scope() == "shared") - if allow_dynamic: - conditions.append(is_shared_dynamic(buffer)) - return any(conditions) + return (buffer.scope() == "shared") or (allow_dynamic and is_shared_dynamic(buffer))
352-368: Prefer the single helperprim_expr_equalyou definedOptionally use
prim_expr_equalin other helpers (e.g.,is_full_region) for consistency; current usage is fine.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (22)
.gitignore(1 hunks)examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py(2 hunks)maint/gemm_v2/correctness_evaluation.py(7 hunks)maint/gemm_v2/correctness_evaluation_sm70.py(1 hunks)src/op/gemm.cc(10 hunks)src/op/gemm_py.cc(7 hunks)src/op/gemm_sp.cc(4 hunks)src/transform/lower_tile_op.cc(1 hunks)tilelang/intrinsics/mfma_macro_generator.py(7 hunks)tilelang/intrinsics/mma_macro_generator.py(9 hunks)tilelang/intrinsics/mma_sm70_macro_generator.py(6 hunks)tilelang/intrinsics/tcgen05_macro_generator.py(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py(9 hunks)tilelang/language/builtin.py(3 hunks)tilelang/language/gemm.py(5 hunks)tilelang/layout/swizzle.py(5 hunks)tilelang/tileop/gemm/__init__.py(2 hunks)tilelang/tileop/gemm/gemm_mma.py(6 hunks)tilelang/tileop/gemm/gemm_mma_sm70.py(4 hunks)tilelang/tileop/gemm/gemm_wgmma.py(3 hunks)tilelang/utils/__init__.py(1 hunks)tilelang/utils/language.py(3 hunks)
✅ Files skipped from review due to trivial changes (1)
- .gitignore
🚧 Files skipped from review as they are similar to previous changes (2)
- tilelang/tileop/gemm/gemm_wgmma.py
- tilelang/tileop/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/op/gemm.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/op/gemm.cc
🧬 Code graph analysis (17)
tilelang/language/gemm.py (2)
tilelang/utils/language.py (6)
to_buffer_region(187-211)retrieve_shape(214-231)retrieve_stride(234-252)retrieve_ptr(285-329)retrieve_offset(332-349)prim_expr_equal(352-367)tilelang/env.py (2)
get(175-178)use_gemm_v1(281-287)
tilelang/tileop/gemm/gemm_mma_sm70.py (4)
tilelang/utils/language.py (3)
is_shared(45-60)is_fragment(105-116)is_full_region(370-399)tilelang/tileop/gemm/gemm_base.py (4)
ARegion(79-80)BRegion(83-84)CRegion(87-88)is_gemm_rs(27-28)tilelang/intrinsics/mma_sm70_macro_generator.py (2)
ldmatrix_b(234-284)mma(286-327)tilelang/tileop/gemm/gemm_mma.py (2)
_gemm_ssr(103-128)is_gemm_rs(218-219)
tilelang/utils/__init__.py (1)
tilelang/utils/language.py (5)
retrieve_stride(234-252)retrieve_shape(214-231)retrive_ptr_from_buffer_region(255-282)is_full_region(370-399)to_buffer_region(187-211)
tilelang/intrinsics/mma_macro_generator.py (4)
tilelang/intrinsics/utils.py (1)
get_ldmatrix_offset(21-63)tilelang/utils/language.py (2)
is_fragment(105-116)to_buffer_region(187-211)tilelang/intrinsics/mma_sm70_macro_generator.py (3)
extract_thread_binding(158-188)_warp_ldmatrix_a(216-230)ldmatrix_b(234-284)tilelang/intrinsics/mfma_macro_generator.py (4)
extract_thread_binding(225-252)_warp_ldmatrix_a(277-300)ldmatrix_b(304-357)ldmatrix_b(794-869)
examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py (1)
tilelang/env.py (1)
disable_cache(275-276)
tilelang/tileop/gemm/gemm_mma.py (3)
tilelang/utils/language.py (3)
is_shared(45-60)is_fragment(105-116)is_full_region(370-399)tilelang/tileop/gemm/gemm_base.py (6)
ARegion(79-80)BRegion(83-84)CRegion(87-88)is_gemm_sr(24-25)is_gemm_rs(27-28)is_gemm_rr(30-31)tilelang/intrinsics/mma_macro_generator.py (6)
ldmatrix_b(285-368)ldmatrix_b(892-1004)mma(370-430)mma(1006-1055)mma(1060-1158)mma(1163-1262)
tilelang/layout/swizzle.py (2)
src/layout/swizzle.h (1)
tvm(12-69)tilelang/language/ast/ir.py (1)
buffer(93-161)
tilelang/language/builtin.py (2)
tilelang/language/ast/ir.py (1)
evaluate(1319-1331)tilelang/language/tir/op.py (3)
tvm_access_ptr(651-676)address_of(464-480)type_annotation(635-648)
src/op/gemm_sp.cc (4)
src/op/gemm.cc (2)
computeWarpPartition(228-408)computeWarpPartition(228-229)src/op/operator.cc (2)
GetVarFromAccessPtr(74-81)GetVarFromAccessPtr(74-74)src/op/gemm_sp.h (4)
GemmSPWarpPolicy(28-52)GemmSPWarpPolicy(33-37)GemmSPWarpPolicy(39-43)GemmSPWarpPolicy(45-51)src/layout/gemm_layouts.cc (10)
makeGemmFragmentCHopper(176-187)makeGemmFragmentCHopper(176-178)makeGemmFragmentC(121-136)makeGemmFragmentC(121-123)makeGemmABLayoutHopper(741-766)makeGemmABLayoutHopper(741-742)makeGemmSparseFragmentC(138-157)makeGemmSparseFragmentC(138-140)makeGemmSparseAmpereABLayout(683-688)makeGemmSparseAmpereABLayout(683-684)
tilelang/intrinsics/tcgen05_macro_generator.py (1)
src/transform/lower_tile_op.cc (8)
access_ptr(287-385)access_ptr(288-290)buffer(272-280)buffer(272-272)buffer(401-418)buffer(401-401)buffer(420-437)buffer(420-420)
maint/gemm_v2/correctness_evaluation.py (2)
maint/gemm_v2/correctness_evaluation_sm70.py (1)
run_gemm_rs(180-210)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)
run_gemm_rs(191-242)
src/op/gemm_py.cc (3)
src/op/gemm.cc (10)
NormalizeToBufferRegion(53-100)NormalizeToBufferRegion(53-54)allowTcgen5Mma(192-199)allowTcgen5Mma(192-192)allowWgmma(201-209)allowWgmma(201-201)checkWgmma(440-491)checkWgmma(440-440)getGemmInst(211-226)getGemmInst(211-211)tilelang/ir.py (1)
GemmWarpPolicy(30-39)src/op/operator.cc (2)
GetVarFromAccessPtr(74-81)GetVarFromAccessPtr(74-74)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
tilelang/utils/language.py (2)
is_fragment(105-116)to_buffer_region(187-211)tilelang/intrinsics/mfma_macro_generator.py (4)
_warp_ldmatrix_a(277-300)ldmatrix_b(304-357)ldmatrix_b(794-869)_warp_ldmatrix_b(327-355)tilelang/intrinsics/mma_macro_generator.py (8)
_warp_ldmatrix_a(244-281)_warp_ldmatrix_a(800-888)ldmatrix_b(285-368)ldmatrix_b(892-1004)mma_load_layout(223-224)mma_load_layout(310-311)_warp_ldmatrix_b(322-366)_warp_ldmatrix_b(907-1002)
src/op/gemm.cc (4)
tilelang/language/utils.py (1)
region(8-27)src/op/gemm_py.cc (13)
strides(85-85)NormalizeToBufferRegion(25-72)NormalizeToBufferRegion(25-26)allowTcgen5Mma(192-199)allowTcgen5Mma(192-192)allowWgmma(201-209)allowWgmma(201-201)checkWgmma(261-312)checkWgmma(261-261)getGemmInst(211-229)getGemmInst(211-211)MakeAccessPtrFromRegion(78-108)MakeAccessPtrFromRegion(78-79)src/op/gemm.h (1)
Gemm(144-149)src/op/operator.cc (2)
GetVarFromAccessPtr(74-81)GetVarFromAccessPtr(74-74)
tilelang/utils/language.py (4)
src/transform/lower_tile_op.cc (8)
buffer(272-280)buffer(272-272)buffer(401-418)buffer(401-401)buffer(420-437)buffer(420-420)access_ptr(287-385)access_ptr(288-290)src/transform/flatten_buffer.cc (6)
buffer(295-321)buffer(295-296)region(335-361)region(335-335)buf(227-247)buf(227-227)src/transform/storage_rewrite.cc (9)
e(862-904)e(862-862)i(720-860)buf(249-256)buf(249-249)buf(509-526)buf(509-509)buf(1740-1765)buf(1740-1740)tilelang/language/utils.py (1)
region(8-27)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
is_fragment(105-116)retrive_ptr_from_buffer_region(255-282)is_full_region(370-399)tilelang/language/builtin.py (2)
initialize_wgmma_descriptor(703-730)warpgroup_fence_operand(433-592)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/intrinsics/mma_sm70_macro_generator.py (3)
ldmatrix_a(190-232)_warp_ldmatrix_a(216-230)ldmatrix_b(234-284)tilelang/intrinsics/mma_macro_generator.py (6)
ldmatrix_a(207-283)ldmatrix_a(786-890)_warp_ldmatrix_a(244-281)_warp_ldmatrix_a(800-888)ldmatrix_b(285-368)ldmatrix_b(892-1004)
🪛 Ruff (0.14.3)
tilelang/utils/__init__.py
13-13: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
16-16: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
17-17: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/layout/swizzle.py
28-29: Avoid specifying long messages outside the exception class
(TRY003)
94-94: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
111-111: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/language/builtin.py
466-466: Avoid specifying long messages outside the exception class
(TRY003)
482-482: Avoid specifying long messages outside the exception class
(TRY003)
514-514: Avoid specifying long messages outside the exception class
(TRY003)
543-545: Avoid specifying long messages outside the exception class
(TRY003)
574-574: Do not catch blind exception: Exception
(BLE001)
577-579: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/tcgen05_macro_generator.py
264-264: Prefer TypeError exception for invalid type
(TRY004)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
276-276: Prefer TypeError exception for invalid type
(TRY004)
276-276: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/utils/language.py
27-28: Avoid specifying long messages outside the exception class
(TRY003)
211-211: Avoid specifying long messages outside the exception class
(TRY003)
229-229: Avoid specifying long messages outside the exception class
(TRY003)
231-231: Avoid specifying long messages outside the exception class
(TRY003)
245-245: Prefer TypeError exception for invalid type
(TRY004)
245-245: Avoid specifying long messages outside the exception class
(TRY003)
270-270: Prefer TypeError exception for invalid type
(TRY004)
270-270: Avoid specifying long messages outside the exception class
(TRY003)
282-282: Prefer TypeError exception for invalid type
(TRY004)
282-282: Avoid specifying long messages outside the exception class
(TRY003)
329-329: Avoid specifying long messages outside the exception class
(TRY003)
349-349: Avoid specifying long messages outside the exception class
(TRY003)
384-384: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/wgmma_macro_generator.py
274-274: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
386-386: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
386-386: Unpacked variable warp_m is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tilelang/intrinsics/mfma_macro_generator.py
297-297: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (1)
tilelang/layout/swizzle.py (1)
65-76: FFI should receive PrimExpr for dims, not Python intsWith the
_get_stride_continuousfix,strideandcontinuousremain PrimExprs, which is the safer contract for TVM FFI. No further change needed in these call sites—just confirming the intent.If any FFI expects raw ints, please confirm and we can adapt by
tir.IntImmwrapping instead ofint().Also applies to: 80-89, 93-107, 110-124, 126-147, 150-171, 174-195, 198-215
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| tilelang.disable_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Avoid top-level side effects; keep a single cache toggle in main().
Calling tilelang.disable_cache() at import time affects anyone importing this module. It’s also duplicated at Line 57. Remove the top-level call and keep a single, gated call in main().
Apply this minimal change:
- tilelang.disable_cache()Optional verification (ensures a top-level wrapper exists and avoids surprises at runtime):
🏁 Script executed:
#!/bin/bash
# Verify a public wrapper exists and locate its definition(s)
rg -nP -C2 'def\s+disable_cache\s*\('
rg -nP -C2 'disable_cache\s*=\s*.*\.disable_cache'Length of output: 195
🏁 Script executed:
#!/bin/bash
set -e
# Locate tilelang package __init__.py
find . -type f -path "./tilelang/__init__.py" -print
# Show first 200 lines to check for disable_cache wrapper
sed -n '1,200p' tilelang/__init__.pyLength of output: 3702
🏁 Script executed:
#!/bin/bash
# Display lines 1–80 of the example to confirm top-level call and context
sed -n '1,80p' examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.pyLength of output: 2815
Remove top-level disable_cache call
Delete the import-time tilelang.disable_cache() (line 4) and rely solely on the call inside main() to avoid unexpected side effects on import.
Optional: Gate the print(jit_kernel.get_kernel_source()) behind a flag to prevent unneeded verbose output.
🤖 Prompt for AI Agents
In examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
around line 4, remove the top-level tilelang.disable_cache() call so importing
the module has no side effects and rely on the existing disable_cache()
invocation inside main(); additionally, optionally wrap or gate the
print(jit_kernel.get_kernel_source()) behind a verbosity or debug flag so kernel
source is only printed when explicitly requested.
| indice = buffer_load.indices[len(buffer_load.indices) - i - 1] | ||
| if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): | ||
| offset += indice * stride | ||
| elif isinstance(indice, tvm.tir.Ramp): | ||
| offset += indice.base * stride | ||
| else: | ||
| raise ValueError(f"Unsupported index type: {type(indice)}") | ||
| stride *= shape | ||
| return buffer.access_ptr(access_type, offset=offset) | ||
| elif isinstance(buffer_or_load_or_region, BufferRegion): | ||
| buffer_region = buffer_or_load_or_region | ||
| buffer = buffer_region.buffer | ||
| offset, stride = 0, 1 | ||
| for i, shape in enumerate(reversed(buffer.shape)): | ||
| offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride | ||
| stride *= shape | ||
| return buffer.access_ptr(access_type, offset=offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle Ramp indices before the generic branch.
Ramp inherits from PrimExpr, so the first isinstance(..., PrimExpr) branch always fires. That feeds a vector expression into access_ptr, which is invalid for pointer offsets and breaks vectorized tiles. Please special-case Ramp before the generic branch and validate its stride.
- if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
- offset += indice * stride
- elif isinstance(indice, tvm.tir.Ramp):
+ if isinstance(indice, tvm.tir.Ramp):
+ if not isinstance(indice.stride, tvm.tir.IntImm) or indice.stride.value != 1:
+ raise ValueError(f"Unsupported ramp stride: {indice.stride}")
+ offset += indice.base * stride
+ elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)):
offset += indice * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| indice = buffer_load.indices[len(buffer_load.indices) - i - 1] | |
| if isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): | |
| offset += indice * stride | |
| elif isinstance(indice, tvm.tir.Ramp): | |
| offset += indice.base * stride | |
| else: | |
| raise ValueError(f"Unsupported index type: {type(indice)}") | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) | |
| elif isinstance(buffer_or_load_or_region, BufferRegion): | |
| buffer_region = buffer_or_load_or_region | |
| buffer = buffer_region.buffer | |
| offset, stride = 0, 1 | |
| for i, shape in enumerate(reversed(buffer.shape)): | |
| offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) | |
| indice = buffer_load.indices[len(buffer_load.indices) - i - 1] | |
| if isinstance(indice, tvm.tir.Ramp): | |
| if not isinstance(indice.stride, tvm.tir.IntImm) or indice.stride.value != 1: | |
| raise ValueError(f"Unsupported ramp stride: {indice.stride}") | |
| offset += indice.base * stride | |
| elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): | |
| offset += indice * stride | |
| else: | |
| raise ValueError(f"Unsupported index type: {type(indice)}") | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) | |
| elif isinstance(buffer_or_load_or_region, BufferRegion): | |
| buffer_region = buffer_or_load_or_region | |
| buffer = buffer_region.buffer | |
| offset, stride = 0, 1 | |
| for i, shape in enumerate(reversed(buffer.shape)): | |
| offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) |
🧰 Tools
🪛 Ruff (0.14.3)
264-264: Prefer TypeError exception for invalid type
(TRY004)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
| a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) | ||
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass actual Buffer into _determinate_swizzle_mode
_determinate_swizzle_mode expects a Buffer. Feeding it A_region / B_region (BufferRegion) blows up as soon as the layout helpers touch .shape. This means the refactored wgmma path now crashes before issuing any instruction. Please pass region.buffer instead.
- a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout)
- b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+ a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout)
+ b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) | |
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) | |
| a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout) | |
| b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout) |
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 191-193, the code
calls _determinate_swizzle_mode(A_region, ...) and
_determinate_swizzle_mode(B_region, ...) passing BufferRegion objects;
_determinate_swizzle_mode expects a Buffer and later accesses .shape which
causes a crash. Change those calls to pass the underlying Buffer (e.g.,
A_region.buffer and B_region.buffer) so the helper receives the correct type and
the layout helpers can access .shape.
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) | ||
| b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same BufferRegion-to-Buffer fix needed here
The RS path is subject to the identical crash: _determinate_swizzle_mode still sees a BufferRegion. Please use the underlying buffer.
- b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout)
+ b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) | |
| b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( | |
| b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout) | |
| b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( |
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 342 to 343, the
call to self._determinate_swizzle_mode is being passed a BufferRegion (B_region)
which causes the same crash as earlier; change the call to pass the underlying
buffer object from the region (e.g., B_region.buffer or equivalent attribute
used elsewhere) so _determinate_swizzle_mode receives a Buffer, and ensure
related downstream uses (like computing b_swizzle_atom_elems) still reference
the region where appropriate.
| if dtype is None: | ||
| raise ValueError("dtype must be provided when passing a pointer expression.") | ||
| inferred = None | ||
| # Case 1: Pointer from Buffer.access_ptr -> tir.builtin.tvm_access_ptr | ||
| if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()): | ||
| # args[0] is a type annotation call; its dtype carries the element dtype | ||
| inferred = str(data_ptr.args[0].dtype) | ||
| # Case 2: Pointer from tir.address_of(BufferLoad(...)) | ||
| elif isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.address_of()): | ||
| # args[0] should be a BufferLoad; its dtype is the element dtype | ||
| inferred = str(data_ptr.args[0].dtype) | ||
| # Case 3: Typed pointer Var with PrimType element (typed TIR) | ||
| elif hasattr(data_ptr, "type_annotation") and data_ptr.type_annotation is not None: | ||
| try: | ||
| elem_ty = getattr(data_ptr.type_annotation, "element_type", None) | ||
| if elem_ty is not None and hasattr(elem_ty, "dtype"): | ||
| inferred = str(elem_ty.dtype) | ||
| except Exception: | ||
| inferred = None | ||
| if inferred is None: | ||
| raise ValueError( | ||
| "dtype must be provided when passing a pointer expression and cannot be inferred." | ||
| ) | ||
| dtype = inferred | ||
| if num_regs is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix dtype inference for access_ptr pointers
When buffer_or_ptr comes from buffer.access_ptr(...), data_ptr.op is tvm_access_ptr, but data_ptr.args[0].dtype is just "handle"/"object", not the underlying element dtype. We end up fencing with "handle" which breaks any call that relies on automatic dtype inference (e.g. passing a pointer without an explicit dtype now regresses). Please extract the actual element dtype from the type-annotation payload instead of the wrapper call. For example:
- if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()):
- inferred = str(data_ptr.args[0].dtype)
+ if isinstance(data_ptr, Call) and data_ptr.op.same_as(tir.builtin.tvm_access_ptr()):
+ type_anno = data_ptr.args[0]
+ if isinstance(type_anno, Call) and type_anno.op.same_as(tir.builtin.type_annotation()):
+ inferred = str(type_anno.args[0])
+ elif hasattr(type_anno, "value"):
+ inferred = str(type_anno.value)(Adapt as needed for other concrete node types.)
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.3)
574-574: Do not catch blind exception: Exception
(BLE001)
577-579: Avoid specifying long messages outside the exception class
(TRY003)
| def _get_stride_continuous( | ||
| buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]: | ||
| """ | ||
| Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion. | ||
| Args: | ||
| buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion | ||
| Returns: | ||
| tuple: (stride, continuous) as integers | ||
| """ | ||
| _, shape, _ = _get_buffer_info(buffer_or_load_or_region) | ||
| stride = int(shape[-2]) | ||
| continuous = int(shape[-1]) | ||
| return stride, continuous | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not cast PrimExpr to int; handle rank>=2 and keep symbolic dims
Casting shape[-2]/shape[-1] to int breaks when shapes are symbolic; also no guard for 1D buffers. Return PrimExprs and validate rank.
-def _get_stride_continuous(
- buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
+def _get_stride_continuous(
+ buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[PrimExpr, PrimExpr]:
@@
- _, shape, _ = _get_buffer_info(buffer_or_load_or_region)
- stride = int(shape[-2])
- continuous = int(shape[-1])
- return stride, continuous
+ _, shape, _ = _get_buffer_info(buffer_or_load_or_region)
+ if len(shape) < 2:
+ raise ValueError("Buffer rank must be >= 2 for swizzled layouts")
+ stride = shape[-2]
+ continuous = shape[-1]
+ return stride, continuous📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _get_stride_continuous( | |
| buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]: | |
| """ | |
| Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion. | |
| Args: | |
| buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion | |
| Returns: | |
| tuple: (stride, continuous) as integers | |
| """ | |
| _, shape, _ = _get_buffer_info(buffer_or_load_or_region) | |
| stride = int(shape[-2]) | |
| continuous = int(shape[-1]) | |
| return stride, continuous | |
| def _get_stride_continuous( | |
| buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[PrimExpr, PrimExpr]: | |
| """ | |
| Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion. | |
| Args: | |
| buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion | |
| Returns: | |
| tuple: (stride, continuous) as integers | |
| """ | |
| _, shape, _ = _get_buffer_info(buffer_or_load_or_region) | |
| if len(shape) < 2: | |
| raise ValueError("Buffer rank must be >= 2 for swizzled layouts") | |
| stride = shape[-2] | |
| continuous = shape[-1] | |
| return stride, continuous |
🤖 Prompt for AI Agents
In tilelang/layout/swizzle.py around lines 32 to 47, the function currently
casts shape[-2] and shape[-1] to int which breaks for symbolic PrimExpr dims and
doesn't guard against 1D buffers; change the function to (1) validate that the
buffer rank is >= 2 and raise a clear ValueError if not, (2) stop casting to int
and return the raw PrimExprs for stride and continuous (update the return type
to tuple[PrimExpr, PrimExpr] and import PrimExpr if needed), and (3) keep the
rest of the logic unchanged so callers receive symbolic dimensions when present.
| A_region = self.ARegion | ||
| B_region = self.BRegion | ||
| C_region = self.CRegion | ||
|
|
||
| A_buf = A_region.buffer | ||
| C_buf = C_region.buffer | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call GemmBase region accessors
self.ARegion, self.BRegion, and self.CRegion are methods on GemmBase. Taking them without parentheses leaves bound method objects, so the very next line (A_region.buffer) raises 'function' object has no attribute "buffer"' at runtime. Invoke the accessors before using the regions.
Apply this diff:
- A_region = self.ARegion
- B_region = self.BRegion
- C_region = self.CRegion
+ A_region = self.ARegion()
+ B_region = self.BRegion()
+ C_region = self.CRegion()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| A_region = self.ARegion | |
| B_region = self.BRegion | |
| C_region = self.CRegion | |
| A_buf = A_region.buffer | |
| C_buf = C_region.buffer | |
| A_region = self.ARegion() | |
| B_region = self.BRegion() | |
| C_region = self.CRegion() | |
| A_buf = A_region.buffer | |
| C_buf = C_region.buffer |
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma_sm70.py around lines 78 to 84, the code assigns
self.ARegion, self.BRegion, and self.CRegion (bound method objects) to variables
without calling them, so accessing A_region.buffer raises "'function' object has
no attribute \"buffer\""; fix by invoking the accessors (call self.ARegion(),
self.BRegion(), self.CRegion()) when assigning A_region, B_region, and C_region
so the variables hold region objects before using their .buffer attributes.
| # We use region for memory input to support strided gemm | ||
| # T.gemm(A_shared[0:128, :], B_shared, C_local) | ||
| A_region = self.ARegion | ||
| B_region = self.BRegion | ||
| C_region = self.CRegion | ||
|
|
||
| A_buf = A_region.buffer | ||
| B_buf = B_region.buffer | ||
| C_buf = C_region.buffer | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invoke region getters before dereferencing
Just like on SM70, GemmBase.ARegion/BRegion/CRegion are plain methods. Capturing them without calling leaves bound method objects, so A_region.buffer, B_region.buffer, etc. explode with 'function' object has no attribute "buffer"'. This breaks every lowering path.
Please update to:
- A_region = self.ARegion
- B_region = self.BRegion
- C_region = self.CRegion
+ A_region = self.ARegion()
+ B_region = self.BRegion()
+ C_region = self.CRegion()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # We use region for memory input to support strided gemm | |
| # T.gemm(A_shared[0:128, :], B_shared, C_local) | |
| A_region = self.ARegion | |
| B_region = self.BRegion | |
| C_region = self.CRegion | |
| A_buf = A_region.buffer | |
| B_buf = B_region.buffer | |
| C_buf = C_region.buffer | |
| # We use region for memory input to support strided gemm | |
| # T.gemm(A_shared[0:128, :], B_shared, C_local) | |
| A_region = self.ARegion() | |
| B_region = self.BRegion() | |
| C_region = self.CRegion() | |
| A_buf = A_region.buffer | |
| B_buf = B_region.buffer | |
| C_buf = C_region.buffer |
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma.py around lines 86 to 95, the code captures
ARegion/BRegion/CRegion methods without invoking them so
A_region/B_region/C_region are bound function objects and accessing .buffer
fails; call the getters (e.g., A_region = self.ARegion()) for A, B, and C before
accessing .buffer (and then use A_region.buffer, B_region.buffer,
C_region.buffer) so the actual region objects are used.
| def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list: | ||
| """ | ||
| Retrieve row-major strides for a buffer-like object based on its buffer.shape. | ||
| For BufferRegion and BufferLoad, uses the underlying buffer's `shape`. | ||
| """ | ||
| if isinstance(obj, tir.Buffer): | ||
| shape = obj.shape | ||
| elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)): | ||
| shape = obj.buffer.shape | ||
| else: | ||
| raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}") | ||
|
|
||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| return strides | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honor explicit Buffer.strides when available; fall back to row-major
Current implementation assumes row-major strides. Use buffer.strides if defined; otherwise compute from shape.
-def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
+def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
@@
- if isinstance(obj, tir.Buffer):
- shape = obj.shape
- elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
- shape = obj.buffer.shape
+ if isinstance(obj, tir.Buffer):
+ buf = obj
+ elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
+ buf = obj.buffer
else:
raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
- strides = []
- stride = 1
- for s in reversed(shape):
- strides.insert(0, stride)
- stride *= s
- return strides
+ # Prefer explicit strides if provided
+ if buf.strides and len(buf.strides) == len(buf.shape):
+ return list(buf.strides)
+ # Fallback: derive row-major strides from shape
+ strides: list[PrimExpr] = []
+ stride: PrimExpr = tir.IntImm("int32", 1)
+ for s in reversed(buf.shape):
+ strides.insert(0, stride)
+ stride = stride * s
+ return strides📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list: | |
| """ | |
| Retrieve row-major strides for a buffer-like object based on its buffer.shape. | |
| For BufferRegion and BufferLoad, uses the underlying buffer's `shape`. | |
| """ | |
| if isinstance(obj, tir.Buffer): | |
| shape = obj.shape | |
| elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)): | |
| shape = obj.buffer.shape | |
| else: | |
| raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}") | |
| strides = [] | |
| stride = 1 | |
| for s in reversed(shape): | |
| strides.insert(0, stride) | |
| stride *= s | |
| return strides | |
| def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list: | |
| """ | |
| Retrieve row-major strides for a buffer-like object based on its buffer.shape. | |
| For BufferRegion and BufferLoad, uses the underlying buffer's `shape`. | |
| """ | |
| if isinstance(obj, tir.Buffer): | |
| buf = obj | |
| elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)): | |
| buf = obj.buffer | |
| else: | |
| raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}") | |
| # Prefer explicit strides if provided | |
| if buf.strides and len(buf.strides) == len(buf.shape): | |
| return list(buf.strides) | |
| # Fallback: derive row-major strides from shape | |
| strides: list[PrimExpr] = [] | |
| stride: PrimExpr = tir.IntImm("int32", 1) | |
| for s in reversed(buf.shape): | |
| strides.insert(0, stride) | |
| stride = stride * s | |
| return strides |
🧰 Tools
🪛 Ruff (0.14.3)
245-245: Prefer TypeError exception for invalid type
(TRY004)
245-245: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/utils/language.py around lines 234 to 253, the function always
computes row-major strides from shape; change it to use an explicit
buffer.strides when present and only compute row-major strides as a fallback.
Concretely: for tir.Buffer use buffer.strides if buffer.strides is not
None/empty and convert it to a plain list of ints to return; for
tir.BufferRegion and tir.BufferLoad read obj.buffer.strides first and fall back
to computing from obj.buffer.shape if strides are missing; keep the existing
reversed-shape computation logic as the fallback and raise the same ValueError
for unsupported types.
| def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, | ||
| access_type: str = "r") -> PrimExpr: | ||
| if isinstance(buffer_or_load_or_region, Buffer): | ||
| return buffer_or_load_or_region.access_ptr(access_type) | ||
| elif isinstance(buffer_or_load_or_region, BufferLoad): | ||
| buffer_load = buffer_or_load_or_region | ||
| offset, stride = 0, 1 | ||
| buffer = buffer_load.buffer | ||
| for i, shape in enumerate(reversed(buffer.shape)): | ||
| indice = buffer_load.indices[len(buffer_load.indices) - i - 1] | ||
| if isinstance(indice, (tir.IntImm, tir.PrimExpr)): | ||
| offset += indice * stride | ||
| elif isinstance(indice, tir.Ramp): | ||
| offset += indice.base * stride | ||
| else: | ||
| raise ValueError(f"Unsupported index type: {type(indice)}") | ||
| stride *= shape | ||
| return buffer.access_ptr(access_type, offset=offset) | ||
| elif isinstance(buffer_or_load_or_region, BufferRegion): | ||
| buffer_region = buffer_or_load_or_region | ||
| buffer = buffer_region.buffer | ||
| offset, stride = 0, 1 | ||
| for i, shape in enumerate(reversed(buffer.shape)): | ||
| offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride | ||
| stride *= shape | ||
| return buffer.access_ptr(access_type, offset=offset) | ||
| else: | ||
| raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Fix function name typo and provide a compatibility alias
Rename retrive_ptr_from_buffer_region -> retrieve_ptr_from_buffer_region and keep an alias to avoid breaking callers.
-def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
- access_type: str = "r") -> PrimExpr:
+def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
+ access_type: str = "r") -> PrimExpr:
@@
- else:
- raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+ else:
+ raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+
+# Backward-compat alias (to be deprecated)
+retrive_ptr_from_buffer_region = retrieve_ptr_from_buffer_regionRun to find potential external/internal references to the old name:
🏁 Script executed:
#!/bin/bash
rg -nP '\bretrive_ptr_from_buffer_region\b' -C2Length of output: 2188
Fix typo in function name, error type, and add backward-compat alias
Rename retrive_ptr_from_buffer_region → retrieve_ptr_from_buffer_region, change the final exception to TypeError, and add a backward-compat alias:
-def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
- access_type: str = "r") -> PrimExpr:
+def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
+ access_type: str = "r") -> PrimExpr:
@@
- else:
- raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+ else:
+ raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
+
+# Backward-compat alias
+retrive_ptr_from_buffer_region = retrieve_ptr_from_buffer_regionThis preserves the existing __init__.py import and all calls in wgmma_macro_generator.py.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, | |
| access_type: str = "r") -> PrimExpr: | |
| if isinstance(buffer_or_load_or_region, Buffer): | |
| return buffer_or_load_or_region.access_ptr(access_type) | |
| elif isinstance(buffer_or_load_or_region, BufferLoad): | |
| buffer_load = buffer_or_load_or_region | |
| offset, stride = 0, 1 | |
| buffer = buffer_load.buffer | |
| for i, shape in enumerate(reversed(buffer.shape)): | |
| indice = buffer_load.indices[len(buffer_load.indices) - i - 1] | |
| if isinstance(indice, (tir.IntImm, tir.PrimExpr)): | |
| offset += indice * stride | |
| elif isinstance(indice, tir.Ramp): | |
| offset += indice.base * stride | |
| else: | |
| raise ValueError(f"Unsupported index type: {type(indice)}") | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) | |
| elif isinstance(buffer_or_load_or_region, BufferRegion): | |
| buffer_region = buffer_or_load_or_region | |
| buffer = buffer_region.buffer | |
| offset, stride = 0, 1 | |
| for i, shape in enumerate(reversed(buffer.shape)): | |
| offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) | |
| else: | |
| raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") | |
| def retrieve_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, | |
| access_type: str = "r") -> PrimExpr: | |
| if isinstance(buffer_or_load_or_region, Buffer): | |
| return buffer_or_load_or_region.access_ptr(access_type) | |
| elif isinstance(buffer_or_load_or_region, BufferLoad): | |
| buffer_load = buffer_or_load_or_region | |
| offset, stride = 0, 1 | |
| buffer = buffer_load.buffer | |
| for i, shape in enumerate(reversed(buffer.shape)): | |
| indice = buffer_load.indices[len(buffer_load.indices) - i - 1] | |
| if isinstance(indice, (tir.IntImm, tir.PrimExpr)): | |
| offset += indice * stride | |
| elif isinstance(indice, tir.Ramp): | |
| offset += indice.base * stride | |
| else: | |
| raise ValueError(f"Unsupported index type: {type(indice)}") | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) | |
| elif isinstance(buffer_or_load_or_region, BufferRegion): | |
| buffer_region = buffer_or_load_or_region | |
| buffer = buffer_region.buffer | |
| offset, stride = 0, 1 | |
| for i, shape in enumerate(reversed(buffer.shape)): | |
| offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride | |
| stride *= shape | |
| return buffer.access_ptr(access_type, offset=offset) | |
| else: | |
| raise TypeError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") | |
| # Backward-compat alias | |
| retrive_ptr_from_buffer_region = retrieve_ptr_from_buffer_region |
🧰 Tools
🪛 Ruff (0.14.3)
270-270: Prefer TypeError exception for invalid type
(TRY004)
270-270: Avoid specifying long messages outside the exception class
(TRY003)
282-282: Prefer TypeError exception for invalid type
(TRY004)
282-282: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/utils/language.py around lines 255 to 283, rename the function
retrive_ptr_from_buffer_region to retrieve_ptr_from_buffer_region, change the
final raised exception from ValueError to TypeError for unsupported input types,
and add a backward-compatible alias so retrive_ptr_from_buffer_region points to
the new retrieve_ptr_from_buffer_region function (preserving existing
imports/calls).
- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection. - Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations. - Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage. - Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms. - Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/op/gemm.cc (1)
121-125: Fix pointer offset when rebuilding access_ptr.Identical to the issue flagged in
gemm_py.cc: skipping the last two dimensions in the offset sum causes every reconstructed pointer to land at element (0,0) for typical 2D regions, so all GEMM tiles operate on the wrong data. The loop must include the minima of every dimension when computingoffset.Apply this diff:
- for (int i = 0; i < ndim - 2; ++i) { + for (int i = 0; i < ndim; ++i) { offset = offset + region->region[i]->min * strides[i]; }
🧹 Nitpick comments (6)
tilelang/intrinsics/mma_macro_generator.py (3)
144-168: Consider adding assertions for FP64 path.The FP64 micro-size initialization logic is correct, but unlike the non-FP64 path (lines 153-156), the FP64 path lacks assertions to validate
warp_row_tilesandwarp_col_tilesconstraints. For FP64 (m8n8k4), you should verify:
warp_row_tiles >= 8andwarp_row_tiles % 8 == 0warp_col_tiles >= 8andwarp_col_tiles % 8 == 0Apply this diff to add validation:
if k_dim == 4: # fp64 path: m_dim must be 8, n_dim 8 assert m_dim == 8, f"For fp64 MMA, m_dim must be 8, got {m_dim}" + assert warp_row_tiles >= 8, f"warp_row_tiles must be >= 8 for fp64, got {warp_row_tiles}" + assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8 for fp64, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be >= 8 for fp64, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8 for fp64, got {warp_col_tiles}" self.n_dim = 8 self.micro_size_y = 8 self.warp_rows = warp_row_tiles // m_dim self.warp_cols = warp_col_tiles // 8
234-270: FP64 fast path implementation looks correct, but consider the unused parameter.The FP64 fast path correctly implements direct per-lane loads since PTX ldmatrix doesn't support FP64. The index calculations (
mi = tx // micro_size_k,mk = tx % micro_size_k) and handling of both transposed/non-transposed cases appear correct.However, static analysis correctly identifies that the
A_shared_bufparameter in the inner macro_warp_ld_a_fp64(line 254) is unused. The macro only usesA_local_buf,ki,thread_binding, andrk, while the actual buffer access uses the capturedA_bufand base offsets.If the parameter is kept for API consistency with other macro variants, consider adding a comment explaining this. Otherwise, remove it from the signature.
Apply this diff to remove the unused parameter:
@T.macro def _warp_ld_a_fp64( A_local_buf, - A_shared_buf, ki, thread_binding, rk=0, ):And update the return statement:
- return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk) + return _warp_ld_a_fp64(A_local_buf, ki, thread_binding, rk)
351-387: FP64 fast path is correct, but has an unused parameter.Similar to
ldmatrix_a, the FP64 path for matrix B correctly implements direct per-lane loads. The index calculations and transpose handling are appropriate.The same unused parameter issue exists:
B_shared_bufin the inner macro_warp_ld_b_fp64(line 371) is not used. Consider removing it or documenting why it's kept for API consistency.Apply this diff:
@T.macro def _warp_ld_b_fp64( B_local_buf, - B_shared_buf, ki, thread_binding, rk=0, ):And update the return:
- return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk) + return _warp_ld_b_fp64(B_local_buf, ki, thread_binding, rk)tilelang/intrinsics/mma_layout.py (1)
48-51: Logic is correct; consider adding documentation.The math for the 8×8 FP64 layout mapping is correct:
- 32 threads with 2 local elements each map to a 64-element (8×8) shared memory tile
- Row assignment groups every 4 threads together (rows 0-7)
- Column assignment interleaves within thread groups (cols 0-7)
Consider adding a docstring to explain the layout strategy and expected parameter ranges (thread_id: 0-31, local_id: 0-1).
Optional: Add docstring for clarity.
def mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id): + """ + Map thread_id and local_id to shared memory layout for FP64 MMA store. + + Maps 32 threads × 2 local elements to an 8×8 shared memory tile. + + Args: + thread_id: Thread index within warp (0-31) + local_id: Local element index (0-1) + + Returns: + tuple: (row, col) indices in the 8×8 shared memory layout + """ row = thread_id // 4 col = (thread_id % 4) * 2 + local_id return row, colsrc/op/gemm.cc (2)
87-89: Minor: prefer.at()for consistency with gemm_py.cc.Line 89 uses
vmap[var]while the equivalent code ingemm_py.cc:89usesvmap.at(var). Using.at()is more explicit about lookup semantics and provides clearer error messages when a key is missing.Apply this diff:
- Buffer buf = vmap[var]; + Buffer buf = vmap.at(var);
159-164: Update stale documentation comment.The comment at line 47 states that kPack "must be 1", but the code at line 161 accepts both 1 and 2. The comment should be updated to reflect the actual constraint.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/op/gemm.cc(9 hunks)src/tl_templates/cuda/instruction/mma.h(1 hunks)tilelang/intrinsics/mma_layout.py(1 hunks)tilelang/intrinsics/mma_macro_generator.py(15 hunks)tilelang/intrinsics/utils.py(2 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/op/gemm.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/op/gemm.cc
🧬 Code graph analysis (3)
tilelang/intrinsics/utils.py (1)
tilelang/intrinsics/mma_layout.py (1)
mma_store_32x2_to_shared_8x8_layout_fp64(48-51)
tilelang/intrinsics/mma_macro_generator.py (3)
tilelang/intrinsics/utils.py (3)
mma_store_index_map(82-83)get_ldmatrix_offset(22-64)mma_store_index_map_fp64(86-87)tilelang/utils/language.py (2)
is_fragment(105-116)to_buffer_region(187-211)tilelang/intrinsics/mma_sm70_macro_generator.py (4)
get_thread_binding(139-145)extract_thread_binding(158-188)_warp_ldmatrix_a(216-230)ldmatrix_b(234-284)
src/op/gemm.cc (5)
tilelang/language/utils.py (1)
region(8-27)tilelang/language/tir/op.py (1)
tvm_access_ptr(651-676)src/op/gemm_py.cc (13)
strides(85-85)NormalizeToBufferRegion(25-72)NormalizeToBufferRegion(25-26)allowTcgen5Mma(192-199)allowTcgen5Mma(192-192)allowWgmma(201-209)allowWgmma(201-201)getGemmInst(211-229)getGemmInst(211-211)MakeAccessPtrFromRegion(78-108)MakeAccessPtrFromRegion(78-79)GetArchInt(328-339)GetArchInt(328-328)src/op/gemm.h (5)
Gemm(144-149)GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)src/op/operator.cc (2)
GetVarFromAccessPtr(74-81)GetVarFromAccessPtr(74-74)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mma_macro_generator.py
254-254: Unused function argument: A_shared_buf
(ARG001)
371-371: Unused function argument: B_shared_buf
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (14)
src/tl_templates/cuda/instruction/mma.h (1)
139-141: LGTM! FP64 DMMA dispatcher implementation verified and safe.The
cute::SM80_8x8x4_F64F64F64F64_TNimplementation is already in use throughout the codebase insrc/tl_templates/cuda/gemm_mma.h(lines 55, 68, 80, 92, 100). The code addition follows the exact same established pattern with no new dependencies introduced. CuTe is managed as a git submodule and is a proven dependency.tilelang/intrinsics/mma_macro_generator.py (8)
6-7: LGTM!The import additions for
BufferRegionandto_buffer_regionare appropriate for the BufferRegion integration feature. These enable the methods to accept bothBufferandBufferRegionas inputs.Also applies to: 13-13
44-44: LGTM!The FP64 dtype abbreviation is correctly added and follows the established pattern for other data types.
83-87: LGTM!The FP64 initialization correctly overrides
M_DIMto 8 for the m8n8k4 MMA shape. The comment clearly explains the rationale.
126-128: LGTM!The FP64 MMA prefix "m8n8k4" is correctly set for k_dim == 4, matching the FP64 matrix multiplication instruction shape.
186-191: LGTM!The conditional selection of the FP64-specific store index map (
mma_store_index_map_fp64) based on the accumulator dtype is correct. This accounts for the different memory layout of m8n8k4 FP64 MMA results.
296-343: LGTM!The BufferRegion integration is well-implemented:
- Properly normalizes inputs using
to_buffer_region- Correctly extracts base offsets from the last two dimensions (
region[-2].min,region[-1].min)- Updates both ldmatrix and fallback paths to use region-based addressing
- Handles transposed and non-transposed layouts correctly
The multi-line buffer element assignments (lines 318-320, 336-341) improve readability for complex indexing expressions.
399-467: LGTM!The BufferRegion integration for matrix B mirrors the approach in
ldmatrix_aand is correctly implemented. Region-based addressing is properly applied to both the ldmatrix path (lines 440-442) and the fallback path (lines 460-465), with correct handling of transposed and non-transposed layouts.
753-754: LGTM!The improved assertion message now includes the buffer name and actual scope, making debugging easier when a non-fragment buffer is incorrectly passed.
tilelang/intrinsics/utils.py (2)
11-11: LGTM: Import added correctly.The import for the FP64 layout function is properly placed and formatted.
86-87: New FP64 wrapper function is unused; verify integration plan with mma_macro_generator.py.The function is syntactically correct and properly delegates to
mma_store_32x2_to_shared_8x8_layout_fp64, but it's not called anywhere in the codebase. The standardmma_store_index_mapis actively used inmma_macro_generator.py(lines 558, 576), but there are no corresponding FP64 calls. Confirm whethermma_macro_generator.pyshould be updated to support FP64 variants, or if this wrapper is preparatory work not yet ready for integration.src/op/gemm.cc (3)
192-224: LGTM: method renames and refactoring.The method renames (allowTcgen5Mma, allowWgmma, getGemmInst, checkWgmma) follow consistent camelCase naming conventions, and the refactored implementation correctly uses the new region-based fields (a_, b_, c_, transA_, transB_, m_, n_, k_, etc.).
538-540: These pointers will have incorrect offsets until MakeAccessPtrFromRegion is fixed.The calls to
MakeAccessPtrFromRegionon these lines will produce incorrect offsets due to the bug flagged in the review comment for lines 121-125. Once that bug is fixed, these calls will work correctly.
683-904: LGTM: InferLayout refactoring.The InferLayout method correctly uses all renamed fields (c_, a_, b_, m_, n_, k_, transA_, transB_, kPack_, etc.) consistently across all target architectures. The refactoring is complete and the logic appears sound.
…ling - Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency. - Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope. - Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types. - Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)
191-192: Fix BufferRegion-to-Buffer mismatch (still unfixed from previous review).
_determinate_swizzle_modeexpects aBufferand accesses.shape(line 152), but you're passingA_regionandB_regionwhich areBufferRegionobjects. This will raise anAttributeErrorat runtime.Apply this diff to pass the underlying buffers:
- a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) - b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)Based on past review comments.
344-344: Fix BufferRegion-to-Buffer mismatch in wgmma_rs (still unfixed from previous review).Same issue as in
wgmma:_determinate_swizzle_modeexpects aBufferbut receives aBufferRegion. This will crash when the method accesses.shape.Apply this diff:
- b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)Based on past review comments.
🧹 Nitpick comments (4)
tilelang/intrinsics/wgmma_macro_generator.py (2)
274-274: Prefix unused variable with underscore.The extracted
txvariable is never used. Consider prefixing it with an underscore to indicate it's intentionally unused.Apply this diff:
- tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
388-388: Prefix unused variables with underscores.The extracted
txandwarp_mvariables are never used in the macro. Consider prefixing them with underscores to indicate they're intentionally unused.Apply this diff:
- tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + _tx, warp_n, _warp_m = self.extract_thread_binding(thread_binding)src/transform/merge_shared_memory_allocations.cc (1)
357-372: LGTM! Well-structured helper for alignment marking.The helper correctly guards against null pointers and non-alignment scopes, properly extracts storage scope information, and applies appropriate alignment based on the target architecture.
Consider extracting the hardcoded alignment values as named constants for better maintainability:
+private: + static constexpr int kHopperAlignment = 1024; + static constexpr int kDefaultAlignment = 16; + // Helper to record alignment for a shared/shared.dyn Var under alignment // scope void MarkSharedVarIfNeeded(const VarNode *op) { if (!op || !under_alignment_scope_) return; auto ptr_type = op->type_annotation.as<PointerTypeNode>(); if (!ptr_type) return; auto scope = GetPtrStorageScope(tvm::ffi::GetRef<Var>(op)); if (scope == "shared" || scope == "shared.dyn") { auto target = Target::Current(); ICHECK(target.defined()) << "Target is not defined"; - const int alignment = TargetIsHopper(target) ? 1024 : 16; + const int alignment = TargetIsHopper(target) ? kHopperAlignment : kDefaultAlignment; shmem_alignment_map_[op] = alignment; } }maint/gemm_v2/correctness_evaluation.py (1)
726-735: Consider removing commented test code.The commented-out test blocks at the end of the file appear to be development or debugging artifacts. If they're no longer needed, consider removing them to keep the codebase clean. If they serve a purpose (e.g., manual testing scenarios), consider moving them to a separate debug script or adding a comment explaining why they're preserved.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
maint/gemm_v2/correctness_evaluation.py(7 hunks)src/transform/merge_shared_memory_allocations.cc(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py(10 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/merge_shared_memory_allocations.cc (2)
tilelang/language/builtin.py (3)
tma_load(86-95)initialize_wgmma_descriptor(703-730)initialize_tcgen05_descriptor(733-764)src/transform/storage_access.cc (4)
VisitExpr_(40-62)VisitExpr_(40-40)VisitExpr_(314-453)VisitExpr_(314-314)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
is_fragment(105-116)retrive_ptr_from_buffer_region(255-282)is_full_region(370-399)tilelang/language/builtin.py (2)
initialize_wgmma_descriptor(703-730)warpgroup_fence_operand(433-592)
maint/gemm_v2/correctness_evaluation.py (2)
maint/gemm_v2/correctness_evaluation_sm70.py (1)
run_gemm_rs(180-210)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)
run_gemm_rs(191-242)run_gemm_sr(326-377)run_gemm_rr(465-516)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py
274-274: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
388-388: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
388-388: Unpacked variable warp_m is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (8)
tilelang/intrinsics/wgmma_macro_generator.py (3)
7-8: LGTM: Region-based imports added correctly.The imports properly support the BufferRegion refactoring, including the helper function
retrive_ptr_from_buffer_regionwhich is now imported from utils instead of being defined locally.
266-271: LGTM: Correct region-to-pointer/buffer extraction.The pointer extraction and buffer setup properly utilize the imported helpers. The full-region assertion for the output fragment is appropriate.
380-385: LGTM: Correct region assertions and buffer extraction.The full-region assertions for fragment inputs/outputs and the subsequent buffer/pointer extraction are properly implemented.
src/transform/merge_shared_memory_allocations.cc (3)
374-387: LGTM! Correct intrinsic updates for alignment scope detection.The updated condition properly recognizes the normalized descriptor initialization intrinsics (
initialize_wgmma_descriptorandinitialize_tcgen05_descriptor) that require stricter shared memory alignment. This aligns with the refactoring to use more generic intrinsics across GEMM paths.
389-392: LGTM! Correct handling of direct variable references.The visitor properly marks shared memory variables when they're directly referenced within an alignment scope before delegating to the base class.
394-402: LGTM! Proper handling of BufferLoad nodes.The new visitor correctly handles the case where buffers are accessed via
BufferLoadby marking the underlying data variable for alignment when within an alignment scope. The explanatory comment clearly describes the intent, and the implementation properly delegates to the base visitor.maint/gemm_v2/correctness_evaluation.py (2)
49-49: LGTM! Kernel selection now respects environment variable.The change from
T.gemm_v2toT.gemmin the basic matmul variant allows the system to select between GEMM v1 and v2 based on theTILELANG_USE_GEMM_V1environment variable, which aligns with the PR objectives. The other matmul variants (rs, sr, rr) correctly continue to useT.gemm_v2explicitly.
439-483: LGTM! Cleaner code by relying on defaults.The wrapper functions have been appropriately simplified to rely on default parameter values. This makes the code more maintainable and ensures consistency with the updated defaults.
| num_stages=2, | ||
| num_threads=128, | ||
| ): | ||
| if block_N >= 256 or block_M >= 256 or block_K >= 256: | ||
| num_stages = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Document the num_stages override logic and verify consistency.
Three concerns:
-
Default change inconsistency: The
num_stagesdefault is changed from 3 to 2 here, but the relevant code snippets show that similar files (correctness_evaluation_sm70.pyandtest_tilelang_tilelibrary_gemm.py) still use the default value of 3. Verify this divergence is intentional. -
Undocumented conditional override: The logic that sets
num_stages=0when any block dimension is >= 256 completely disables pipelining for large blocks, but there's no comment explaining why. Add a comment describing the rationale (e.g., "Disable pipelining for large blocks to avoid resource constraints" or similar). -
Code duplication: This conditional logic is duplicated in
run_gemm,run_gemm_rs,run_gemm_sr, andrun_gemm_rr. Consider extracting it into a helper function.
🏁 Script executed:
#!/bin/bash
# Description: Check if other test files have inconsistent num_stages defaults
rg -n "num_stages\s*=\s*[0-9]" --type=py -g "*gemm*" -g "*correctness*" -C2Length of output: 31867
Add documentation for num_stages logic and verify if default change is intentional.
The script output confirms all three concerns:
-
Verified inconsistency:
correctness_evaluation.pyusesnum_stages=2(lines 105, 193, 279, 369), whilecorrectness_evaluation_sm70.pyandtest_tilelang_tilelibrary_gemm.pyusenum_stages=3. Confirm whether this divergence is intentional. -
Verified lack of documentation: The conditional
if block_N >= 256 or block_M >= 256 or block_K >= 256: num_stages = 0appears identically at lines 108–109, 196–197, 282–283, and 372–373 with no explanatory comments. Add inline documentation describing the rationale (e.g., resource constraints, memory limitations, or performance tuning). -
Verified code duplication: This conditional block is duplicated across
run_gemm,run_gemm_rs,run_gemm_sr, andrun_gemm_rr. Extract into a helper function to reduce duplication and improve maintainability.
🤖 Prompt for AI Agents
maint/gemm_v2/correctness_evaluation.py lines 105-109: The num_stages default
and repeated conditional are inconsistent and undocumented; confirm whether
default num_stages=2 (here and at lines 193, 279, 369) should match other files
that use 3, add an inline comment explaining why we force num_stages=0 when any
block dimension >=256 (e.g., register/shared-memory/resource limits or
performance reasons), and remove duplication by extracting the conditional into
a single helper (e.g., determine_num_stages(block_M, block_N, block_K,
default_num_stages)) used by run_gemm, run_gemm_rs, run_gemm_sr, and run_gemm_rr
so all call the helper and the rationale is documented in the helper's
docstring.
- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations. - Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication. - Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.
…cture and function calls - Updated import statements to directly reference modules instead of individual test functions, enhancing clarity. - Modified function calls to use the new module structure for better organization and maintainability in testing examples.
- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations. - Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tilelang/tileop/gemm/gemm_mma_sm70.py (1)
78-85: Call region accessor methods before dereferencing.
self.ARegion,self.BRegion, andself.CRegionare methods onGemmBase. Taking them without parentheses leaves bound method objects, so the very next line (A_region.buffer) raisesAttributeError: 'function' object has no attribute 'buffer'at runtime.Apply this diff:
- A_region = self.ARegion - B_region = self.BRegion - C_region = self.CRegion + A_region = self.ARegion() + B_region = self.BRegion() + C_region = self.CRegion()tilelang/tileop/gemm/gemm_mma.py (1)
88-96: Invoke region getters before dereferencing.Just like the other GEMM files,
GemmBase.ARegion/BRegion/CRegionare methods. Capturing them without calling leaves bound method objects, soA_region.buffer,B_region.buffer, etc. will raiseAttributeError: 'function' object has no attribute 'buffer'at runtime.Apply this diff:
- A_region = self.ARegion - B_region = self.BRegion - C_region = self.CRegion + A_region = self.ARegion() + B_region = self.BRegion() + C_region = self.CRegion()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py(1 hunks)src/transform/storage_rewrite.cc(1 hunks)tilelang/intrinsics/mfma_macro_generator.py(7 hunks)tilelang/tileop/gemm/gemm_mfma.py(6 hunks)tilelang/tileop/gemm/gemm_mma.py(6 hunks)tilelang/tileop/gemm/gemm_mma_sm70.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/transform/storage_rewrite.cc (1)
src/transform/merge_shared_memory_allocations.cc (2)
buffer(532-551)buffer(532-532)
tilelang/tileop/gemm/gemm_mfma.py (4)
tilelang/utils/language.py (3)
is_shared(45-60)is_fragment(105-116)is_full_region(370-399)tilelang/tileop/gemm/gemm_base.py (7)
ARegion(79-80)BRegion(83-84)CRegion(87-88)clear_accum(107-108)is_gemm_sr(24-25)is_gemm_rs(27-28)is_gemm_rr(30-31)tilelang/language/fill.py (1)
clear(50-74)tilelang/intrinsics/mfma_macro_generator.py (5)
ldmatrix_a(254-298)ldmatrix_a(711-784)ldmatrix_b(300-349)ldmatrix_b(786-861)mfma(351-394)
tilelang/intrinsics/mfma_macro_generator.py (3)
tilelang/utils/language.py (1)
to_buffer_region(187-211)tilelang/intrinsics/mma_macro_generator.py (6)
ldmatrix_a(229-343)ldmatrix_a(887-991)_warp_ldmatrix_a(304-341)_warp_ldmatrix_a(901-989)ldmatrix_b(345-467)ldmatrix_b(993-1105)tilelang/intrinsics/mma_sm70_macro_generator.py (3)
ldmatrix_a(190-232)_warp_ldmatrix_a(216-230)ldmatrix_b(234-284)
tilelang/tileop/gemm/gemm_mma_sm70.py (4)
tilelang/utils/language.py (3)
is_shared(45-60)is_fragment(105-116)is_full_region(370-399)tilelang/tileop/gemm/gemm_base.py (5)
ARegion(79-80)BRegion(83-84)CRegion(87-88)clear_accum(107-108)is_gemm_rs(27-28)tilelang/language/fill.py (1)
clear(50-74)tilelang/intrinsics/mma_sm70_macro_generator.py (3)
ldmatrix_a(190-232)ldmatrix_b(234-284)mma(286-327)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (5)
examples/deepseek_v32/sparse_mla_fwd.py (2)
sparse_mla_fwd(15-174)test_sparse_mla_fwd(253-299)examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
sparse_mla_fwd(18-311)test_sparse_mla_fwd_pipelined(400-452)examples/deepseek_v32/sparse_mla_bwd.py (2)
sparse_mla_bwd(283-320)test_sparse_mla_bwd(334-384)examples/deepseek_v32/topk_selector.py (1)
test_topk_selector(188-245)examples/deepseek_v32/fp8_lighting_indexer.py (1)
test_fp8_lighting_indexer(260-302)
tilelang/tileop/gemm/gemm_mma.py (5)
tilelang/utils/language.py (3)
is_shared(45-60)is_fragment(105-116)is_full_region(370-399)tilelang/tileop/gemm/gemm_base.py (7)
ARegion(79-80)BRegion(83-84)CRegion(87-88)clear_accum(107-108)is_gemm_sr(24-25)is_gemm_rs(27-28)is_gemm_rr(30-31)tilelang/language/fill.py (1)
clear(50-74)tilelang/intrinsics/mma_macro_generator.py (8)
ldmatrix_a(229-343)ldmatrix_a(887-991)ldmatrix_b(345-467)ldmatrix_b(993-1105)mma(469-529)mma(1107-1156)mma(1161-1259)mma(1264-1363)tilelang/tileop/gemm/gemm_mfma.py (7)
_gemm_ssr(107-133)is_gemm_sr(224-225)_gemm_srr(142-163)is_gemm_rs(227-228)_gemm_rsr(174-193)_gemm_rsr(203-212)is_gemm_rr(230-231)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py
293-293: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (8)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (2)
4-8: Clean refactor to module-level imports.The change from direct function imports to module-level imports improves code clarity by making it explicit where each test function originates. The refactor is consistent across all five modules.
12-38: Function calls correctly updated to qualified access pattern.All test function invocations have been consistently updated to use the module namespace (e.g.,
topk_selector.test_topk_selector()), aligning with the import style changes. The arguments passed to each function match the signatures verified in the codebase.src/transform/storage_rewrite.cc (1)
1428-1451: LGTM! Merge logic supports multi-pass region-based buffer handling.The merge logic correctly handles repeated buffer_var declarations by:
- Promoting handle dtypes to concrete types when available
- Updating extents if previously unknown
- Combining declaration locations via bitwise OR
This aligns well with the broader refactor enabling region-based buffer representations across GEMM paths.
tilelang/tileop/gemm/gemm_mfma.py (3)
102-102: Good defensive check for fragment output.The assertion ensures C_region is a full region before proceeding, preventing subtle bugs with partial region outputs.
115-116: Accumulator clearing logic is correctly gated.The conditional
T.clear(C_buf)whenclear_accumis true provides flexibility for callers to control accumulator initialization, and is consistently applied across all GEMM paths.
139-139: Fragment input validation is appropriate.The assertion that fragment inputs must be full regions (lines 139, 171, 199-200) prevents misuse and ensures the MMA intrinsics receive complete operands.
tilelang/intrinsics/mfma_macro_generator.py (2)
254-270: Region-based buffer handling is correctly implemented.The updated signature
A_shared_buf: Buffer | BufferRegionand the normalization viato_buffer_region()correctly supports both plain buffers and buffer regions, extracting base offsets for strided access. This maintains backward compatibility while enabling region-based GEMM operations.
287-296: Buffer indexing correctly applies region offsets.The buffer accesses using
A_buf[A_base0 + l + row, A_base1 + r + col]correctly incorporate the region's base offsets, enabling proper strided GEMM support.
| # We use region for memory input to support strided gemm | ||
| # T.gemm(A_shared[0:128, :], B_shared, C_local) | ||
| A_region = self.ARegion | ||
| B_region = self.BRegion | ||
| C_region = self.CRegion | ||
|
|
||
| A_buf = A_region.buffer | ||
| B_buf = B_region.buffer | ||
| C_buf = C_region.buffer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call region accessor methods before dereferencing.
self.ARegion, self.BRegion, and self.CRegion are methods (per GemmBase), not properties. Assigning them without parentheses captures bound method objects, so the next lines accessing .buffer will raise AttributeError: 'function' object has no attribute 'buffer' at runtime.
Apply this diff:
- A_region = self.ARegion
- B_region = self.BRegion
- C_region = self.CRegion
+ A_region = self.ARegion()
+ B_region = self.BRegion()
+ C_region = self.CRegion()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # We use region for memory input to support strided gemm | |
| # T.gemm(A_shared[0:128, :], B_shared, C_local) | |
| A_region = self.ARegion | |
| B_region = self.BRegion | |
| C_region = self.CRegion | |
| A_buf = A_region.buffer | |
| B_buf = B_region.buffer | |
| C_buf = C_region.buffer | |
| # We use region for memory input to support strided gemm | |
| # T.gemm(A_shared[0:128, :], B_shared, C_local) | |
| A_region = self.ARegion() | |
| B_region = self.BRegion() | |
| C_region = self.CRegion() | |
| A_buf = A_region.buffer | |
| B_buf = B_region.buffer | |
| C_buf = C_region.buffer |
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mfma.py around lines 88 to 96, the code assigns
self.ARegion, self.BRegion, and self.CRegion (which are methods on GemmBase) to
variables and then tries to access .buffer, causing AttributeError; call the
accessor methods (e.g., self.ARegion(), self.BRegion(), self.CRegion()) before
reading .buffer so A_region = self.ARegion() etc., then use A_region.buffer,
B_region.buffer, C_region.buffer.
- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)
712-862: Fix critical parameter mismatch inldmatrix_bsuper() call and add BufferRegion support to preshuffle paths.Line 800 passes 6 arguments to parent's
ldmatrix_bwhich only accepts 4 parameters. The parent class signature isdef ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0), but the child callssuper().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n). This causes a TypeError whenb_preshuffle=Falseandpid_m/pid_nare passed.Additionally, the preshuffle custom implementations (lines 803–859) directly access buffer parameters without normalizing via
to_buffer_region(). IfBufferRegionobjects are passed to these paths, indexing will fail. Apply the same normalization pattern used in the parent class (lines 313–315) to ensure compatibility with bothBufferandBufferRegioninputs.
🧹 Nitpick comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)
288-297: Consider more descriptive variable names forlandr.The buffer access logic correctly incorporates base offsets from the BufferRegion. However, the variable names
landrare ambiguous. Based on context, they appear to represent logical row and column indices.Consider renaming to more descriptive names like
logical_rowandlogical_col, or at minimumrow_idxandcol_idx.row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = (rk * chunk + ki * (k_pack * micro_size_k), - warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, - A_base1 + r + col] + logical_row, logical_col = (rk * chunk + ki * (k_pack * micro_size_k), + warp_m * warp_row_tiles + i * micro_size_x) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + logical_row + row, + A_base1 + logical_col + col]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/intrinsics/mfma_macro_generator.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/utils/language.py (1)
to_buffer_region(187-211)tilelang/intrinsics/mma_macro_generator.py (4)
ldmatrix_a(229-343)ldmatrix_a(887-991)ldmatrix_b(345-467)ldmatrix_b(993-1105)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py
294-294: Ambiguous variable name: l
(E741)
🔇 Additional comments (5)
tilelang/intrinsics/mfma_macro_generator.py (5)
5-5: LGTM! Necessary imports for BufferRegion support.The added imports (
BufferRegionandto_buffer_region) are essential for the new region-aware buffer handling introduced in this refactor.Also applies to: 12-12
142-142: LGTM! Completes bfloat16 support.The addition of the
"bfloat16": "bf16"mapping ensures consistency with the class-leveldtype_abbrvdictionary and properly supports bfloat16 operations.
301-301: LGTM! Consistent BufferRegion normalization.The signature change and normalization logic for
ldmatrix_bmirror the implementation inldmatrix_a, ensuring consistent handling ofBuffer | BufferRegioninputs across both methods.Also applies to: 313-317
336-348: LGTM! Buffer access consistent with ldmatrix_a.The buffer access patterns correctly add base offsets and are consistent with the
ldmatrix_aimplementation.
255-255: 2D buffer assumption verified and confirmed valid.The type signature change to
Buffer | BufferRegionand normalization logic are correctly implemented. Verification across 29 call sites confirms all ldmatrix_a invocations are GEMM/MFMA matrix operations, which inherently require 2D buffers. The region indexing viaregion[-2].minandregion[-1].minsafely accesses matrix dimensions. Theto_buffer_region()conversion maintains backward compatibility while properly handling both input types.
…eneration - Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits. - Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity. - Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/target/codegen_hip.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_hip.cc (1)
src/target/codegen_hip.h (1)
enable_fp8_(79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
src/target/codegen_hip.cc
Outdated
| std::unordered_map<std::string, std::string> acc_type_map = { | ||
| {"float32", "float"}, | ||
| {"float64", "double"}, | ||
| {"float16x4", "float16x4"}, | ||
| {"bfloat16x4", "bfloat16x4"}, | ||
| {"float32x4", "float32x4"}, | ||
| {"float8_e4m3fnuzx4", "fp8_e4_4_t"}, | ||
| {"float8_e4m3fnuzx8", "long"}, | ||
| {"int32x4", "int32x4"}, | ||
| {"float32x16", "float32x16"}}; | ||
| std::string call_mfma_code = R"({ | ||
| *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), | ||
| *((({B_dtype}*){b_ref}) + {b_bias}), | ||
| *((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0); | ||
| })"; | ||
| std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix; | ||
| Replacer replacer; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include <unordered_map> to fix the compile failure
std::unordered_map is introduced here but the file does not include <unordered_map>, so this will not compile in a clean build. Please add the missing header near the other standard includes.
+#include <unordered_map>Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/target/codegen_hip.cc around lines 949 to 955, the new use of
std::unordered_map requires including the <unordered_map> header; add #include
<unordered_map> alongside the other standard includes near the top of the file
so the code compiles in a clean build.
…dling in CodeGenTileLangHIP - Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings. - Updated logic to track user-selected backends and conditionally enable defaults based on environment variables. - Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity. - Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tilelang/intrinsics/mfma_macro_generator.py (1)
257-257: LGTM: Region-based buffer access properly implemented.The changes correctly normalize the input to
BufferRegionand use base offsets in buffer accesses. The pattern is consistent with similar changes inmma_macro_generator.pyand maintains backward compatibility by accepting bothBufferandBufferRegion.Optional: Consider more descriptive variable names.
Static analysis flags
l(line 296) as an ambiguous variable name. While the meaning is clear in context (left/right indices), consider using more descriptive names likerow_idxandcol_idxfor improved readability.Also applies to: 269-274, 290-291, 298-299
CMakeLists.txt (1)
157-177: Verify mutual exclusivity of backend selection.The auto-selection logic (lines 157–177) may enable multiple backends simultaneously (e.g., both Metal and CUDA on macOS if the conditionals don't properly exclude each other). Check whether the build system supports or requires mutually exclusive backends. If mutual exclusivity is required, add validation logic to error when multiple backends are enabled.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
CMakeLists.txt(4 hunks)tilelang/intrinsics/mfma_macro_generator.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/utils/language.py (1)
to_buffer_region(187-211)tilelang/intrinsics/mma_macro_generator.py (4)
ldmatrix_a(229-343)ldmatrix_a(887-991)_warp_ldmatrix_a(304-341)_warp_ldmatrix_a(901-989)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py
296-296: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
🔇 Additional comments (5)
tilelang/intrinsics/mfma_macro_generator.py (3)
5-5: LGTM: Necessary imports for BufferRegion support.The added imports enable region-based buffer handling, which is consistent with the broader refactoring effort across the codebase.
Also applies to: 12-12
142-156: LGTM: bf16 support properly added.The bfloat16 handling follows the established pattern for other data types and uses the appropriate MFMA suffix (
bf16_1k) for the intrinsic.
303-303: LGTM: Consistent implementation across ldmatrix_b.The changes to
ldmatrix_bcorrectly mirror the pattern used inldmatrix_a, ensuring consistent region-based buffer handling across both methods.Also applies to: 315-320, 338-339, 349-350
CMakeLists.txt (2)
78-106: No issues found with CMake syntax on line 88.The syntax
if(DEFINED CACHE{${_backend_var}})is valid CMake, introduced in CMake 3.14 to test cache entries specifically. This concern can be removed.
100-105: Type change to STRING is intentional and necessary for dual-purpose storage; no compatibility issues found in visible codebase.The change from BOOL to STRING at line 127 is intentional and required by design. Line 205 demonstrates the dual purpose:
cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA)reassignsUSE_CUDAto a filesystem path after the CUDA toolkit is discovered. The variable must hold both boolean values (ON/OFF) and paths, necessitating STRING type.All visible consumers are compatible:
- Conditional checks (
if(USE_CUDA)at lines 179, 186, 198) work correctly with STRING values "ON"/"OFF"- Function calls (
find_rocm(${USE_ROCM})at line 189) pass the value without type assumptions- Environment variable consumers (version_provider.py) are independent
The external TVM module (
FindROCMat line 188) cannot be verified from the visible codebase, but the current usage pattern shows no issues.
…or.py - Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity. - Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)
291-292: LGTM: Buffer access correctly applies region-based offsets.The buffer indexing properly applies
A_base0andA_base1offsets derived from the BufferRegion, enabling correct addressing for both transposed and non-transposed cases.Optional: Consider more descriptive variable names. The variable
l(used here but defined in unchanged code) is flagged by static analysis as ambiguous. Consider renaming to something more descriptive likerow_offsetorspatial_idxin a future refactor.Also applies to: 299-300
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/target/codegen_hip.cc(1 hunks)tilelang/intrinsics/mfma_macro_generator.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/mfma_macro_generator.py (3)
tilelang/utils/language.py (2)
is_fragment(105-116)to_buffer_region(187-211)tilelang/intrinsics/mma_sm70_macro_generator.py (3)
ldmatrix_a(190-232)_warp_ldmatrix_a(216-230)ldmatrix_b(234-284)tilelang/intrinsics/mma_macro_generator.py (6)
ldmatrix_a(229-343)ldmatrix_a(887-991)_warp_ldmatrix_a(304-341)_warp_ldmatrix_a(901-989)ldmatrix_b(345-467)ldmatrix_b(993-1105)
🪛 Ruff (0.14.3)
tilelang/intrinsics/mfma_macro_generator.py
297-297: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Quick Lint
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build SDist
🔇 Additional comments (7)
src/target/codegen_hip.cc (1)
931-931: ****Verification confirms the change is correct. The type
bfloat16x4_vecis defined insrc/tl_templates/hip/common.has a GCC vector attribute type, and it is used insrc/tl_templates/hip/gemm.hfor MFMA operations. The naming is consistent—onlybfloat16x4_vecexists in the codebase; analogous vector types likefloat16x4_vecandfloat32x4_vecare not defined. The structbfloat16x4remains available for existing code that uses it directly, so this dtype_map change does not introduce breaking changes—it only affects the C++ type emitted in generated MFMA code.tilelang/intrinsics/mfma_macro_generator.py (6)
5-5: LGTM: Import additions support BufferRegion functionality.The new imports for
BufferRegionandto_buffer_regionare necessary for the region-based buffer addressing introduced in this file, consistent with similar changes across the codebase.Also applies to: 12-12
142-142: LGTM: bf16 abbreviation mapping added.The addition of the "bfloat16": "bf16" mapping is consistent with other dtype abbreviations and enables bf16 support in MFMA intrinsics.
154-156: LGTM: bf16 suffix formatting matches HIP intrinsic conventions.The special suffix format for bf16 (
bf16_1kwithout underscore separator) is correctly documented and aligns with HIP intrinsic naming requirements.
304-304: LGTM: Consistent BufferRegion normalization for B matrix.The BufferRegion normalization in
ldmatrix_bmirrors the implementation inldmatrix_a, correctly extracting base offsets for region-based addressing. The same dimension assumptions apply (2D+ regions for matrix operations).Also applies to: 316-320
339-340: LGTM: Buffer access correctly applies region-based offsets for B matrix.The buffer indexing in
ldmatrix_bproperly appliesB_base0andB_base1offsets, mirroring the correct implementation pattern fromldmatrix_a.Also applies to: 350-351
258-258: Region dimension assumption verified—code is safe.All callers of
ldmatrix_aare GEMM kernels where the matrix A has at least 2 dimensions. Theto_buffer_regionfunction converts matrices to BufferRegions with full-region coverage, guaranteeing the region has at least 2 dimensions. The indexing patternregion[-2].minandregion[-1].minis consistent and used throughout the codebase with no reported issues.
…r non-negative index checks to reduce log verbosity.
- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices. - Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tilelang/intrinsics/wgmma_macro_generator.py (1)
191-193: Pass buffers to_determinate_swizzle_mode.
_determinate_swizzle_modeoperates onBuffer(it probes.shapevia layout helpers). Feeding itBufferRegionwill blow up as soon as the helper runs, so both the main WGMMA path and the RS fallback regress. Please unwrap the regions before calling the helper.- a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) - b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout) ... - b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)Also applies to: 352-353
🧹 Nitpick comments (1)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
80-86: Consider applying the same refactor pattern for consistency.The forward pass still uses the older pattern with
alloc_localand conditional branches, while the backward pass (lines 270-273) has been refactored to use an inline ternary expression. For consistency and maintainability, consider refactoring this section similarly:- end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) - start = T.alloc_local([1], 'int32') - if window_size is not None: - start[0] = T.max(0, (bx * block_M - window_size) // block_N) - else: - start[0] = 0 - - for k in T.Pipelined(start[0], end, num_stages=num_stages): + end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py(1 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py(1 hunks)src/transform/legalize_negative_index.cc(1 hunks)tilelang/intrinsics/wgmma_macro_generator.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
Pipelined(57-94)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
is_fragment(105-116)retrive_ptr_from_buffer_region(255-282)is_full_region(370-399)tilelang/language/builtin.py (2)
initialize_wgmma_descriptor(703-730)warpgroup_fence_operand(433-592)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
Pipelined(57-94)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py
274-274: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
396-396: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
396-396: Unpacked variable warp_m is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
🔇 Additional comments (4)
src/transform/legalize_negative_index.cc (1)
53-55: Verify that suppressing this warning in release builds is intentional.The change from
LOG(WARNING)toDLOG(WARNING)means this diagnostic message will only appear in debug builds. If this warning helps identify potential indexing issues or aids in debugging production problems, consider keeping it asLOG(WARNING). However, if this warning is noisy and not actionable in production, the debug-only logging is appropriate.examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
84-85: LGTM: Clean refactor to scalar expression.The change from local array allocation to a direct scalar expression simplifies the code while preserving the logic. The use of
T.max(0, ...)correctly guards against negative start positions whenwindow_sizeis large relative tobx * block_M.
87-87: LGTM: Correctly updated to use scalar start.The update to pass the scalar
startdirectly (instead of dereferencingstart[0]) is the necessary complement to the refactor on lines 84-85. This aligns with theT.Pipelinedsignature expectation of aPrimExpr.examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
270-273: Good simplification of loop bound computation.The refactor consolidates the loop end calculation into a single inline expression, removing the local allocation overhead and improving readability. The logic is correct: it computes the appropriate upper bound for the sliding window case and falls back to the full sequence length when no window is specified.
- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices. - Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)
191-192: Passregion.bufferto_determinate_swizzle_mode.As flagged in previous reviews,
_determinate_swizzle_modeexpects aBuffer(line 150) but receivesBufferRegionobjects, causing a crash when accessing.shape.Apply this diff:
- a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) - b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
346-346: PassB_region.bufferto_determinate_swizzle_mode.Identical to the issue at lines 191-192:
_determinate_swizzle_modeexpects aBufferbut receives aBufferRegion, which will crash when accessing.shape.Apply this diff:
- b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
🧹 Nitpick comments (3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
81-84: Consider using T.floordiv for consistency.The inline expression for
startsimplifies the code well by removing the local allocation. However, line 81-82 uses Python's//operator while line 266 usesT.floordivfor a similar floor division operation. Although//should work through PrimExpr operator overloading, usingT.floordivwould be more consistent with the codebase pattern.Apply this diff for consistency:
- start = T.max(0, - (bx * block_M - window_size) // block_N) if window_size is not None else 0 + start = T.max(0, + T.floordiv(bx * block_M - window_size, block_N)) if window_size is not None else 0tilelang/intrinsics/wgmma_macro_generator.py (2)
274-274: Consider prefixing unused variable with underscore.Variable
txis unpacked but never used in the macro body. If it's intentionally unused, prefix it with_to indicate this.Apply this diff:
- tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)Based on static analysis.
390-390: Prefix unused variables with underscores.Variables
txandwarp_mare unpacked but never used in thewgmma_rsmacro body. Prefix them with_to indicate they're intentionally unused.Apply this diff:
- tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + _tx, warp_n, _warp_m = self.extract_thread_binding(thread_binding)Based on static analysis.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py(2 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py(2 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py(1 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)examples/linear_attention/example_linear_attn_fwd.py(1 hunks)tilelang/intrinsics/wgmma_macro_generator.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
tilelang/language/loop.py (1)
Pipelined(57-94)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/language/loop.py (1)
Pipelined(57-94)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/language/loop.py (1)
Pipelined(57-94)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
is_fragment(105-116)retrive_ptr_from_buffer_region(255-282)is_full_region(370-399)tilelang/language/builtin.py (2)
initialize_wgmma_descriptor(703-730)warpgroup_fence_operand(433-592)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
Pipelined(57-94)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
tilelang/language/loop.py (1)
Pipelined(57-94)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py
274-274: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
390-390: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
390-390: Unpacked variable warp_m is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (12)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
165-166: LGTM! Cleaner scalar expression for start boundary.The direct scalar computation is more readable and efficient than the previous array-based approach. The conditional logic correctly handles both full attention (start=0) and sliding window cases (clamping to non-negative block index).
168-168: LGTM! Correct usage of scalar start with Pipelined.Passing the scalar
startdirectly aligns with thePipelinedfunction signature, which expectstir.PrimExprparameters. This eliminates the unnecessary array indirection.examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
84-87: LGTM! Clean refactor to scalar expression.The refactor from per-iteration local allocation to a direct scalar computation simplifies the code and eliminates unnecessary overhead. The
T.max(0, ...)correctly handles edge cases where the sliding window start might be negative, and the conditional expression properly adapts to thewindow_sizeparameter.
266-270: LGTM! Consistent backward pass refactor.The refactor mirrors the forward pass changes, replacing the per-iteration local allocation with a direct scalar computation. The logic correctly computes the loop end boundary for both sliding window and full attention modes, and the change maintains consistency across the codebase.
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)
175-179: LGTM! Clean refactoring that simplifies loop bound computation.The change replaces per-iteration local scalar tracking with a direct scalar expression, removing unnecessary allocation and indexing. The start computation correctly identifies the first block within the sliding window (or 0 for full attention), and the scalar value aligns with the
Pipelinedfunction signature which expects aPrimExprscalar.examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
267-270: LGTM! Clean simplification of loop bounds.The inline expression for
loop_edeffectively simplifies the logic by removing the local allocation and using a direct scalar value instead of array indexing (loop_edvsloop_ed[0]). The ternary expression correctly computes the loop bound for sliding window attention in the backward pass, and the change improves code readability.tilelang/intrinsics/wgmma_macro_generator.py (4)
7-8: LGTM on import additions.The BufferRegion import and utility function imports are correct and properly support the refactoring to region-based parameters.
163-171: Signature refactoring looks good.The method now correctly accepts BufferRegion parameters, and the fragment dispatch logic at line 170 properly delegates to
wgmma_rswhen A is a fragment.
266-270: Proper region-to-pointer conversion and validation.The pointer extraction using
retrive_ptr_from_buffer_regionand the full-region assertion for the output accumulator are correctly implemented.
382-386: Correct region assertions and pointer/buffer extraction.The full-region assertions for fragment inputs and the extraction of buffers/pointers from regions are properly implemented for the RS (register-to-shared) path.
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
168-169: LGTM! Simplified start computation is correct.The scalar computation correctly determines the starting block for the sliding window. The floor division may occasionally include one extra block at the boundary that gets fully masked out by the explicit bounds checking in
MMA0, but this conservative approach is safe and doesn't affect correctness. The change eliminates unnecessary local allocation overhead.
172-172: LGTM! Correct usage of scalar start.Passing
startdirectly (instead ofstart[0]) is correct now that it's computed as a scalarPrimExpr. This aligns with thePipelinedfunction signature and eliminates unnecessary indexing.
| def tl_fused_chunk_fwd(q, k, v): | ||
| B, S, H, D = q.shape | ||
| kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) | ||
| print(kernel.get_kernel_source()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make kernel source printing conditional to avoid cluttering benchmark output.
The print statement will execute every time tl_fused_chunk_fwd is called, including during the do_bench iterations at line 141. This will flood the output with kernel source text and make the benchmark results unreadable.
Consider one of these alternatives:
Option 1: Make it conditional on an environment variable or flag
- print(kernel.get_kernel_source())
+ if os.environ.get('TILELANG_DEBUG_KERNEL_SOURCE'):
+ print(kernel.get_kernel_source())Option 2: Print once before benchmarking in main()
def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
- print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o)
return o, hThen add in main() before benchmarking:
def main(B=1, S=512, H=16, D=128):
# ... setup code ...
# Print kernel source once for inspection
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
o, h = tl_fused_chunk_fwd(q, k, v)
# ... rest of main ...🤖 Prompt for AI Agents
In examples/linear_attention/example_linear_attn_fwd.py around line 93, the
unconditional print(kernel.get_kernel_source()) inside tl_fused_chunk_fwd is
flooding benchmark output; remove that print from the hot path and instead
either (A) print the kernel source once before benchmarking in main() after
constructing the kernel, or (B) wrap the print in a conditional guard (e.g.,
check an environment variable or a debug flag) so it only emits when explicitly
enabled; update main() to print the kernel source once if you choose option A,
and ensure no prints occur inside the timed/iterative code paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)
191-192: Pass.bufferto_determinate_swizzle_modeto avoid AttributeError.
_determinate_swizzle_modeexpects aBuffer(line 150) and accesses.shapeand other Buffer attributes. PassingBufferRegiondirectly will crash. This issue was flagged in previous reviews.Apply this diff:
- a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) - b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + a_swizzle_mode = self._determinate_swizzle_mode(A_region.buffer, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
346-346: Pass.bufferto_determinate_swizzle_modein wgmma_rs.Same issue as lines 191-192:
_determinate_swizzle_modeexpects aBuffer, not aBufferRegion. This was flagged in previous reviews.Apply this diff:
- b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region.buffer, self.b_shared_layout)
🧹 Nitpick comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)
274-274: Prefix unused unpacked variables with underscore.
txis unpacked but never used. Consider prefixing it with an underscore to indicate it's intentionally unused:_tx, warp_n, warp_m = ...
390-390: Prefix unused unpacked variables with underscore.Both
txandwarp_mare unpacked but never used in this macro. Consider prefixing them with underscores:_tx, warp_n, _warp_m = ...
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/intrinsics/wgmma_macro_generator.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/wgmma_macro_generator.py (2)
tilelang/utils/language.py (3)
is_fragment(105-116)retrive_ptr_from_buffer_region(255-282)is_full_region(370-399)tilelang/language/builtin.py (1)
warpgroup_fence_operand(433-592)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py
274-274: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
390-390: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
390-390: Unpacked variable warp_m is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (6)
tilelang/intrinsics/wgmma_macro_generator.py (6)
7-8: LGTM: Imports support BufferRegion refactor.The imports correctly add
BufferRegionand the utility functions needed to work with region-based parameters. Theretrive_ptr_from_buffer_regionhelper is now properly centralized intilelang.utils(addressing previous review concerns).
164-171: LGTM: Signature refactored to BufferRegion.The method signature correctly migrates to region-based parameters. The dispatch logic at lines 170-171 appropriately routes to
wgmma_rswhen A is a fragment.
266-270: LGTM: Region-to-pointer/buffer extraction is correct.The code properly extracts pointers for shared memory operands (A, B) using
retrive_ptr_from_buffer_regionand ensures C is a full region before extracting its buffer for fragment operations.
273-315: LGTM: Macro signature correctly uses pointers for shared memory operands.The inner
_warp_mmamacro signature at line 273 properly reflects the pointer-based access pattern for A and B (shared memory) while keeping C as a buffer (fragment). The descriptor initialization (lines 278-283), fence operations (lines 284, 313), and final invocation (line 315) are all consistent with this approach.
382-386: LGTM: wgmma_rs region extraction is correct.The code properly distinguishes between fragment buffers (A, C) and shared memory pointers (B), with appropriate full-region assertions before buffer extraction.
389-438: LGTM: wgmma_rs macro correctly handles mixed buffer/pointer operands.The macro signature (line 389) and implementation properly handle the asymmetric operand types: fragment buffers for A and C, shared memory pointer for B. The descriptor initialization (line 393), fences (lines 396-397, 435-436), and invocation (line 438) are all consistent.
- Removed outdated shared memory reduction logic from `reduce.cc`. - Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes. - Updated CUDA header to define a wider accumulator type for better numerical accuracy. - Enhanced error handling for buffer scope validation in the reduction process.
- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present. - Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/reduce.cc (1)
363-365: Restore shared-memory lowering before landingWe used to lower
ReduceOpwhen either side lived inshared/shared.dyn. After this edit the code returnsLOG(FATAL)for those scopes, so every existing kernel that still performs a block-level reduction through shared memory now hard-crashes instead of compiling. Please reinstate the shared-memory lowering (or offer an equivalent replacement) before merging—otherwise we break the current API surface.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/op/reduce.cc(1 hunks)src/tl_templates/cuda/reduce.h(1 hunks)src/transform/lower_opaque_block.cc(3 hunks)tilelang/language/reduce.py(2 hunks)tilelang/language/v2/builder.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.
Applied to files:
src/tl_templates/cuda/reduce.h
🧬 Code graph analysis (2)
tilelang/language/reduce.py (3)
tilelang/language/copy.py (1)
copy(11-87)tilelang/language/allocate.py (1)
alloc_fragment(59-70)tilelang/utils/language.py (2)
is_shared(45-60)is_fragment(105-116)
src/tl_templates/cuda/reduce.h (2)
src/tl_templates/cuda/common.h (5)
tl(195-289)tl(518-520)tl(521-523)tl(538-618)bfloat16_t(528-529)src/tl_templates/hip/reduce.h (1)
tl(5-88)
🪛 Ruff (0.14.3)
tilelang/language/reduce.py
100-100: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build SDist
🔇 Additional comments (6)
tilelang/language/v2/builder.py (1)
228-249: No issues found. The change allowing silent Buffer evaluation is correct and intentional.The modification treats
tvm.tir.Bufferobjects likeNoneandstrvalues in theeval()method—they're valid but no-op cases. This aligns with Python semantics where expression statements can contain just identifiers or other values without action. Afterunwrap_expr()convertslocal.varBuffers toBufferLoadnodes (line 229), any remaining Buffer objects are likely parameter buffers or results from other operations that shouldn't raise an error when evaluated as standalone expressions.The change is consistent with the broader BufferRegion refactor evident in the codebase and poses no risk of suppressing legitimate errors since malformed expressions would be caught earlier in the evaluation chain.
src/transform/lower_opaque_block.cc (2)
122-122: LGTM: Consistent handling of unit loops with unroll pragmas.The updated logic correctly treats loops with extent 1 and only
pragma_unroll_explicitas unit loops that can be simplified away. Both call sites (line 122 for recording the loop variable, lines 138-139 for elimination) use the same helper, ensuring consistent behavior.Also applies to: 138-139
154-169: LGTM: Well-designed helper for annotation filtering.The helper function is correctly implemented with appropriate const-correctness. The logic properly identifies annotations that are either truly empty or contain only the unroll pragma, allowing unit-length loops produced by unroll directives to be optimized away.
src/tl_templates/cuda/reduce.h (3)
1-250: No action needed. SharedReduceWarp removal is not a breaking change.The search confirms that
SharedReduceWarphas zero usages anywhere in the codebase. CUDA and HIP versions are separate, independently-loaded implementations for different hardware backends—they're not meant to be identical. The CUDA version has been modernized with the removal of unused code, which is correct and safe. The HIP version can maintain its own code independently.
4-5: The original review comment is incorrect—<cstdint>is actively used in the file.The file uses
uint32_t(from<cstdint>) at lines 72 and 92 in the expressionuint32_t(-1). However,<type_traits>does appear to be unused based on the absence of anystd::namespace usages in the file.Likely an incorrect or invalid review comment.
9-19: TheAccTypetrait is defined but not used in this file or elsewhere in the codebase.While the trait's design is sound—promoting FP16/BF16 to float for accumulation improves numerical accuracy and follows standard CUDA practices—the trait struct remains unused. The reduction operators in this file do not reference it.
Either integrate
AccTypeinto the reduction implementations usingtypename AccType<T>::type, or remove it if it's planned for future work rather than current use.
| red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) | ||
| red_frag_out = alloc_fragment(out.shape, out.dtype) | ||
|
|
||
| # rename buffers | ||
| IRBuilder.name(buffer.name + "_frag", red_frag_in) | ||
| IRBuilder.name(out.name + "_frag", red_frag_out) | ||
|
|
||
| copy(buffer, red_frag_in) | ||
| tir.call_intrin( | ||
| "handle", | ||
| tir.op.Op.get("tl.reduce"), | ||
| red_frag_in.access_ptr("r"), | ||
| red_frag_out.access_ptr("w"), | ||
| reduce_type, | ||
| dim, | ||
| clear, | ||
| ) | ||
| copy(red_frag_out, out) | ||
| elif is_shared(buffer) and is_fragment(out): | ||
| red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) | ||
| IRBuilder.name(buffer.name + "_frag", red_frag_in) | ||
|
|
||
| copy(buffer, red_frag_in) | ||
| tir.call_intrin( | ||
| "handle", | ||
| tir.op.Op.get("tl.reduce"), | ||
| red_frag_in.access_ptr("r"), | ||
| out.access_ptr("w"), | ||
| reduce_type, | ||
| dim, | ||
| clear, | ||
| ) | ||
| elif is_fragment(buffer) and is_shared(out): | ||
| red_frag_out = alloc_fragment(out.shape, out.dtype) | ||
| IRBuilder.name(out.name + "_frag", red_frag_out) | ||
|
|
||
| tir.call_intrin( | ||
| "handle", | ||
| tir.op.Op.get("tl.reduce"), | ||
| buffer.access_ptr("r"), | ||
| red_frag_out.access_ptr("w"), | ||
| reduce_type, | ||
| dim, | ||
| clear, | ||
| ) | ||
| copy(red_frag_out, out) | ||
| elif is_fragment(buffer) and is_fragment(out): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preserve accumulation semantics when clear=False.
In both shared-output branches we allocate red_frag_out and pass it to tl.reduce without seeding it with the existing out contents. For clear=False, the intrinsic expects the destination pointer to contain the prior accumulator; feeding an uninitialized fragment drops the existing values and corrupts results. Please copy the shared buffer into the fragment before the reduce when clear is false.
if is_shared(buffer) and is_shared(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
red_frag_out = alloc_fragment(out.shape, out.dtype)
@@
- tir.call_intrin(
+ if not clear:
+ copy(out, red_frag_out)
+ tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
@@
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)
tir.call_intrin(
+ if not clear:
+ copy(out, red_frag_out)
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
red_frag_out.access_ptr("w"),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) | |
| red_frag_out = alloc_fragment(out.shape, out.dtype) | |
| # rename buffers | |
| IRBuilder.name(buffer.name + "_frag", red_frag_in) | |
| IRBuilder.name(out.name + "_frag", red_frag_out) | |
| copy(buffer, red_frag_in) | |
| tir.call_intrin( | |
| "handle", | |
| tir.op.Op.get("tl.reduce"), | |
| red_frag_in.access_ptr("r"), | |
| red_frag_out.access_ptr("w"), | |
| reduce_type, | |
| dim, | |
| clear, | |
| ) | |
| copy(red_frag_out, out) | |
| elif is_shared(buffer) and is_fragment(out): | |
| red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) | |
| IRBuilder.name(buffer.name + "_frag", red_frag_in) | |
| copy(buffer, red_frag_in) | |
| tir.call_intrin( | |
| "handle", | |
| tir.op.Op.get("tl.reduce"), | |
| red_frag_in.access_ptr("r"), | |
| out.access_ptr("w"), | |
| reduce_type, | |
| dim, | |
| clear, | |
| ) | |
| elif is_fragment(buffer) and is_shared(out): | |
| red_frag_out = alloc_fragment(out.shape, out.dtype) | |
| IRBuilder.name(out.name + "_frag", red_frag_out) | |
| tir.call_intrin( | |
| "handle", | |
| tir.op.Op.get("tl.reduce"), | |
| buffer.access_ptr("r"), | |
| red_frag_out.access_ptr("w"), | |
| reduce_type, | |
| dim, | |
| clear, | |
| ) | |
| copy(red_frag_out, out) | |
| elif is_fragment(buffer) and is_fragment(out): | |
| if is_shared(buffer) and is_shared(out): | |
| red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) | |
| red_frag_out = alloc_fragment(out.shape, out.dtype) | |
| # rename buffers | |
| IRBuilder.name(buffer.name + "_frag", red_frag_in) | |
| IRBuilder.name(out.name + "_frag", red_frag_out) | |
| copy(buffer, red_frag_in) | |
| if not clear: | |
| copy(out, red_frag_out) | |
| tir.call_intrin( | |
| "handle", | |
| tir.op.Op.get("tl.reduce"), | |
| red_frag_in.access_ptr("r"), | |
| red_frag_out.access_ptr("w"), | |
| reduce_type, | |
| dim, | |
| clear, | |
| ) | |
| copy(red_frag_out, out) | |
| elif is_shared(buffer) and is_fragment(out): | |
| red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) | |
| IRBuilder.name(buffer.name + "_frag", red_frag_in) | |
| copy(buffer, red_frag_in) | |
| tir.call_intrin( | |
| "handle", | |
| tir.op.Op.get("tl.reduce"), | |
| red_frag_in.access_ptr("r"), | |
| out.access_ptr("w"), | |
| reduce_type, | |
| dim, | |
| clear, | |
| ) | |
| elif is_fragment(buffer) and is_shared(out): | |
| red_frag_out = alloc_fragment(out.shape, out.dtype) | |
| IRBuilder.name(out.name + "_frag", red_frag_out) | |
| if not clear: | |
| copy(out, red_frag_out) | |
| tir.call_intrin( | |
| "handle", | |
| tir.op.Op.get("tl.reduce"), | |
| buffer.access_ptr("r"), | |
| red_frag_out.access_ptr("w"), | |
| reduce_type, | |
| dim, | |
| clear, | |
| ) | |
| copy(red_frag_out, out) | |
| elif is_fragment(buffer) and is_fragment(out): |
🤖 Prompt for AI Agents
In tilelang/language/reduce.py around lines 43 to 89, the branches that allocate
a red_frag_out for shared output do not seed it with the existing shared buffer
when clear is False; change those branches so that immediately after allocating
red_frag_out (and naming it) you conditionally copy(out, red_frag_out) if not
clear, then call tir.call_intrin; this ensures the fragment contains the prior
accumulator before tl.reduce writes into it, preserving accumulation semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
CMakeLists.txt (1)
164-165: ROCm path handling lost when auto-selecting backend.As flagged in the previous review, line 165 converts any non-empty
$ENV{USE_ROCM}toON, discarding path information. Since the documentation (line 71) states thatUSE_ROCMsupports "ON/OFF/or ROCm SDK path", this loses the ability to use a specific ROCm SDK path if provided.elseif($ENV{USE_ROCM}) - set(USE_ROCM ON) + set(USE_ROCM $ENV{USE_ROCM})This preserves the environment value, allowing
find_rocm()to receive eitherONor a specific path as intended.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
CMakeLists.txt(4 hunks)examples/linear_attention/example_retention_fwd.py(0 hunks)
💤 Files with no reviewable changes (1)
- examples/linear_attention/example_retention_fwd.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (2)
CMakeLists.txt (2)
122-129: Re-application of backend settings post-TVM configuration is sound.The logic correctly caches TileLang's backend preferences before including TVM's config.cmake, then restores them afterward to prevent TVM from overriding user-selected backends. This approach is robust and preserves user intent.
147-155: User selection detection logic is clear and comprehensive.The logic correctly identifies whether any backend was explicitly selected by the user (either via command-line
-DUSE_*or environment variable at configuration time). This gate prevents unintended auto-selection when user has explicitly chosen a backend.
| if(DEFINED CACHE{${_backend_var}}) | ||
| get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE) | ||
| if(_cache_type STREQUAL "UNINITIALIZED") | ||
| set(_user_override ON) | ||
| endif() | ||
| endif() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid CMake syntax: DEFINED CACHE{...} is not a recognized construct.
CMake does not support DEFINED CACHE{...} syntax. To check if a cache entry exists and its type, use get_property() instead. For example:
if(DEFINED ${_backend_var})
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
if(_cache_type STREQUAL "UNINITIALIZED")
set(_user_override ON)
endif()
endif()This code will likely fail during CMake configuration.
- if(DEFINED CACHE{${_backend_var}})
+ if(DEFINED ${_backend_var})
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)🤖 Prompt for AI Agents
In CMakeLists.txt around lines 88 to 93, the conditional uses invalid CMake
syntax `DEFINED CACHE{...}`; replace it by first checking the variable name
existence and/or directly using get_property to query the CACHE entry type.
Specifically, call get_property(CACHE_VAR_TYPE CACHE ${_backend_var} PROPERTY
TYPE RESULT_VARIABLE) or use if(DEFINED ${_backend_var}) before calling
get_property, then test if the returned type equals "UNINITIALIZED" and
set(_user_override ON) accordingly so CMake config will not fail.

TILELANG_USE_GEMM_V1environment variable to control the selection of GEMM version.use_gemm_v1method in theEnvironmentclass to determine if GEMM v1 should be used based on the environment variable.Summary by CodeRabbit
New Features
TILELANG_USE_GEMM_V1environment variable.Bug Fixes
Refactor