Skip to content

Commit ddfaac3

Browse files
authored
[Refactor] Refactor Pass InjectFenceProxy and expose some warp group primitives in frontend (#977)
* • InjectFenceProxy docs and tests - annotate proxy fence injector with context comments for async/generic detection - add compiler internals doc covering the pass mechanics and link it in docs index - repair fence proxy test by fixing descriptor init usage and fence counter logic * do not consider call_extern as async. * doc update. * reduce test size for sparse mla
1 parent 77e31e5 commit ddfaac3

File tree

13 files changed

+639
-145
lines changed

13 files changed

+639
-145
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# InjectFenceProxy Pass
2+
3+
`tl.InjectFenceProxy` is a TIR-level transform that keeps the GPU proxy state consistent on NVIDIA Hopper (SM90+) by inserting `fence.proxy.async` instructions when control flow switches from generic memory operations to asynchronous proxy operations.
4+
5+
## Why Fences Are Needed
6+
7+
Hopper separates memory instructions into generic and asynchronous proxy paths. When an asynchronous instruction (for example, `cp.async` or `tma.load`) issues after generic traffic (like `ldmatrix` or plain buffer stores), the hardware requires a `fence.proxy.async` to guarantee ordering. Missing fences can lead to race conditions or undefined behaviour.
8+
9+
## What the Pass Does
10+
11+
- Walks every statement in the `PrimFunc`, tracking whether it behaves as a **generic**, **async**, or **neutral** proxy (neutral statements reset the state, such as an explicit fence).
12+
- Automatically lowers `tma_store` intrinsics into the required `arrive`/`wait` handshake so that TMA stores participate correctly in synchronization.
13+
- Injects an explicit `fence.proxy.async` whenever a generic statement is followed by an async statement without an intervening neutral barrier.
14+
15+
The pass is conservative: unknown extern calls are treated as async so that the fence is inserted rather than accidentally omitted.
16+
17+
### Timeline View
18+
19+
```
20+
generic initialize_descriptor → generic shared-store → async wgmma
21+
│ │ │
22+
└─ generic proxy ┴─ generic proxy ┴─ async proxy
23+
│ fence inserted here ↑
24+
└──────────────────────────────┘
25+
```
26+
27+
The proxy tracker scans the sequence from left to right. The moment it detects a transition from generic to async (between the store and `cp.async` above), it synthesizes a `fence.proxy.async` to reset the hardware proxy state before the async path runs.
28+
29+
## Coverage of Intrinsics
30+
31+
The tracker understands the TileLang intrinsics for TMA load/store, shared-memory MMA (`wgmma`), and TVM/PTX async copy intrinsics (`cp.async` variants). Generic operations currently include `ldmatrix`, `stmatrix`, and descriptor initialization. Other IR nodes (loops, blocks, attributes) receive a proxy kind derived from their bodies so that the analysis survives structured control flow.
32+
33+
## Usage
34+
35+
The pass is part of the default TileLang lowering pipeline. To apply it manually:
36+
37+
```python
38+
from tilelang import tl
39+
from tvm import IRModule
40+
41+
mod = IRModule({"main": prim_func})
42+
with tvm.transform.PassContext():
43+
mod = tl.transform.InjectFenceProxy()(mod)
44+
```
45+
46+
## End-to-End Example
47+
48+
Before the pass:
49+
50+
```python
51+
@T.prim_func
52+
def kernel():
53+
with T.Kernel(1):
54+
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
55+
smem = T.decl_buffer((128,), "float16", scope="shared")
56+
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
57+
smem[0] = T.float16(0)
58+
T.ptx_wgmma_ss(
59+
"float16",
60+
"m64n64k16",
61+
T.bool(True),
62+
T.bool(True),
63+
"fp16",
64+
"fp16",
65+
"fp16",
66+
desc.data,
67+
T.int32(0),
68+
desc.data,
69+
T.int32(0),
70+
smem.data,
71+
T.int32(0),
72+
T.bool(True),
73+
1,
74+
1,
75+
)
76+
```
77+
78+
After `tl.transform.InjectFenceProxy`:
79+
80+
```python
81+
@T.prim_func
82+
def kernel():
83+
with T.Kernel(1):
84+
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
85+
smem = T.decl_buffer((128,), "float16", scope="shared")
86+
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
87+
smem[0] = T.float16(0)
88+
T.fence_proxy_async()
89+
T.ptx_wgmma_ss(
90+
"float16",
91+
"m64n64k16",
92+
T.bool(True),
93+
T.bool(True),
94+
"fp16",
95+
"fp16",
96+
"fp16",
97+
desc.data,
98+
T.int32(0),
99+
desc.data,
100+
T.int32(0),
101+
smem.data,
102+
T.int32(0),
103+
T.bool(True),
104+
1,
105+
1,
106+
)
107+
```
108+
109+
The only change is the `fence_proxy_async` between the generic descriptor setup / shared-memory write and the async `wgmma`. In larger kernels the pass performs the same operation across nested blocks, loops, and conditional branches.
110+
111+
## Extending the Pass
112+
113+
If you introduce a new intrinsic that behaves like an async proxy, add it to `IsAsyncIntrinsic` in `src/transform/inject_fence_proxy.cc`. Likewise, extend `IsKnownGeneric` for additional generic operations. When adding new neutral barriers, make sure they set the proxy kind to `kNeutral` so the state resets correctly.

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ deeplearning_operators/deepseek_mla
4040
:caption: COMPILER INTERNALS
4141

4242
compiler_internals/letstmt_inline
43+
compiler_internals/inject_fence_proxy
4344
:::
4445

4546
:::{toctree}
@@ -54,4 +55,4 @@ autoapi/tilelang/index
5455
:caption: Privacy
5556

5657
privacy
57-
:::
58+
:::

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@ def test_example_fp8_lighting_indexer():
2121
def test_example_sparse_mla_fwd():
2222
# small shapes for testing
2323
test_sparse_mla_fwd(
24-
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
24+
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
2525

2626

2727
@tilelang.testing.requires_cuda
2828
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2929
def test_example_sparse_mla_fwd_pipelined():
3030
# small shapes for testing
3131
test_sparse_mla_fwd_pipelined(
32-
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
32+
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
3333

3434

3535
@tilelang.testing.requires_cuda
3636
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3737
def test_example_sparse_mla_bwd():
3838
test_sparse_mla_bwd(
39-
S=1024, SKV=2048, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
39+
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
4040

4141

4242
if __name__ == "__main__":

src/op/builtin.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,21 @@ TIR_DEFINE_TL_BUILTIN(no_set_max_nreg)
203203
.set_attr<TCallEffectKind>("TCallEffectKind",
204204
Integer(CallEffectKind::kOpaque));
205205

206+
TIR_DEFINE_TL_BUILTIN(warpgroup_arrive)
207+
.set_num_inputs(0)
208+
.set_attr<TCallEffectKind>("TCallEffectKind",
209+
Integer(CallEffectKind::kOpaque));
210+
211+
TIR_DEFINE_TL_BUILTIN(warpgroup_commit_batch)
212+
.set_num_inputs(0)
213+
.set_attr<TCallEffectKind>("TCallEffectKind",
214+
Integer(CallEffectKind::kOpaque));
215+
216+
TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
217+
.set_num_inputs(1)
218+
.set_attr<TCallEffectKind>("TCallEffectKind",
219+
Integer(CallEffectKind::kOpaque));
220+
206221
TIR_DEFINE_TL_BUILTIN(wait_wgmma)
207222
.set_num_inputs(1)
208223
.set_attr<TCallEffectKind>("TCallEffectKind",

src/op/builtin.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,30 @@ TVM_DLL const Op &set_max_nreg();
334334
*/
335335
TVM_DLL const Op &no_set_max_nreg();
336336

337+
/*!
338+
* \brief Arrive at a warpgroup fence for WGMMA sequences
339+
*
340+
* warpgroup_arrive()
341+
*
342+
*/
343+
TVM_DLL const Op &warpgroup_arrive();
344+
345+
/*!
346+
* \brief Commit the current warpgroup batch for WGMMA sequences
347+
*
348+
* warpgroup_commit_batch()
349+
*
350+
*/
351+
TVM_DLL const Op &warpgroup_commit_batch();
352+
353+
/*!
354+
* \brief Wait for the warpgroup batch identified by num_mma
355+
*
356+
* warpgroup_wait(num_mma)
357+
*
358+
*/
359+
TVM_DLL const Op &warpgroup_wait();
360+
337361
/*!
338362
* \brief Wait the previous wgmma to finish
339363
*

src/target/codegen_cuda.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
13741374
print_extern_call_stmt("tl::tma_store_arrive");
13751375
} else if (op->op.same_as(tl::tma_store_wait())) {
13761376
print_extern_call_stmt("tl::tma_store_wait<0>");
1377+
} else if (op->op.same_as(tl::warpgroup_arrive())) {
1378+
print_extern_call_stmt("tl::warpgroup_arrive");
1379+
} else if (op->op.same_as(tl::warpgroup_commit_batch())) {
1380+
print_extern_call_stmt("tl::warpgroup_commit_batch");
1381+
} else if (op->op.same_as(tl::warpgroup_wait())) {
1382+
this->PrintIndent();
1383+
int num_mma = Downcast<IntImm>(op->args[0])->value;
1384+
this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
1385+
<< ">();\n";
13771386
} else if (op->op.same_as(tl::set_max_nreg())) {
13781387
this->PrintIndent();
13791388
int nreg = Downcast<IntImm>(op->args[0])->value;

src/tl_templates/cuda/intrin.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22

33
#if __CUDA_ARCH_LIST__ >= 900
44
#include "cute/arch/cluster_sm90.hpp"
5+
#include "cute/arch/mma_sm90_gmma.hpp"
56
#include "cutlass/cutlass.h"
67

78
namespace tl {
9+
10+
TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); }
11+
TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); }
12+
13+
template <int NumMma> TL_DEVICE void warpgroup_wait() {
14+
cute::warpgroup_wait<NumMma>();
15+
}
16+
817
// Template parameter:
918
// thread_extent: the logical size (in number of threads) of each "group"
1019
// within which we want to elect exactly ONE representative
@@ -53,4 +62,4 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
5362
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
5463
}
5564
} // namespace tl
56-
#endif
65+
#endif

0 commit comments

Comments
 (0)