Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions docs/compiler_internals/letstmt_inline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# LetStmt Inlining in TileLang

This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance.

## Overview

A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined.

## When Does LetStmt Get Inlined?

The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met:

### 1. The value satisfies `CanInlineLetStmt`

The `CanInlineLetStmt` helper returns `true` when:

- **The value is a constant** (`is_const_number(op->value)` returns true)
- **The value is a variable** (`op->value.as<VarNode>()` returns a node)
- **The value is an integer expression without side effects**:
- The value has `int` dtype
- The side effect level is `kPure` or lower (no observable side effects)

```cpp
bool CanInlineLetStmt(const LetStmtNode *op) {
if (is_const_number(op->value))
return true;
if (op->value.as<VarNode>())
return true;
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be
// index).
if (!op->value.dtype().is_int())
return false;
return SideEffect(op->value) <= CallEffectKind::kPure;
}
```

### 2. The variable is NOT used in buffer definitions

Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields).

This protection exists because:
- Buffer definitions are not updated during the simplification pass
- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition
- This would cause compilation errors or incorrect behavior

The mutator checks this before dropping the binding:

```cpp
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());

if (can_inline && !used_in_buffer_def) {
return body; // Inline: remove LetStmt and return body directly
}
```

## Example: Why Buffer Definition Variables Are Protected

Consider this code:

```python
let stride = M * 16
let buffer_a = Buffer(data, shape=[M, N], strides=[stride, 1])
buffer_a[i, j] = ...
```

- `stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects)
- However, `stride` is used in `buffer_a`'s `strides` field
- If we inline it, the buffer definition becomes `strides=[M*16, 1]`
- But the Buffer object's fields are not updated during simplification
- Later code accessing `buffer_a` would fail to find the `stride` variable

Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined.

## How Variables Are Collected

The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions:

```cpp
void VisitBuffer(const Buffer &buf) {
// Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto &dim : buf->shape) {
usage(dim);
}
for (const auto &dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);

// Track for use in LetStmtNode mutator
for (const auto &var : usage.undefined_) {
used_in_buffer_def_.insert(var.get());
}
}
```

## Practical Example: Temporary Variable Issue

Consider this TileLang code:

```python
for i in T.Parallel(block_N):
idx = bx * block_N + i
tmp = T.max(A[idx], 1)
B[idx] = tmp / 2
A[idx] = tmp * 2
```

In this case:
- `tmp` is an integer-like temporary variable
- It satisfies `CanInlineLetStmt` (pure int expression)
- It's **not** used in any buffer definition
- Therefore, `tmp` **will be inlined**

This means the IR becomes:

```python
for i in T.Parallel(block_N):
idx = bx * block_N + i
B[idx] = T.max(A[idx], 1) / 2
A[idx] = T.max(A[idx], 1) * 2
```

If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern.

Comment on lines +126 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Revise the hazard explanation; current sequence doesn’t demonstrate a read-after-write issue.

In the example, both reads occur before the write to A[idx], so “due to the first write” isn’t accurate. Consider a case where A[idx] is written between two reads to illustrate why duplicated loads from inlining can change behavior, or rephrase to note duplicate loads impact performance but not semantics here.

-If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write)...
+Be aware that inlining duplicates loads. If a write to the same location occurs between two uses, re‑evaluation can observe updated memory and change behavior. In this example, both reads occur before the write, so semantics are unchanged, though duplicated loads may affect performance.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern.
Be aware that inlining duplicates loads. If a write to the same location occurs between two uses, re-evaluation can observe updated memory and change behavior. In this example, both reads occur before the write, so semantics are unchanged, though duplicated loads may affect performance.
🤖 Prompt for AI Agents
In docs/compiler_internals/letstmt_inline.md around lines 126 to 127, the hazard
explanation is incorrect because the shown sequence has both reads before the
write so it does not demonstrate a read-after-write semantic change; either
replace the example with one where A[idx] is written between two reads (so the
first write can change the value seen by the second read) to illustrate the
semantic hazard, or rephrase the sentence to say that duplicated loads in this
specific sequence only affect performance (not semantics) and clarify when
duplicated loads can change behavior (i.e., when a write to the same location
occurs between the reads).

## Controlling Let Inlining via Pass Config

TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments.

```python
from tilelang import transform
from tilelang.engine.phase import LowerAndLegalize

with transform.PassContext(
config={transform.PassConfigKey.TL_FORCE_LET_INLINE: True}
):
lowered_mod = LowerAndLegalize(input_mod, target)
```

If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it.

## Summary

The LetStmt inlining mechanism is a **conservative optimization** that:
1. Aggressively inlines simple, pure integer expressions to simplify the IR
2. Protects variables used in buffer definitions to avoid breaking buffer access
3. Helps reduce IR complexity and improve code generation
4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required

Understanding when inlining happens is crucial for:
- Debugging compilation issues
- Understanding generated code
- Writing efficient TileLang programs
- Identifying potential optimization opportunities or bugs

## Related Files

- `src/transform/simplify.cc`: Main Simplify implementation
- `src/transform/frontend_legalize.cc`: Standalone LetInline pass
- `tilelang/engine/phase.py`: Pipeline integration for eager LetInlining
- `testing/python/transform/test_tilelang_transform_let_inline.py`: Regression coverage for the pass
7 changes: 7 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ deeplearning_operators/matmul
deeplearning_operators/deepseek_mla
:::

:::{toctree}
:maxdepth: 1
:caption: COMPILER INTERNALS

compiler_internals/letstmt_inline
:::

:::{toctree}
:maxdepth: 1
:caption: API Reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def test_example_tilelang_block_sparse_attn():


def test_example_tilelang_sparse_gqa_decode_varlen_indice():
example_tilelang_sparse_gqa_decode_varlen_indice.main()
example_tilelang_sparse_gqa_decode_varlen_indice.main(batch=1, max_cache_seqlen=2048)


def test_example_tilelang_sparse_gqa_decode_varlen_mask():
example_tilelang_sparse_gqa_decode_varlen_mask.main()
example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048)


def test_example_triton_sparse_gqa_decode_varlen_indice():
Expand Down
22 changes: 14 additions & 8 deletions examples/fusedmoe/example_fusedmoe_tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,15 +521,21 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return output


def main():
def main(d_hidden=7168,
d_expert=2048,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=8192):
config = {
"dhidden": 7168,
"dexpert": 2048,
"nroutedexperts": 8,
"nsharedexperts": 1,
"nexpertspertoken": 4,
"bs": 1,
"seqlen": 8192,
"dhidden": d_hidden,
"dexpert": d_expert,
"nroutedexperts": n_routed_experts,
"nsharedexperts": n_shared_experts,
"nexpertspertoken": n_experts_per_token,
"bs": batch_size,
"seqlen": seq_len,
"seed": 81394
}

Expand Down
9 changes: 8 additions & 1 deletion examples/fusedmoe/test_example_fusedmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@


def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main()
example_fusedmoe_tilelang.main(
d_hidden=1024,
d_expert=256,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=1024)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
Expand Down
10 changes: 9 additions & 1 deletion src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ static constexpr const char *kDisableDynamicTailSplit =
static constexpr const char *kDisableThreadStorageSync =
"tl.disable_thread_storage_sync";

/*!
* \brief Force inline Let bindings during simplification.
*
* kForceLetInline = "tl.force_let_inline"
*
*/
static constexpr const char *kForceLetInline = "tl.force_let_inline";

/*!
* \brief The size of the vectorized dimension in buffer, designed by user
*
Expand Down Expand Up @@ -441,4 +449,4 @@ TVM_DLL const Op &increase_descriptor_offset();
} // namespace tl
} // namespace tvm

#endif // TVM_TL_OP_BUILTIN_H_
#endif // TVM_TL_OP_BUILTIN_H_
89 changes: 82 additions & 7 deletions src/transform/inject_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

#include <functional>
#include <unordered_set>
#include <utility>

Expand Down Expand Up @@ -845,24 +846,77 @@ class PipelineInjector : private StmtExprMutator {
// Step 2: Find the body and buffer allocations of the pipeline. The body
// can be direct child of the for-loop. If the for-loop has BlockRealize as
// its child, the pipeline body will be the child of the block.
Stmt pipeline_body{nullptr};
Stmt pipeline_body_root{nullptr};
bool pipeline_body_from_block = false;
Array<Buffer> pipeline_allocs;
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
pipeline_body = block->body;
pipeline_body_root = block->body;
pipeline_allocs = block->alloc_buffers;
pipeline_body_from_block = true;
} else {
pipeline_body = for_node->body;
pipeline_body_root = for_node->body;
}

const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< pipeline_body->GetTypeKey();
const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
ObjectRef node = attr->node;
String attr_key = attr->attr_key;
PrimExpr value = attr->value;
Span span = attr->span;
rewrap_fns.emplace_back(
[node = std::move(node), attr_key = std::move(attr_key),
value = std::move(value), span](Stmt body) -> Stmt {
return AttrStmt(node, attr_key, value, body, span);
});
};
{
Stmt current = pipeline_body_root;
while (true) {
if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
pipeline_body_seq = seq_stmt;
break;
}
if (const auto *if_then_else = current.as<IfThenElseNode>()) {
ICHECK(!if_then_else->else_case.defined())
<< "InjectSoftwarePipeline: Can't handle the body of the loop "
"because the IfThenElse node has an else branch";
PrimExpr condition = if_then_else->condition;
Span span = if_then_else->span;
rewrap_fns.emplace_back(
[condition = std::move(condition), span](Stmt body) -> Stmt {
return IfThenElse(condition, body, Stmt(), span);
});
current = if_then_else->then_case;
continue;
}
if (const auto *let_stmt = current.as<LetStmtNode>()) {
Var var = let_stmt->var;
PrimExpr value = let_stmt->value;
Span span = let_stmt->span;
rewrap_fns.emplace_back([var = std::move(var),
value = std::move(value),
span](Stmt body) -> Stmt {
return LetStmt(var, value, body, span);
});
current = let_stmt->body;
continue;
}
if (const auto *attr = current.as<AttrStmtNode>()) {
append_attr_wrapper(attr);
current = attr->body;
continue;
}
LOG(FATAL) << "ValueError: The body of the software pipeline should be "
<< "SeqStmt, got " << current->GetTypeKey();
}
}
ICHECK(pipeline_body_seq != nullptr);

// Step 3: Blockize the components of the pipeline. Each child of the
// pipelined loop will be converted into a block.
Expand Down Expand Up @@ -934,6 +988,27 @@ class PipelineInjector : private StmtExprMutator {
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info)
.BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
stmt = (*it)(stmt);
}
return stmt;
};
if (!rewrap_fns.empty()) {
if (pipeline_body_from_block) {
BlockRealize pipeline_realize = Downcast<BlockRealize>(pipeline);
Block pipeline_block = pipeline_realize->block;
{
BlockNode *block_node = pipeline_block.CopyOnWrite();
block_node->body = apply_wrappers(block_node->body);
}
pipeline = BlockRealize(pipeline_realize->iter_values,
pipeline_realize->predicate, pipeline_block,
pipeline_realize->span);
} else {
pipeline = apply_wrappers(pipeline);
}
}

if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
Expand Down
Loading
Loading