Skip to content

Commit 4b08cea

Browse files
cccxinlijunrushao
authored andcommitted
[RELAY] Fix bug in MergeCompilerRegions pass (apache#15211)
1 parent 4216478 commit 4b08cea

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

src/relay/transforms/merge_compiler_regions.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,35 @@ class RegionMerger : public MixedModeVisitor {
5353
public:
5454
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
5555

56+
void find_control_flow_regions(
57+
const Expr op,
58+
std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>& correlative_regions) {
59+
// Find correlative restriction regions from control flow.
60+
61+
// In IfNode, find from condition, true_branch and false branch.
62+
const IfNode* if_node = op.as<IfNode>();
63+
if (if_node) {
64+
auto cond_region = regions_->GetRegion(if_node->cond);
65+
auto true_branch_region = regions_->GetRegion(if_node->true_branch);
66+
auto false_branch_region = regions_->GetRegion(if_node->false_branch);
67+
if (cond_region.defined()) {
68+
correlative_regions.insert(cond_region);
69+
} else {
70+
find_control_flow_regions(if_node->cond, correlative_regions);
71+
}
72+
if (true_branch_region.defined()) {
73+
correlative_regions.insert(true_branch_region);
74+
} else {
75+
find_control_flow_regions(if_node->true_branch, correlative_regions);
76+
}
77+
if (false_branch_region.defined()) {
78+
correlative_regions.insert(false_branch_region);
79+
} else {
80+
find_control_flow_regions(if_node->false_branch, correlative_regions);
81+
}
82+
}
83+
}
84+
5685
void VisitExpr_(const CallNode* call) final {
5786
if (call->op == CompilerEndOp()) {
5887
auto region = regions_->GetRegion(GetRef<Call>(call));
@@ -84,18 +113,23 @@ class RegionMerger : public MixedModeVisitor {
84113

85114
// Collect unmerged parent regions.
86115
std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> mergeable_regions;
116+
// Collect correlative regions to propagate restrictions.
117+
std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> correlative_regions;
87118
for (const auto& arg : region->GetInputs()) {
88119
auto begin = Downcast<Call>(arg);
89120
ICHECK_EQ(begin->op, CompilerBeginOp());
90121
auto parent_region = regions_->GetRegion(begin->args[0]);
91122
if (parent_region.defined()) {
92123
mergeable_regions.insert(parent_region);
124+
correlative_regions.insert(parent_region);
125+
} else {
126+
find_control_flow_regions(begin->args[0], correlative_regions);
93127
}
94128
}
95129

96130
// Propogate all the parent restrictions to the current region.
97131
auto& region_restrictions = region_restrictions_[region->GetID()];
98-
for (const auto& parent_region : mergeable_regions) {
132+
for (const auto& parent_region : correlative_regions) {
99133
auto parent_restrictions = region_restrictions_[parent_region->GetID()];
100134
region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
101135
}

tests/python/relay/test_pass_merge_compiler_regions.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Unit tests for merge compiler regions."""
1818
import tvm
1919
from tvm import relay
20+
import tvm.relay.transform as transform
2021
from tvm.relay.op.annotation import compiler_begin, compiler_end
2122
from tvm.relay.testing import run_opt_pass
2223

@@ -214,6 +215,67 @@ def expected():
214215
assert tvm.ir.structural_equal(mod, ref_mod)
215216

216217

218+
def test_if_else():
219+
"""
220+
This tests that the restriction regions propagate successful in
221+
if_else control flow.
222+
223+
O = supported by target
224+
X = not supported by target
225+
226+
227+
O1 - - - | O1 --|
228+
| | |
229+
X | X
230+
| | |
231+
If cond ? O1: X | --> + + If cond ? O1: X +
232+
| | |
233+
O2 <- - -| O2 <-|
234+
235+
236+
Avoid O1 merge to O2.
237+
"""
238+
239+
target = "test_if_else"
240+
241+
@tvm.ir.register_op_attr("sigmoid", "target." + target)
242+
def sigmoid(expr): # pylint: disable=unused-variable
243+
return True
244+
245+
@tvm.ir.register_op_attr("erf", "target." + target)
246+
def erf(expr): # pylint: disable=unused-variable
247+
return True
248+
249+
@tvm.ir.register_op_attr("add", "target." + target)
250+
def add(expr): # pylint: disable=unused-variable
251+
return True
252+
253+
"""Test that If-else nodes merges regions correctly."""
254+
255+
def get_mod():
256+
data = relay.var("data", shape=(1, 32))
257+
add0 = relay.add(data, data)
258+
sub0 = relay.subtract(add0, data)
259+
eq = relay.equal(relay.sum(add0), relay.sum(sub0))
260+
261+
true_branch = relay.sigmoid(add0)
262+
false_branch = relay.sigmoid(sub0)
263+
ife = relay.If(eq, true_branch, false_branch)
264+
erf = relay.erf(ife)
265+
out = relay.add(add0, erf)
266+
func = relay.Function([data], out)
267+
mod = tvm.IRModule.from_expr(func)
268+
269+
return mod
270+
271+
for annotate_non_call_ops in [True, False]:
272+
result = transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
273+
merge = transform.MergeCompilerRegions()(result)
274+
# Ensure partition finished without segment fault.
275+
partition = transform.PartitionGraph()(merge)
276+
277+
217278
if __name__ == "__main__":
218279
test_diamond_graph_fanouts()
219280
test_example_graph()
281+
test_if_else()

0 commit comments

Comments
 (0)