Skip to content

Commit cdc67fc

Browse files
authored
[PassConfig] Introduce PassConfig TL_STORAGE_REWRITE_DETECT_INPLACE (#1089)
* • Enable configurable StorageRewrite inplace detection - Add kStorageRewriteDetectInplace constant and register the flag with PassContext so C++ code no longer hard-codes the key. - Wire StorageRewrite to include TileLang builtin constants and honor the new config toggle when deciding inplace reuse. - Document the flag across Python surfaces (PassConfigKey, JIT/autotuner docs) with usage guidance and simplified IR examples. * lint fix * add test * lint fix
1 parent 0c7e741 commit cdc67fc

File tree

8 files changed

+113
-24
lines changed

8 files changed

+113
-24
lines changed

src/op/builtin.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
3333
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
3434
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
3535
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
36+
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
3637

3738
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
3839

src/op/builtin.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ static constexpr const char *kEnablePTXASVerboseOutput =
4848
static constexpr const char *kDisableVectorize256 = "tl.disable_vectorize_256";
4949
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
5050
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
51+
static constexpr const char *kStorageRewriteDetectInplace =
52+
"tl.storage_rewrite_detect_inplace";
5153
/*!
5254
* \brief Whether to disable dynamic tail split
5355
*

src/transform/storage_rewrite.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <unordered_set>
3939
#include <utility>
4040

41+
#include "../op/builtin.h"
4142
#include "arith/int_operator.h"
4243
#include "runtime/thread_storage_scope.h"
4344
#include "tir/ir/buffer_common.h"
@@ -1914,6 +1915,8 @@ using namespace tir::transform;
19141915
namespace transform {
19151916
Pass StorageRewrite() {
19161917
auto pass_func = [](PrimFunc f, const IRModule &m, PassContext ctx) {
1918+
bool detect_inplace =
1919+
ctx->GetConfig<Bool>(kStorageRewriteDetectInplace, Bool(false)).value();
19171920
bool enable_reuse = true;
19181921
bool reuse_require_exact_matched_dtype = false;
19191922
bool merge_static_smem =
@@ -1939,9 +1942,9 @@ Pass StorageRewrite() {
19391942
reuse_require_exact_matched_dtype = true;
19401943
}
19411944
auto *n = f.CopyOnWrite();
1942-
n->body =
1943-
StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
1944-
reuse_require_exact_matched_dtype);
1945+
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), detect_inplace,
1946+
enable_reuse,
1947+
reuse_require_exact_matched_dtype);
19451948
// Parameters may not be rewritten, but internal allocations may.
19461949
// Vectorization of AllocateConst is currently disabled, as it has
19471950
// indexing issues for types that include padding (e.g. int8x3
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import tilelang
2+
import tilelang.testing
3+
from tilelang import language as T
4+
5+
6+
@tilelang.jit
7+
def _compile_kernel_without_inplace():
8+
num_tokens = T.symbolic("num_tokens")
9+
10+
@T.prim_func
11+
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
12+
with T.Kernel(num_tokens, threads=32) as pid:
13+
read = T.alloc_var("int")
14+
read = x[pid]
15+
16+
write = T.alloc_var("int")
17+
write = read * 2
18+
x[pid] = write
19+
20+
return buggy_kernel
21+
22+
23+
@tilelang.jit(
24+
pass_configs={
25+
tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True,
26+
},)
27+
def _compile_kernel_with_inplace():
28+
num_tokens = T.symbolic("num_tokens")
29+
30+
@T.prim_func
31+
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
32+
with T.Kernel(num_tokens, threads=32) as pid:
33+
read = T.alloc_var("int")
34+
read = x[pid]
35+
36+
write = T.alloc_var("int")
37+
write = read * 2
38+
x[pid] = write
39+
40+
return buggy_kernel
41+
42+
43+
def _get_device_kernel_script(detect_inplace: bool) -> str:
44+
if detect_inplace:
45+
kernel = _compile_kernel_with_inplace()
46+
else:
47+
kernel = _compile_kernel_without_inplace()
48+
source = kernel.get_kernel_source()
49+
return source
50+
51+
52+
def test_storage_rewrite_detect_inplace_toggle():
53+
script_off = _get_device_kernel_script(detect_inplace=False)
54+
script_on = _get_device_kernel_script(detect_inplace=True)
55+
56+
assert script_off.count("read = (read * 2);") == 0
57+
assert script_on.count("read = (read * 2);") > 0
58+
59+
60+
if __name__ == "__main__":
61+
tilelang.testing.main()

tilelang/autotuner/param.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,7 @@ class CompileArgs:
3737
target_host: Target host for cross-compilation (default: None).
3838
verbose: Whether to enable verbose output (default: False).
3939
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
40-
Available options:
41-
"tir.disable_vectorize": bool, default: False
42-
"tl.disable_tma_lower": bool, default: False
43-
"tl.disable_warp_specialized": bool, default: False
44-
"tl.config_index_bitwidth": int, default: None
45-
"tl.disable_dynamic_tail_split": bool, default: False
46-
"tl.dynamic_vectorize_size_bits": int, default: 128
47-
"tl.disable_safe_memory_legalize": bool, default: False
40+
Refer to `tilelang.PassConfigKey` for supported options.
4841
"""
4942

5043
out_idx: Optional[Union[List[int], int]] = None

tilelang/jit/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,7 @@ def compile(
5959
Whether to enable verbose output (default: False).
6060
pass_configs : dict, optional
6161
Additional keyword arguments to pass to the Compiler PassContext.
62-
Available options:
63-
"tir.disable_vectorize": bool, default: False
64-
"tl.disable_tma_lower": bool, default: False
65-
"tl.disable_warp_specialized": bool, default: False
66-
"tl.config_index_bitwidth": int, default: None
67-
"tl.disable_dynamic_tail_split": bool, default: False
68-
"tl.dynamic_vectorize_size_bits": int, default: 128
69-
"tl.disable_safe_memory_legalize": bool, default: False
62+
Refer to `tilelang.transform.PassConfigKey` for supported options.
7063
"""
7164
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
7265
if isinstance(compile_flags, str):

tilelang/jit/kernel.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,7 @@ def __init__(
7171
Whether to enable verbose output (default: False).
7272
pass_configs : dict, optional
7373
Additional keyword arguments to pass to the Compiler PassContext.
74-
Available options:
75-
"tir.disable_vectorize": bool, default: False
76-
"tl.disable_tma_lower": bool, default: False
77-
"tl.disable_dynamic_tail_split": bool, default: False
78-
"tl.dynamic_vectorize_size_bits": int, default: 128
74+
Refer to `tilelang.PassConfigKey` for supported options.
7975
from_database : bool, optional
8076
Whether to create a TorchFunction from a database.
8177
"""

tilelang/transform/pass_config.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,46 @@ class PassConfigKey(str, Enum):
6969
TL_FORCE_LET_INLINE = "tl.force_let_inline"
7070
"""Force TileLang to inline let bindings during simplification. Default: False"""
7171

72+
TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace"
73+
"""Control StorageRewrite inplace detection.
74+
75+
When False (default) StorageRewrite keeps distinct temporaries for patterns
76+
such as `dst[i] = f(src[i])`, avoiding implicit aliasing:
77+
78+
```
79+
read = T.allocate([1], "int32", "local.var")
80+
write = T.allocate([1], "int32", "local.var")
81+
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
82+
write_buf = T.Buffer((1,), "int32", data=write, scope="local.var")
83+
write_buf[0] = read_buf[0] * 2
84+
f(write_buf[0])
85+
```
86+
87+
Setting the flag to True allows StorageRewrite to reuse the `read` buffer
88+
for the write when it can prove the update is safely inplace, producing IR
89+
like:
90+
91+
```
92+
read = T.allocate([1], "int32", "local.var")
93+
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
94+
read_buf[0] = read_buf[0] * 2
95+
f(read_buf[0])
96+
```
97+
98+
This reduces local memory usage but introduces aliasing between the buffers.
99+
100+
Usage:
101+
102+
```python
103+
from tilelang.transform import PassContext, PassConfigKey
104+
105+
with PassContext(
106+
config={PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE.value: True}
107+
):
108+
mod = tilelang.transform.StorageRewrite()(mod)
109+
```
110+
"""
111+
72112
# TIR related configs
73113
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
74114
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""

0 commit comments

Comments
 (0)