Skip to content

Commit 1dfac2e

Browse files
authored
[Bugfix] Use ExprDeepEqual instead of StructuralEqual when merge consecutive If stmt (#876)
* Update submodule TVM to latest commit and fix condition comparison in merge_if_stmt.cc * Update submodule TVM to latest commit 0524f760 * lint fix
1 parent 15a303d commit 1dfac2e

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/transform/merge_if_stmt.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class MergeIfStmtRewriter : public StmtExprMutator {
3939
if (const IfThenElseNode *if_node = new_stmt.as<IfThenElseNode>()) {
4040
if (!if_node->else_case.defined()) {
4141
if (current_condition.defined() &&
42-
StructuralEqual()(current_condition, if_node->condition)) {
42+
ExprDeepEqual()(current_condition, if_node->condition)) {
4343
current_if_bodies.push_back(if_node->then_case);
4444
continue;
4545
} else {
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import tilelang
2+
from tilelang import tvm as tvm
3+
from tvm.ir import IRModule
4+
import tilelang.testing
5+
import tilelang.language as T
6+
7+
8+
def merge_if_test():
9+
10+
@T.prim_func
11+
def main():
12+
A = T.alloc_fragment((1,), "float16")
13+
B = T.alloc_fragment((1,), "float16")
14+
C = T.alloc_fragment((1,), "float16")
15+
D = T.alloc_fragment((1,), "float16")
16+
if A[0] == 0:
17+
A[0] = 0
18+
if B[0] == 0:
19+
B[0] = 0
20+
if C[0] == 0:
21+
C[0] = 0
22+
if D[0] == 0:
23+
D[0] = 0
24+
25+
return main
26+
27+
28+
def test_merge_if():
29+
func = merge_if_test()
30+
original_module = IRModule.from_expr(func)
31+
transformed = tilelang.transform.MergeIfStmt()(original_module)
32+
tvm.ir.assert_structural_equal(original_module["main"], transformed["main"], True)
33+
34+
35+
if __name__ == "__main__":
36+
tilelang.testing.main()

0 commit comments

Comments
 (0)