@@ -53,6 +53,35 @@ class RegionMerger : public MixedModeVisitor {
5353 public:
5454 explicit RegionMerger (AnnotatedRegionSet regions) : regions_(regions) {}
5555
56+ void find_control_flow_regions (
57+ const Expr op,
58+ std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>& correlative_regions) {
59+ // Find correlative restriction regions from control flow.
60+
61+ // In IfNode, find from condition, true_branch and false branch.
62+ const IfNode* if_node = op.as <IfNode>();
63+ if (if_node) {
64+ auto cond_region = regions_->GetRegion (if_node->cond );
65+ auto true_branch_region = regions_->GetRegion (if_node->true_branch );
66+ auto false_branch_region = regions_->GetRegion (if_node->false_branch );
67+ if (cond_region.defined ()) {
68+ correlative_regions.insert (cond_region);
69+ } else {
70+ find_control_flow_regions (if_node->cond , correlative_regions);
71+ }
72+ if (true_branch_region.defined ()) {
73+ correlative_regions.insert (true_branch_region);
74+ } else {
75+ find_control_flow_regions (if_node->true_branch , correlative_regions);
76+ }
77+ if (false_branch_region.defined ()) {
78+ correlative_regions.insert (false_branch_region);
79+ } else {
80+ find_control_flow_regions (if_node->false_branch , correlative_regions);
81+ }
82+ }
83+ }
84+
5685 void VisitExpr_ (const CallNode* call) final {
5786 if (call->op == CompilerEndOp ()) {
5887 auto region = regions_->GetRegion (GetRef<Call>(call));
@@ -84,18 +113,23 @@ class RegionMerger : public MixedModeVisitor {
84113
85114 // Collect unmerged parent regions.
86115 std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> mergeable_regions;
116+ // Collect correlative regions to propagate restrictions.
117+ std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> correlative_regions;
87118 for (const auto & arg : region->GetInputs ()) {
88119 auto begin = Downcast<Call>(arg);
89120 ICHECK_EQ (begin->op , CompilerBeginOp ());
90121 auto parent_region = regions_->GetRegion (begin->args [0 ]);
91122 if (parent_region.defined ()) {
92123 mergeable_regions.insert (parent_region);
124+ correlative_regions.insert (parent_region);
125+ } else {
126+ find_control_flow_regions (begin->args [0 ], correlative_regions);
93127 }
94128 }
95129
96130 // Propogate all the parent restrictions to the current region.
97131 auto & region_restrictions = region_restrictions_[region->GetID ()];
98- for (const auto & parent_region : mergeable_regions ) {
132+ for (const auto & parent_region : correlative_regions ) {
99133 auto parent_restrictions = region_restrictions_[parent_region->GetID ()];
100134 region_restrictions.insert (parent_restrictions.begin (), parent_restrictions.end ());
101135 }
0 commit comments