@@ -98,7 +98,13 @@ class CandidateSelector final : public StmtExprVisitor {
9898 void VisitStmt_ (const ForNode* op) final {
9999 // partition const loop when sets partition_const_loop_
100100 if (!is_const_int (op->min ) || !is_const_int (op->extent ) || partition_const_loop_) {
101+ // always treat var with hint to be partitioned
101102 const VarNode* var = op->loop_var .get ();
103+ if (partition_hint_vars.count (var)) {
104+ candidates.insert (GetRef<Stmt>(op));
105+ StmtExprVisitor::VisitStmt_ (op);
106+ return ;
107+ }
102108 record_.insert ({var, false });
103109 StmtExprVisitor::VisitStmt_ (op);
104110 if (record_.at (var) && !no_split_) {
@@ -117,6 +123,12 @@ class CandidateSelector final : public StmtExprVisitor {
117123 Var var = iv->var ;
118124 runtime::ThreadScope scope = runtime::ThreadScope::Create (iv->thread_tag );
119125 if ((scope.rank == 0 ) && (!is_const_int (op->value ) || partition_const_loop_)) {
126+ // always treat var with hint to be partitioned
127+ if (partition_hint_vars.count (var.get ())) {
128+ candidates.insert (GetRef<Stmt>(op));
129+ StmtExprVisitor::VisitStmt_ (op);
130+ return ;
131+ }
120132 record_.insert ({var.get (), false });
121133 StmtExprVisitor::VisitStmt_ (op);
122134 if (record_.at (var.get ()) && !no_split_) {
@@ -125,6 +137,15 @@ class CandidateSelector final : public StmtExprVisitor {
125137 record_.erase (var.get ());
126138 return ;
127139 }
140+ } else if (op->attr_key == attr::pragma_loop_partition_hint) {
141+ const VarNode* var = nullptr ;
142+ if (op->node ->IsInstance <VarNode>()) {
143+ var = op->node .as <VarNode>();
144+ } else if (op->node ->IsInstance <IterVarNode>()) {
145+ var = op->node .as <IterVarNode>()->var .get ();
146+ }
147+ ICHECK (var);
148+ partition_hint_vars.insert (var);
128149 }
129150 StmtExprVisitor::VisitStmt_ (op);
130151 }
@@ -162,6 +183,7 @@ class CandidateSelector final : public StmtExprVisitor {
162183 }
163184
164185 std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
186+ std::unordered_set<const VarNode*> partition_hint_vars;
165187
166188 private:
167189 bool in_likely_{false };
@@ -170,15 +192,28 @@ class CandidateSelector final : public StmtExprVisitor {
170192 std::unordered_map<const VarNode*, VarIsUsed> record_;
171193};
172194
195+ // Finder try best to find partitions for hinted vars
196+ #define DEFINE_PARTITION_FINDER_VISIT_CMP_OP (OpNodeT ) \
197+ void VisitExpr_ (const OpNodeT* op) final { \
198+ if (has_partition_hint_) { \
199+ DeduceCondition (GetRef<PrimExpr>(op)); \
200+ return ; \
201+ } \
202+ StmtExprVisitor::VisitExpr_ (op); \
203+ }
204+
173205// Populate partitions data structure, i.e., for a specific variable,
174- // find an interval in which each condition
175- // (currently, "likely" conditions) has fixed true or false value
206+ // find an interval in which each condition has fixed true or false value
176207class PartitionFinder : public StmtExprVisitor {
177208 public:
178209 explicit PartitionFinder (Var current_var,
179210 const std::unordered_map<const VarNode*, IntSet>& hint_map,
180- const std::unordered_map<const VarNode*, IntSet>& relax_map)
181- : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
211+ const std::unordered_map<const VarNode*, IntSet>& relax_map,
212+ bool has_partition_hint)
213+ : current_var_(current_var),
214+ has_partition_hint_(has_partition_hint),
215+ hint_map_(hint_map),
216+ relax_map_(relax_map) {
182217 for (const auto & kv : hint_map) {
183218 out_vars_.insert (kv.first );
184219 }
@@ -218,33 +253,43 @@ class PartitionFinder : public StmtExprVisitor {
218253
219254 void VisitExpr_ (const CallNode* op) final {
220255 if (op->op .same_as (builtin::likely ())) {
221- PrimExpr cond = op->args [0 ];
222- if (UsesVar (cond, [this ](const VarNode* var) { return var == current_var_.get (); })) {
223- // For cond, find out the interval, if exists, in which we can prove that cond is
224- // true. Also find the interval, if exists, in which we can prove that cond is
225- // false.
226- IntSet interval = DeduceBound (current_var_, cond, hint_map_, relax_map_);
227- if (!interval.IsNothing ()) {
228- // cond is true within interval
229- partitions[{cond, true }] = interval;
230- }
231- PrimExpr inverse_cond = InverseCond (cond);
232- if (inverse_cond.defined ()) {
233- IntSet interval = DeduceBound (current_var_, inverse_cond, hint_map_, relax_map_);
234- if (!interval.IsNothing ()) {
235- // cond is false within interval
236- partitions[{cond, false }] = interval;
237- }
238- }
239- }
256+ DeduceCondition (op->args [0 ]);
240257 } else {
241258 StmtExprVisitor::VisitExpr_ (op);
242259 }
243260 }
244261
262+ DEFINE_PARTITION_FINDER_VISIT_CMP_OP (GENode);
263+ DEFINE_PARTITION_FINDER_VISIT_CMP_OP (GTNode);
264+ DEFINE_PARTITION_FINDER_VISIT_CMP_OP (LENode);
265+ DEFINE_PARTITION_FINDER_VISIT_CMP_OP (LTNode);
266+ DEFINE_PARTITION_FINDER_VISIT_CMP_OP (EQNode);
267+ DEFINE_PARTITION_FINDER_VISIT_CMP_OP (NENode);
268+
245269 Partition partitions;
246270
247271 private:
272+ void DeduceCondition (const PrimExpr& cond) {
273+ // For cond, find out the interval, if exists, in which we can prove that cond is
274+ // true. Also find the interval, if exists, in which we can prove that cond is
275+ // false.
276+ if (UsesVar (cond, [this ](const VarNode* var) { return var == current_var_.get (); })) {
277+ IntSet interval = DeduceBound (current_var_, cond, hint_map_, relax_map_);
278+ if (!interval.IsNothing ()) {
279+ // cond is true within interval
280+ partitions[{cond, true }] = interval;
281+ }
282+ PrimExpr inverse_cond = InverseCond (cond);
283+ if (inverse_cond.defined ()) {
284+ IntSet interval = DeduceBound (current_var_, inverse_cond, hint_map_, relax_map_);
285+ if (!interval.IsNothing ()) {
286+ // cond is false within interval
287+ partitions[{cond, false }] = interval;
288+ }
289+ }
290+ }
291+ }
292+
248293 PrimExpr InverseCond (const PrimExpr& cond) {
249294 PrimExpr inverse_cond;
250295 if (const LTNode* op = cond.as <LTNode>()) {
@@ -270,6 +315,7 @@ class PartitionFinder : public StmtExprVisitor {
270315 }
271316
272317 Var current_var_;
318+ bool has_partition_hint_;
273319 std::unordered_set<const VarNode*> out_vars_;
274320 std::unordered_map<const VarNode*, IntSet> hint_map_;
275321 std::unordered_map<const VarNode*, IntSet> relax_map_;
@@ -472,7 +518,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
472518 // include hint of var.
473519 hint_map_.insert ({var.get (), IntSet::Interval (min, max)});
474520
475- PartitionFinder finder (var, hint_map_, relax_map_);
521+ bool has_partition_hint_ = selector.partition_hint_vars .count (var.get ());
522+ PartitionFinder finder (var, hint_map_, relax_map_, has_partition_hint_);
476523 finder (body);
477524
478525 hint_map_.erase (var.get ());
@@ -601,7 +648,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
601648 }
602649}
603650
604- class RemoveLikelyTags : public StmtExprMutator {
651+ class RemoveLikelyTagsAndHints : public StmtExprMutator {
605652 public:
606653 PrimExpr VisitExpr_ (const CallNode* op) final {
607654 if (op->op .same_as (builtin::likely ())) {
@@ -611,12 +658,19 @@ class RemoveLikelyTags : public StmtExprMutator {
611658 return StmtExprMutator::VisitExpr_ (op);
612659 }
613660 }
661+
662+ Stmt VisitStmt_ (const AttrStmtNode* op) final {
663+ if (op->attr_key == attr::pragma_loop_partition_hint) {
664+ return VisitStmt (op->body );
665+ }
666+ return StmtExprMutator::VisitStmt_ (op);
667+ }
614668};
615669
616670Stmt LoopPartition (Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) {
617671 stmt = LoopPartitioner (partition_const_loop, no_unroll_loop_with_extent_one)
618672 .VisitAndMutate (std::move (stmt));
619- stmt = RemoveLikelyTags ()(std::move (stmt));
673+ stmt = RemoveLikelyTagsAndHints ()(std::move (stmt));
620674 return stmt;
621675}
622676
0 commit comments