-
Couldn't load subscription status.
- Fork 286
[Refactor] Refactor Pass InjectFenceProxy and expose some warp group primitives in frontend
#977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- annotate proxy fence injector with context comments for async/generic detection - add compiler internals doc covering the pass mechanics and link it in docs index - repair fence proxy test by fixing descriptor init usage and fence counter logic
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. WalkthroughAdds an InjectFenceProxy transform (including TMA store arrive/wait lowering and proxy-kind tracking), introduces warpgroup synchronization intrinsics and codegen, updates WGMMA macro to use them, alters pipeline child handling, adds docs/tests, and adds debug logging around the pass. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Dev as Developer
participant TL as TileLang OptimizeForTarget
participant P1 as TMAStoreSyncInjector
participant P2 as ProxyFenceInjector
participant IR as IRModule
participant CG as CUDA CodeGen
participant CU as CUDA Device (cute)
Dev->>TL: Lower module
TL->>IR: Input IR
TL->>P1: Apply TMA store sync rewrite
P1-->>TL: IR' (tma_store_arrive/wait)
TL->>P2: Apply InjectFenceProxy
P2-->>IR: IR'' (fence.proxy.async on generic→async)
TL-->>Dev: Lowered IR''
Dev->>CG: Generate CUDA
CG->>CU: Emit tl::warpgroup_arrive / commit_batch / wait<N>()
CG-->>Dev: CUDA kernel
sequenceDiagram
autonumber
participant K as Kernel
participant TL as tl::intrinsics
participant CU as cute::arch (SM90)
K->>TL: warpgroup_arrive()
TL->>CU: cute::warpgroup_arrive()
rect rgba(230,240,255,0.5)
K->>K: MMA compute region (wgmma)
end
K->>TL: warpgroup_commit_batch()
TL->>CU: cute::warpgroup_commit_batch()
K->>TL: warpgroup_wait<N>()
TL->>CU: cute::warpgroup_wait<N>()
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
docs/compiler_internals/inject_fence_proxy.md (1)
19-25: Specify a code-fence language for the timeline block.Markdown lint (MD040) flags this fence because it lacks a language tag. Please annotate it (e.g., ```text) so the docs build stays clean.
-``` +```text generic initialize_descriptor → generic shared-store → async wgmma │ │ │ └─ generic proxy ┴─ generic proxy ┴─ async proxy │ fence inserted here ↑ └──────────────────────────────┘
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
docs/compiler_internals/inject_fence_proxy.md(1 hunks)docs/index.md(2 hunks)src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)src/target/codegen_cuda.cc(1 hunks)src/tl_templates/cuda/intrin.h(2 hunks)src/transform/inject_fence_proxy.cc(1 hunks)src/transform/inject_pipeline.cc(2 hunks)testing/python/transform/test_tilelang_transform_inject_fence_proxy.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/intrinsics/wgmma_macro_generator.py(2 hunks)tilelang/language/builtin.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tilelang/engine/phase.py (2)
src/transform/inject_fence_proxy.cc (2)
InjectFenceProxy(305-313)InjectFenceProxy(305-305)tilelang/transform/__init__.py (1)
InjectFenceProxy(230-238)
tilelang/language/builtin.py (2)
tilelang/language/tir/op.py (1)
call_intrin(119-144)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
src/target/codegen_cuda.cc (2)
tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
src/op/builtin.h (2)
tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (3)
tilelang/language/tir/op.py (3)
ptx_cp_async(1244-1273)call_intrin(119-144)ptx_wgmma_ss(1064-1103)tilelang/language/builtin.py (3)
fence_proxy_async(82-91)warpgroup_arrive(252-258)initialize_descriptor(391-422)src/transform/inject_fence_proxy.cc (24)
InjectFenceProxy(305-313)InjectFenceProxy(305-305)stmt(138-138)stmt(138-138)stmt(279-281)stmt(279-279)stmt(283-292)stmt(283-283)op(140-154)op(140-140)op(171-196)op(171-171)op(198-218)op(198-198)op(220-224)op(220-220)op(226-235)op(226-226)op(237-243)op(237-237)op(245-250)op(245-245)op(252-262)op(252-252)
tilelang/intrinsics/wgmma_macro_generator.py (3)
tilelang/language/allocate.py (1)
alloc_descriptor(158-164)tilelang/language/builtin.py (4)
initialize_descriptor(391-422)warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
src/transform/inject_fence_proxy.cc (2)
tilelang/language/builtin.py (5)
fence_proxy_async(82-91)tma_load(70-79)tma_store_arrive(94-103)tma_store_wait(106-115)initialize_descriptor(391-422)tilelang/language/tir/op.py (6)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)ptx_cp_async(1244-1273)ptx_cp_async_barrier(1592-1606)ptx_cp_async_bulk(1276-1310)ptx_ldmatrix(1205-1241)
src/op/builtin.cc (2)
tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
src/tl_templates/cuda/intrin.h (3)
src/op/builtin.h (1)
tl(22-473)src/tl_templates/cuda/gemm_sm90.h (1)
tl(231-385)tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)
🪛 markdownlint-cli2 (0.18.1)
docs/compiler_internals/inject_fence_proxy.md
19-19: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🔇 Additional comments (5)
tilelang/language/builtin.py (1)
252-280: Warpgroup intrinsics wired cleanly.The Python wrappers mirror the new TL ops and keep the handle signatures consistent, so the frontend API stays aligned with the backend changes.
src/target/codegen_cuda.cc (1)
1377-1385: CUDA emission matches the new intrinsics.The dedicated branches map straight to the cute warpgroup helpers and preserve the compile-time MMA count, keeping codegen in lockstep with the intrinsic surface.
src/op/builtin.h (1)
337-359: Header exports look complete.Good to see the warpgroup arrive/commit/wait ops surfaced alongside the existing intrinsics so downstream components can rely on the shared registry.
src/tl_templates/cuda/intrin.h (1)
10-15: Device wrappers line up with cute helpers.Including the SM90 GMMA header and delegating through cute keeps the TL device API tidy while exposing the new warpgroup primitives.
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1)
56-224: Great coverage for the new fence logic.The new cases validate no double-fencing, proxy hints, TMA store sync insertion, and the warpgroup sequencing windows—exactly the behavior we need confidence in.
| Stmt VisitStmt_(const ForNode *op) final { return VisitSingleBody(op); } | ||
| Stmt VisitStmt_(const LetStmtNode *op) final { return VisitSingleBody(op); } | ||
| Stmt VisitStmt_(const AssertStmtNode *op) final { | ||
| return VisitSingleBody(op); | ||
| } | ||
| Stmt VisitStmt_(const WhileNode *op) final { return VisitSingleBody(op); } | ||
|
|
||
| template <typename NodeType> Stmt VisitSingleBody(const NodeType *op) { | ||
| Stmt stmt = StmtMutator::VisitStmt_(op); | ||
| const auto *node = stmt.as<NodeType>(); | ||
| ProxyKind body_kind = GetProxyKind(node->body); | ||
| SetProxyKind(stmt, body_kind); | ||
| return stmt; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Propagate proxy kind through Allocate (and similar) nodes.
ProxyFenceInjector never overrides VisitStmt_(const AllocateNode*), so an Allocate winds up with ProxyKind::kUnknown. When a generic op precedes an Allocate whose body starts with an async op, NeedsFence() sees Unknown → Async and skips the fence. That recreates the very race this pass is supposed to eliminate—most Hopper kernels wrap async copies inside shared-memory alloc scopes, so the fence silently goes missing.
Propagate the body’s proxy kind through Allocate (and related single-body nodes) the same way you already do for For, Let, etc.
Stmt VisitStmt_(const ForNode *op) final { return VisitSingleBody(op); }
Stmt VisitStmt_(const LetStmtNode *op) final { return VisitSingleBody(op); }
Stmt VisitStmt_(const AssertStmtNode *op) final {
return VisitSingleBody(op);
}
Stmt VisitStmt_(const WhileNode *op) final { return VisitSingleBody(op); }
+ Stmt VisitStmt_(const AllocateNode *op) final { return VisitSingleBody(op); }
+ Stmt VisitStmt_(const AllocateConstNode *op) final {
+ return VisitSingleBody(op);
+ }
+ Stmt VisitStmt_(const DeclBufferNode *op) final {
+ return VisitSingleBody(op);
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Stmt VisitStmt_(const ForNode *op) final { return VisitSingleBody(op); } | |
| Stmt VisitStmt_(const LetStmtNode *op) final { return VisitSingleBody(op); } | |
| Stmt VisitStmt_(const AssertStmtNode *op) final { | |
| return VisitSingleBody(op); | |
| } | |
| Stmt VisitStmt_(const WhileNode *op) final { return VisitSingleBody(op); } | |
| template <typename NodeType> Stmt VisitSingleBody(const NodeType *op) { | |
| Stmt stmt = StmtMutator::VisitStmt_(op); | |
| const auto *node = stmt.as<NodeType>(); | |
| ProxyKind body_kind = GetProxyKind(node->body); | |
| SetProxyKind(stmt, body_kind); | |
| return stmt; | |
| } | |
| Stmt VisitStmt_(const ForNode *op) final { return VisitSingleBody(op); } | |
| Stmt VisitStmt_(const LetStmtNode *op) final { return VisitSingleBody(op); } | |
| Stmt VisitStmt_(const AssertStmtNode *op) final { | |
| return VisitSingleBody(op); | |
| } | |
| Stmt VisitStmt_(const WhileNode *op) final { return VisitSingleBody(op); } | |
| Stmt VisitStmt_(const AllocateNode *op) final { return VisitSingleBody(op); } | |
| Stmt VisitStmt_(const AllocateConstNode *op) final { | |
| return VisitSingleBody(op); | |
| } | |
| Stmt VisitStmt_(const DeclBufferNode *op) final { | |
| return VisitSingleBody(op); | |
| } | |
| template <typename NodeType> Stmt VisitSingleBody(const NodeType *op) { | |
| Stmt stmt = StmtMutator::VisitStmt_(op); | |
| const auto *node = stmt.as<NodeType>(); | |
| ProxyKind body_kind = GetProxyKind(node->body); | |
| SetProxyKind(stmt, body_kind); | |
| return stmt; | |
| } |
🤖 Prompt for AI Agents
In src/transform/inject_fence_proxy.cc around lines 264-277, the pass doesn't
override VisitStmt_(const AllocateNode*), so Allocate (and similar single-body
nodes) retain ProxyKind::kUnknown and can hide needed fences; add a VisitStmt_
override for const AllocateNode* that returns VisitSingleBody(op) (and add
similar overrides for any other single-body statement node types not yet
covered) so the body's ProxyKind is propagated into the Allocate node before
returning.
| print("Before injectFenceProxy") | ||
| print(mod) | ||
| mod = tilelang.transform.InjectFenceProxy()(mod) | ||
| print("After InjectFenceProxy") | ||
| print(mod) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unconditional debug prints
Dumping the entire IR module every compile will spam stdout/stderr and severely degrade usability. Please drop these prints or gate them behind an explicit debug flag/logger before merging.
- print("Before injectFenceProxy")
- print(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
- print("After InjectFenceProxy")
- print(mod)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print("Before injectFenceProxy") | |
| print(mod) | |
| mod = tilelang.transform.InjectFenceProxy()(mod) | |
| print("After InjectFenceProxy") | |
| print(mod) | |
| mod = tilelang.transform.InjectFenceProxy()(mod) |
🤖 Prompt for AI Agents
In tilelang/engine/phase.py around lines 159 to 164 there are unconditional
print statements that dump the IR module before and after InjectFenceProxy;
remove these prints or replace them with a proper logger call guarded by a debug
flag. Specifically, delete the two print(...) lines or change them to something
like logger.debug(...) and ensure the logger's level is controlled by a global
debug setting or environment/config flag so the IR is only emitted when
debugging is explicitly enabled. Ensure no stdout/stderr spamming remains in the
hot compile path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
src/transform/inject_fence_proxy.cc (1)
76-107: Consider explicitly classifying the new warpgroup operations.The
IsAsyncIntrinsicfunction correctly identifies WGMMA operations (ptx_wgmma_ss,ptx_wgmma_rs) as async. However, the newly added warpgroup synchronization operations (warpgroup_arrive,warpgroup_commit_batch,warpgroup_wait) are not explicitly classified in eitherIsAsyncIntrinsicorIsKnownGeneric.Currently, they will fall through to the conservative default at line 212 (treating unknown externs as async), which is safe but not explicit. For better maintainability and clarity, consider adding explicit classification for these operations, likely as neutral barriers since they are synchronization primitives:
bool IsAsyncIntrinsic(const CallNode *call) { if (call == nullptr) { return false; } // TileLang async intrinsics if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || call->op.same_as(tma_store()) || call->op.same_as(tma_store_arrive()) || call->op.same_as(tma_store_wait()) || call->op.same_as(ptx_cp_async_barrier_noinc()) || call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs())) { return true; } // PTX async copy intrinsics if (call->op.same_as(builtin::ptx_cp_async()) || call->op.same_as(builtin::ptx_cp_async_barrier()) || call->op.same_as(builtin::ptx_cp_async_bulk())) { return true; } return false; }And add a helper for neutral operations:
+// Operations that act as synchronization barriers and reset proxy state +bool IsNeutralIntrinsic(const CallNode *call) { + if (call == nullptr) { + return false; + } + return call->op.same_as(warpgroup_arrive()) || + call->op.same_as(warpgroup_commit_batch()) || + call->op.same_as(warpgroup_wait()); +}Then update the
VisitStmt_(const EvaluateNode *op)to check this:if (const auto *call = evaluate->value.as<CallNode>()) { if (IsFenceCall(call)) { kind = ProxyKind::kNeutral; + } else if (IsNeutralIntrinsic(call)) { + kind = ProxyKind::kNeutral; } else if (IsAsyncIntrinsic(call)) {testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (2)
56-88: LGTM! Test correctly validates no redundant fence insertion.The test verifies that the pass doesn't add a second fence when one is already present between async and generic operations, which is the expected behavior.
Consider extracting the fence counting logic into a helper function to reduce duplication across tests:
def _count_intrinsic_calls(stmt, intrinsic_name): count = 0 def visit(node): nonlocal count if isinstance(node, tir.Evaluate) and isinstance(node.value, tir.Call): name = getattr(node.value.op, "name", None) if name == intrinsic_name: count += 1 tir.stmt_functor.post_order_visit(stmt, visit) return countThen simplify line 87 to:
assert _count_intrinsic_calls(mod["main"].body, "tl.fence_proxy_async") == 1
90-121: LGTM! Test correctly validates proxy hint override behavior.The test confirms that
proxy_hint="neutral"scopes prevent automatic fence insertion, giving users manual control over fence placement when needed.The
_has_fencevisitor follows the same pattern as_count_fencesin the previous test. Consider using the helper function suggested above with a zero-count check:assert _count_intrinsic_calls(mod["main"].body, "tl.fence_proxy_async") == 0
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
docs/compiler_internals/inject_fence_proxy.md(1 hunks)docs/index.md(2 hunks)src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)src/target/codegen_cuda.cc(1 hunks)src/tl_templates/cuda/intrin.h(2 hunks)src/transform/inject_fence_proxy.cc(1 hunks)src/transform/inject_pipeline.cc(2 hunks)testing/python/transform/test_tilelang_transform_inject_fence_proxy.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/intrinsics/wgmma_macro_generator.py(2 hunks)tilelang/language/builtin.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tilelang/intrinsics/wgmma_macro_generator.py (3)
tilelang/language/allocate.py (1)
alloc_descriptor(158-164)tilelang/language/builtin.py (4)
initialize_descriptor(391-422)warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
tilelang/language/builtin.py (2)
tilelang/language/tir/op.py (1)
call_intrin(119-144)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
tilelang/engine/phase.py (2)
src/transform/inject_fence_proxy.cc (2)
InjectFenceProxy(305-313)InjectFenceProxy(305-305)tilelang/transform/__init__.py (1)
InjectFenceProxy(230-238)
src/target/codegen_cuda.cc (2)
tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
src/op/builtin.h (2)
tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
src/op/builtin.cc (2)
tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)src/tl_templates/cuda/intrin.h (1)
warpgroup_wait(13-15)
src/tl_templates/cuda/intrin.h (3)
src/op/builtin.h (1)
tl(22-473)src/tl_templates/cuda/gemm_sm90.h (1)
tl(231-385)tilelang/language/builtin.py (3)
warpgroup_arrive(252-258)warpgroup_commit_batch(261-267)warpgroup_wait(270-280)
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (3)
tilelang/language/tir/op.py (4)
ptx_cp_async(1244-1273)call_extern(172-194)call_intrin(119-144)ptx_wgmma_ss(1064-1103)tilelang/language/builtin.py (3)
fence_proxy_async(82-91)warpgroup_arrive(252-258)initialize_descriptor(391-422)src/transform/inject_fence_proxy.cc (24)
InjectFenceProxy(305-313)InjectFenceProxy(305-305)stmt(138-138)stmt(138-138)stmt(279-281)stmt(279-279)stmt(283-292)stmt(283-283)op(140-154)op(140-140)op(171-196)op(171-171)op(198-218)op(198-198)op(220-224)op(220-220)op(226-235)op(226-226)op(237-243)op(237-237)op(245-250)op(245-245)op(252-262)op(252-252)
src/transform/inject_fence_proxy.cc (2)
tilelang/language/builtin.py (5)
fence_proxy_async(82-91)tma_load(70-79)tma_store_arrive(94-103)tma_store_wait(106-115)initialize_descriptor(391-422)tilelang/language/tir/op.py (6)
ptx_wgmma_ss(1064-1103)ptx_wgmma_rs(1106-1143)ptx_cp_async(1244-1273)ptx_cp_async_barrier(1592-1606)ptx_cp_async_bulk(1276-1310)ptx_ldmatrix(1205-1241)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
🔇 Additional comments (12)
src/tl_templates/cuda/intrin.h (1)
5-16: LGTM!The new warpgroup synchronization wrappers are correctly implemented:
- The include for
mma_sm90_gmma.hppis necessary for the cute library functions- All three wrappers (
warpgroup_arrive,warpgroup_commit_batch,warpgroup_wait) correctly delegate to their cute counterparts- The templated
warpgroup_wait<NumMma>()properly forwards the template parameter- All functions are appropriately marked
TL_DEVICEand guarded by the SM90+ architecture checksrc/target/codegen_cuda.cc (1)
1377-1385: LGTM!The codegen for the three warpgroup operations is correctly implemented:
warpgroup_arriveandwarpgroup_commit_batchemit simple external calls with no argumentswarpgroup_waitcorrectly extracts thenum_mmaargument fromop->args[0]and emits it as a template parameter- The implementation follows the established pattern for similar intrinsics in this file
src/op/builtin.h (1)
337-359: LGTM!The three new Op declarations are properly implemented:
- All follow the established pattern for Op declarations in this file
- Doxygen documentation is clear and appropriate
- The
TVM_DLLexport macro is correctly applied- Logical placement between
no_set_max_nregandwait_wgmmamaintains file organizationdocs/compiler_internals/inject_fence_proxy.md (1)
1-36: LGTM!This is excellent documentation for the InjectFenceProxy pass:
- Clear explanation of the problem (Hopper generic vs. async proxy paths)
- Comprehensive description of the pass behavior (state tracking, TMA store lowering, fence injection)
- Appropriate technical detail on intrinsic coverage
- Helpful usage example with code
- Useful extension guidance for future maintainers
tilelang/language/builtin.py (1)
252-280: LGTM!The three new Python wrappers are correctly implemented:
- All follow the established pattern of using
tir.call_intrin("handle", tir.op.Op.get(...))- The
warpgroup_waitfunction correctly accepts and forwards thenum_mma: intparameter- Docstrings are clear and appropriate
- Type hints are provided where applicable
src/transform/inject_fence_proxy.cc (3)
126-155: LGTM!The
TMAStoreSyncInjectorcorrectly rewrites TMA store sequences:
- Properly identifies
tma_storeintrinsics in evaluate nodes- Constructs the required synchronization sequence (store → arrive → wait)
- Uses
SeqStmtto maintain the ordering- The
Applymethod follows the established pattern for TIR passes
159-301: LGTM!The
ProxyFenceInjectoris well-implemented with comprehensive statement coverage:
- The
SeqStmtNodevisitor correctly tracks proxy state across statements and injects fences at generic-to-async transitions- Statement classification is appropriate:
EvaluateNode: classifies calls based on intrinsic type, with conservative async default for unknownsBufferStoreNode: correctly marked as genericIfThenElseNode: properly combines proxy kinds from branches- Other control flow nodes appropriately propagate body kinds
- The proxy state management through
proxy_map_is clean and efficient- The
MakeFenceStmthelper correctly creates neutral fence statements
305-313: LGTM!The pass composition is clean and correct:
- Properly composes
TMAStoreSyncInjectorandProxyFenceInjectorin the right order (TMA synchronization rewriting must happen before fence injection)- Uses the standard
CreatePrimFuncPassAPI appropriately- Pass metadata (priority 0, name "tl.InjectFenceProxy") is correctly specified
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (4)
123-152: LGTM! Test correctly validates TMA store synchronization injection.The test confirms that
tl.tma_storeis automatically rewritten to includetma_store_arriveandtma_store_waithandshakes, matching the behavior described in the PR objectives and implemented inTMAStoreSyncInjector::Apply(src/transform/inject_fence_proxy.cc).
154-186: LGTM! Test correctly validates wgmma async marking and fence ordering.The test confirms that
ptx_wgmma_ssoperations are treated as async and that afence_proxy_asyncis inserted before them when transitioning from generic operations. The ordering verification usingorder.index()is appropriate and will fail clearly if the expected intrinsics are not present.
227-227: LGTM! Standard test runner pattern.The change to use
tilelang.testing.main()enables standard test discovery and execution, allowing individual tests to be run via command line arguments or IDE test runners.
188-224: Verify and tighten the fence_count assertionThe test uses
assert fence_count >= 1, which may mask redundant fence insertions. Determine the exact number oftl.fence_proxy_asynccalls in this scenario (e.g., 1) and update the assertion to:assert fence_count == <expected_count>Run the test with the correct PYTHONPATH or via pytest to load the
tilelangpackage before adjusting.
| print("Before injectFenceProxy") | ||
| print(mod) | ||
| mod = tilelang.transform.InjectFenceProxy()(mod) | ||
| print("After InjectFenceProxy") | ||
| print(mod) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug prints from the lowering pipeline.
These unconditional print calls dump entire IRModules during every lowering run, which will flood stdout and noticeably slow compilation for any Hopper target. Please drop them or gate behind an explicit debug flag.
- print("Before injectFenceProxy")
- print(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
- print("After InjectFenceProxy")
- print(mod)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print("Before injectFenceProxy") | |
| print(mod) | |
| mod = tilelang.transform.InjectFenceProxy()(mod) | |
| print("After InjectFenceProxy") | |
| print(mod) | |
| mod = tilelang.transform.InjectFenceProxy()(mod) |
🤖 Prompt for AI Agents
In tilelang/engine/phase.py around lines 159 to 164, remove the unconditional
print statements that dump the IRModule before and after InjectFenceProxy;
either delete these prints or replace them with a gated debug mechanism (e.g.,
use an existing logger.debug call or check a module-level/dev flag or
environment variable before printing) so the IR is not always written to stdout
and compilation is not slowed or flooded.
|
local test passed :) |
This pull request introduces a new TIR pass
tl.InjectFenceProxyfor NVIDIA Hopper (SM90+) GPUs and adds support for new warpgroup synchronization intrinsics. The main goal is to ensure correct ordering between generic and asynchronous memory operations by automatically injectingfence.proxy.asyncinstructions and handling TMA store synchronization. The pass is now documented, and the codebase has been refactored for clarity and extensibility. Additionally, new intrinsics for warpgroup synchronization are registered and lowered to CUDA code.TIR Fence Proxy Pass and Synchronization Improvements
tl.InjectFenceProxyto automatically insertfence.proxy.asyncinstructions when control flow switches from generic memory operations to asynchronous proxy operations, preventing race conditions and undefined behavior on NVIDIA Hopper (SM90+). The pass is conservative and treats unknown extern calls as async.ProxyMarkerandInjectFenceProxyclasses with newProxyKindtracking and a more robust analysis of IR nodes, supporting structured control flow and easy extension for new intrinsics.tma_storeintrinsics into the requiredarrive/waithandshake to ensure proper synchronization for TMA stores.docs/compiler_internals/inject_fence_proxy.md, explaining why fences are needed, what the pass does, coverage of intrinsics, usage, and extension instructions. [1] [2]Warpgroup Synchronization Intrinsics
warpgroup_arrive,warpgroup_commit_batch, andwarpgroup_waitinsrc/op/builtin.ccand documented them insrc/op/builtin.h. These ops enable explicit warpgroup synchronization for WGMMA sequences. [1] [2]cutelibrary calls insrc/tl_templates/cuda/intrin.hand added lowering logic to the CUDA codegen insrc/target/codegen_cuda.cc. [1] [2]Miscellaneous
src/transform/inject_pipeline.cc. [1] [2]Summary by CodeRabbit
New Features
Refactor
Documentation
Tests