Skip to content

Commit b3afa7b

Browse files
[CINN][Backend Pass Update No.7] Update merge_block_utils (#70406)
1 parent 7eab154 commit b3afa7b

File tree

3 files changed

+116
-114
lines changed

3 files changed

+116
-114
lines changed

paddle/cinn/optim/merge_block_utils.cc

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,60 +17,68 @@
1717
#include "paddle/cinn/common/cas.h"
1818
#include "paddle/cinn/ir/ir_mutator.h"
1919
#include "paddle/cinn/ir/ir_printer.h"
20+
#include "paddle/cinn/ir/stmt.h"
2021
#include "paddle/common/enforce.h"
2122

2223
namespace cinn {
2324
namespace optim {
2425

2526
namespace {
27+
using ir::stmt::BlockRef;
28+
using ir::stmt::For;
29+
using ir::stmt::StmtRef;
2630

27-
struct ForInfoAnalyzer : public ir::IRMutator<Expr*> {
31+
struct ForHash {
32+
std::size_t operator()(const For& stmt) const {
33+
return std::hash<const Object*>()(stmt.get());
34+
}
35+
};
36+
37+
struct ForInfoAnalyzer {
2838
public:
29-
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
39+
void operator()(const For& for_stmt) { Visit(for_stmt); }
3040

31-
ForTreeNode BuildTreeNode(const ir::For* node) {
41+
ForTreeNode BuildTreeNode(const For& node) {
3242
ForTreeNode tree_node = {node, std::vector<ForTreeNode>()};
33-
for (const auto for_node : for_to_children_[node]) {
34-
tree_node.children.push_back(BuildTreeNode(for_node));
43+
for (const For& stmt : for_to_children_[node]) {
44+
tree_node.children.push_back(BuildTreeNode(stmt));
3545
}
3646
return tree_node;
3747
}
3848

3949
ForTreeNode GetRootTreeNode() { return BuildTreeNode(root_node_); }
4050

4151
private:
42-
void Visit(const ir::For* node, ir::Expr* expr) override {
43-
auto old_last_node = last_node_;
44-
if (last_node_ == nullptr) {
52+
void Visit(const For& node) {
53+
if (root_node_ == nullptr) {
4554
root_node_ = node;
46-
} else {
47-
for_to_children_[last_node_].push_back(node);
4855
}
49-
last_node_ = const_cast<ir::For*>(node);
50-
ir::IRMutator<>::Visit(node, expr);
51-
last_node_ = old_last_node;
56+
const BlockRef& body = node->body();
57+
for (const StmtRef& stmt : body->stmts()) {
58+
if (stmt.isa<For>()) {
59+
for_to_children_[node].push_back(stmt.as<For>());
60+
Visit(stmt.as<For>());
61+
}
62+
}
5263
}
5364

54-
ir::For* last_node_ = nullptr;
55-
const ir::For* root_node_ = nullptr;
56-
std::unordered_map<const ir::For*, std::vector<const ir::For*>>
57-
for_to_children_;
65+
private:
66+
For root_node_{nullptr};
67+
std::unordered_map<For, std::vector<For>, ForHash> for_to_children_;
5868
};
5969

6070
} // namespace
6171

62-
bool CanMergeBlocks(const ir::For* first,
63-
const ir::For* second,
72+
bool CanMergeBlocks(const For first,
73+
const For second,
6474
const ForEqualFunc& IsEqual) {
65-
auto Get = [&](ir::Expr* expr) -> ForTreeNode {
75+
auto Get = [&](const For for_stmt) -> ForTreeNode {
6676
ForInfoAnalyzer for_info_analyzer;
67-
for_info_analyzer(expr);
77+
for_info_analyzer(for_stmt);
6878
return for_info_analyzer.GetRootTreeNode();
6979
};
70-
ir::Expr first_expr = Expr(const_cast<ir::For*>(first));
71-
ir::Expr second_expr = Expr(const_cast<ir::For*>(second));
72-
const auto first_inner_for_list = Get(&first_expr);
73-
const auto second_inner_for_list = Get(&second_expr);
80+
const auto first_inner_for_list = Get(first);
81+
const auto second_inner_for_list = Get(second);
7482
return IsEqual(first_inner_for_list, second_inner_for_list);
7583
}
7684

paddle/cinn/optim/merge_block_utils.h

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,71 +14,68 @@
1414

1515
#pragma once
1616

17-
#include "paddle/cinn/ir/ir.h"
17+
#include "paddle/cinn/ir/stmt.h"
1818

1919
namespace cinn {
2020
namespace optim {
2121

2222
struct ForTreeNode {
23-
const ir::For* val;
23+
const ir::stmt::For val;
2424
std::vector<ForTreeNode> children;
2525
};
2626

2727
using ForEqualFunc =
2828
std::function<bool(const ForTreeNode&, const ForTreeNode&)>;
2929

30-
/**
30+
/*
3131
* Determines if two blocks of code with nested for-loops have identical loop
32-
extents and can be merged.
33-
32+
* extents and can be merged.
33+
*
3434
* This pass is applicable in scenarios where there are multiple code blocks
35-
with nested for-loops,
36-
* and we need to determine if these blocks can be consolidated to simplify the
37-
code structure.
38-
35+
* with nested for-loops, and we need to determine if these blocks can be
36+
* consolidated to simplify the code structure.
37+
*
3938
* When applied, this pass will not directly modify the IR but serves as a
40-
prerequisite check
41-
* to ensure that loop extents match. If they do, a separate merging process can
42-
be safely conducted
43-
* to combine the blocks into a single block with shared loop structures.
44-
39+
* prerequisite check to ensure that loop extents match. If they do, a separate
40+
* merging process can be safely conducted to combine the blocks into a single
41+
* block with shared loop structures.
42+
*
4543
* Performance impact: This pass itself does not directly impact performance but
46-
enables further
47-
* optimizations by identifying mergeable loop structures, which can reduce code
48-
size and potentially
49-
* improve cache efficiency by consolidating similar data processing tasks.
50-
51-
* Examples:
52-
* 1. Simple identical loops:
53-
* Input IR:
54-
* block(var_B)
55-
* for(i, 0, 10)
56-
* for(j, 0, 10)
57-
* B[i,j] = A[i,j]
44+
* enables further optimizations by identifying mergeable loop structures, which
45+
* can reduce code size and potentially improve cache efficiency by
46+
* consolidating similar data processing tasks.
5847
*
59-
* block(var_C)
60-
* for(i, 0, 10)
61-
* for(j, 0, 10)
62-
* C[i,j] = A[i,j]
63-
* Output IR:
64-
* Can be merged since loop extents are identical.
48+
* Examples:
6549
*
66-
* 2. Different loop extents:
67-
* Input IR:
68-
* block(var_B)
69-
* for(i, 0, 10)
70-
* for(j, 0, 10)
71-
* B[i,j] = A[i,j]
50+
* Simple identical loops:
51+
* Input IR:
52+
* block(var_B)
53+
* for(i, 0, 10)
54+
* for(j, 0, 10)
55+
* B[i,j] = A[i,j]
56+
* block(var_C)
57+
* for(i, 0, 10)
58+
* for(j, 0, 10)
59+
* C[i,j] = A[i,j]
60+
* Output IR:
61+
* Can be merged since loop extents are identical.
7262
*
73-
* block(var_C)
74-
* for(i, 0, 3)
75-
* for(j, 0, 4)
76-
* C[i,j] = A[i,j]
77-
* Output IR:
78-
* Cannot be merged due to differing loop extents.
63+
* Different loop extents:
64+
* Input IR:
65+
* block(var_B)
66+
* for(i, 0, 10)
67+
* for(j, 0, 10)
68+
* B[i,j] = A[i,j]
69+
* block(var_C)
70+
* for(i, 0, 3)
71+
* for(j, 0, 4)
72+
* C[i,j] = A[i,j]
73+
* Output IR:
74+
* Cannot be merged due to differing loop extents.
7975
*/
80-
bool CanMergeBlocks(const ir::For* first,
81-
const ir::For* second,
76+
77+
bool CanMergeBlocks(const ir::stmt::For first,
78+
const ir::stmt::For second,
8279
const ForEqualFunc& IsEqual);
8380

8481
} // namespace optim

test/cpp/pir/cinn/adt/merge_block_utils_test.cc

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ namespace {
2727
bool IsBlockForAllEqual(const ForTreeNode& first, const ForTreeNode& second) {
2828
auto ForVarExtentEqual = [&](const ForTreeNode& first,
2929
const ForTreeNode& second) -> bool {
30-
const ir::Expr lhs = first.val->extent;
31-
const ir::Expr rhs = second.val->extent;
30+
const ir::Expr lhs = first.val->extent();
31+
const ir::Expr rhs = second.val->extent();
3232
if (cinn::common::AutoSimplify(ir::Sub::Make(lhs, rhs)) != ir::Expr(0)) {
3333
return false;
3434
}
@@ -46,74 +46,71 @@ bool IsBlockForAllEqual(const ForTreeNode& first, const ForTreeNode& second) {
4646
return true;
4747
}
4848

49-
ir::Expr MakeForLoops(const std::vector<int> extents, int index) {
50-
if (index >= extents.size()) {
51-
ir::Expr sb = ir::ScheduleBlock::Make(std::vector<Var>(),
52-
std::vector<Expr>(),
53-
std::vector<Expr>(),
54-
"block",
55-
ir::Expr(0));
56-
return sb;
49+
ir::stmt::For MakeForLoops(const std::vector<int> extents, int index) {
50+
ir::stmt::StmtRef body_stmt;
51+
if (index == extents.size() - 1) {
52+
body_stmt = ir::stmt::Schedule(std::vector<Var>(),
53+
std::vector<Expr>(),
54+
std::vector<Expr>(),
55+
std::vector<Expr>(),
56+
"block",
57+
ir::stmt::BlockRef(0));
58+
} else {
59+
body_stmt = MakeForLoops(extents, index + 1);
5760
}
5861

59-
ir::Expr extent = ir::Expr(extents.at(index));
60-
ir::Expr for_expr = ir::For::Make(ir::Var("i"),
61-
ir::Expr(0),
62-
extent,
63-
ir::ForType::Serial,
64-
ir::DeviceAPI::CUDA,
65-
MakeForLoops(extents, index + 1),
66-
ir::VectorizeInfo(),
67-
ir::BindInfo());
68-
69-
return for_expr;
62+
std::vector<ir::stmt::StmtRef> body = {body_stmt};
63+
return ir::stmt::For(ir::Var("i"),
64+
ir::Expr(0),
65+
ir::Expr(extents[index]),
66+
ir::ForType::Serial,
67+
ir::DeviceAPI::CUDA,
68+
ir::stmt::BlockRef(body),
69+
ir::VectorizeInfo(),
70+
ir::BindInfo());
7071
}
7172

7273
void TestHelper(const std::vector<int>& extents1,
7374
const std::vector<int>& extents2,
7475
bool is_same) {
7576
auto for_loop1 = MakeForLoops(extents1, 0);
7677
auto for_loop2 = MakeForLoops(extents2, 0);
77-
auto f1 = for_loop1.As<ir::For>();
78-
auto f2 = for_loop2.As<ir::For>();
7978

8079
if (is_same) {
81-
EXPECT_TRUE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
80+
EXPECT_TRUE(CanMergeBlocks(for_loop1, for_loop2, IsBlockForAllEqual));
8281
} else {
83-
EXPECT_FALSE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
82+
EXPECT_FALSE(CanMergeBlocks(for_loop1, for_loop2, IsBlockForAllEqual));
8483
}
8584
}
8685

8786
void TestHelper2(const std::vector<std::vector<int>>& extents1,
8887
const std::vector<std::vector<int>>& extents2,
8988
bool is_same) {
9089
auto MakeNestLoops =
91-
[&](const std::vector<std::vector<int>>& extents) -> ir::Expr {
92-
std::vector<ir::Expr> for_loops;
90+
[&](const std::vector<std::vector<int>>& extents) -> ir::stmt::For {
91+
std::vector<ir::stmt::StmtRef> for_loops;
9392
for (size_t i = 0; i < extents.size(); ++i) {
9493
for_loops.push_back(MakeForLoops(extents[i], 0));
9594
}
96-
ir::Expr block = ir::Block::Make(for_loops);
97-
ir::Expr for_expr = ir::For::Make(ir::Var("i"),
98-
ir::Expr(0),
99-
ir::Expr(1),
100-
ir::ForType::Serial,
101-
ir::DeviceAPI::CUDA,
102-
block,
103-
ir::VectorizeInfo(),
104-
ir::BindInfo());
105-
return for_expr;
95+
ir::stmt::BlockRef block(for_loops);
96+
ir::stmt::For for_stmt = ir::stmt::For(ir::Var("i"),
97+
ir::Expr(0),
98+
ir::Expr(1),
99+
ir::ForType::Serial,
100+
ir::DeviceAPI::CUDA,
101+
block,
102+
ir::VectorizeInfo(),
103+
ir::BindInfo());
104+
return for_stmt;
106105
};
107106

108-
auto for_expr1 = MakeNestLoops(extents1);
109-
auto for_expr2 = MakeNestLoops(extents2);
110-
auto f1 = for_expr1.As<ir::For>();
111-
auto f2 = for_expr2.As<ir::For>();
107+
auto for_stmt1 = MakeNestLoops(extents1);
108+
auto for_stmt2 = MakeNestLoops(extents2);
112109

113110
if (is_same) {
114-
EXPECT_TRUE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
111+
EXPECT_TRUE(CanMergeBlocks(for_stmt1, for_stmt2, IsBlockForAllEqual));
115112
} else {
116-
EXPECT_FALSE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
113+
EXPECT_FALSE(CanMergeBlocks(for_stmt1, for_stmt2, IsBlockForAllEqual));
117114
}
118115
}
119116

0 commit comments

Comments
 (0)