Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }

Expand Down
2 changes: 2 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ static constexpr const char *kEnablePTXASVerboseOutput =
static constexpr const char *kDisableVectorize256 = "tl.disable_vectorize_256";
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace";
/*!
* \brief Whether to disable dynamic tail split
*
Expand Down
9 changes: 6 additions & 3 deletions src/transform/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <unordered_set>
#include <utility>

#include "../op/builtin.h"
#include "arith/int_operator.h"
#include "runtime/thread_storage_scope.h"
#include "tir/ir/buffer_common.h"
Expand Down Expand Up @@ -1914,6 +1915,8 @@ using namespace tir::transform;
namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, const IRModule &m, PassContext ctx) {
bool detect_inplace =
ctx->GetConfig<Bool>(kStorageRewriteDetectInplace, Bool(false)).value();
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem =
Expand All @@ -1939,9 +1942,9 @@ Pass StorageRewrite() {
reuse_require_exact_matched_dtype = true;
}
auto *n = f.CopyOnWrite();
n->body =
StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
reuse_require_exact_matched_dtype);
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), detect_inplace,
enable_reuse,
reuse_require_exact_matched_dtype);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
Expand Down
61 changes: 61 additions & 0 deletions testing/python/components/test_storage_rewrite_detect_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit
def _compile_kernel_without_inplace():
num_tokens = T.symbolic("num_tokens")

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int")
read = x[pid]

write = T.alloc_var("int")
write = read * 2
x[pid] = write

return buggy_kernel


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True,
},)
def _compile_kernel_with_inplace():
num_tokens = T.symbolic("num_tokens")

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int")
read = x[pid]

write = T.alloc_var("int")
write = read * 2
x[pid] = write

return buggy_kernel


def _get_device_kernel_script(detect_inplace: bool) -> str:
if detect_inplace:
kernel = _compile_kernel_with_inplace()
else:
kernel = _compile_kernel_without_inplace()
source = kernel.get_kernel_source()
return source


def test_storage_rewrite_detect_inplace_toggle():
script_off = _get_device_kernel_script(detect_inplace=False)
script_on = _get_device_kernel_script(detect_inplace=True)

assert script_off.count("read = (read * 2);") == 0
assert script_on.count("read = (read * 2);") > 0


if __name__ == "__main__":
tilelang.testing.main()
9 changes: 1 addition & 8 deletions tilelang/autotuner/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,7 @@ class CompileArgs:
target_host: Target host for cross-compilation (default: None).
verbose: Whether to enable verbose output (default: False).
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
Refer to `tilelang.PassConfigKey` for supported options.
"""

out_idx: Optional[Union[List[int], int]] = None
Expand Down
9 changes: 1 addition & 8 deletions tilelang/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,7 @@ def compile(
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str):
Expand Down
6 changes: 1 addition & 5 deletions tilelang/jit/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ def __init__(
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
Refer to `tilelang.PassConfigKey` for supported options.
from_database : bool, optional
Whether to create a TorchFunction from a database.
"""
Expand Down
40 changes: 40 additions & 0 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,46 @@ class PassConfigKey(str, Enum):
TL_FORCE_LET_INLINE = "tl.force_let_inline"
"""Force TileLang to inline let bindings during simplification. Default: False"""

TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace"
"""Control StorageRewrite inplace detection.

When False (default) StorageRewrite keeps distinct temporaries for patterns
such as `dst[i] = f(src[i])`, avoiding implicit aliasing:

```
read = T.allocate([1], "int32", "local.var")
write = T.allocate([1], "int32", "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
write_buf = T.Buffer((1,), "int32", data=write, scope="local.var")
write_buf[0] = read_buf[0] * 2
f(write_buf[0])
```

Setting the flag to True allows StorageRewrite to reuse the `read` buffer
for the write when it can prove the update is safely inplace, producing IR
like:

```
read = T.allocate([1], "int32", "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
read_buf[0] = read_buf[0] * 2
f(read_buf[0])
```

This reduces local memory usage but introduces aliasing between the buffers.

Usage:

```python
from tilelang.transform import PassContext, PassConfigKey

with PassContext(
config={PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE.value: True}
):
mod = tilelang.transform.StorageRewrite()(mod)
```
"""

# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
Expand Down