Skip to content

Commit f8f0282

Browse files
authored
[SCHEDULE] Refactor bound inference logic (#41)
1 parent 5c07413 commit f8f0282

File tree

4 files changed

+205
-115
lines changed

4 files changed

+205
-115
lines changed

include/tvm/schedule_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace schedule {
2222
* \param sch The root schedule to infer all the bounds.
2323
* \return the result bound of the iteration Variable
2424
*/
25-
Map<IterVar, Range> InferBound(Schedule sch);
25+
Map<IterVar, Range> InferBound(const Schedule& sch);
2626

2727
/*!
2828
* \brief Schedule s' dependent operations.

src/arithmetic/int_set.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
432432
.set_dispatch<And>(Binary<And>)
433433
.set_dispatch<Or>(Binary<Or>);
434434

435-
436435
IntSet EvalSet(Expr e,
437436
const std::unordered_map<const Variable*, IntSet>& dom_map) {
438437
return IntSetEvaluator(dom_map).Eval(e);
@@ -444,17 +443,12 @@ IntSet EvalSet(Expr e,
444443
for (auto kv : dom_map) {
445444
dmap[kv.first->var.as<Variable>()] = kv.second;
446445
}
447-
IntSetEvaluator m(dmap);
448-
return m.Eval(e);
446+
return EvalSet(e, dmap);
449447
}
450448

451449
IntSet EvalSet(Range r,
452-
const Map<IterVar, IntSet>& dom_map) {
453-
std::unordered_map<const Variable*, IntSet> dmap;
454-
for (auto kv : dom_map) {
455-
dmap[kv.first->var.as<Variable>()] = kv.second;
456-
}
457-
IntSetEvaluator m(dmap);
450+
const std::unordered_map<const Variable*, IntSet>& dom_map) {
451+
IntSetEvaluator m(dom_map);
458452
IntSet min_set = m.Eval(r->min);
459453
IntSet ext_set = m.Eval(r->extent).cover_interval();
460454
const Interval& ei = ext_set.as<IntervalSet>()->i;
@@ -463,13 +457,21 @@ IntSet EvalSet(Range r,
463457
return Combine<Add>(min_set, ext_set);
464458
}
465459

460+
IntSet EvalSet(Range r,
461+
const Map<IterVar, IntSet>& dom_map) {
462+
std::unordered_map<const Variable*, IntSet> dmap;
463+
for (auto kv : dom_map) {
464+
dmap[kv.first->var.as<Variable>()] = kv.second;
465+
}
466+
return EvalSet(r, dmap);
467+
}
468+
466469
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
467470
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
468471
p->stream << "interval-set["
469472
<< "[" << op->i.min << ", "
470473
<< op->i.max << ']';
471474
});
472475

473-
474476
} // namespace arith
475477
} // namespace tvm

src/arithmetic/int_set.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ IntSet EvalSet(Expr e,
103103
*/
104104
IntSet EvalSet(Range r,
105105
const Map<IterVar, IntSet>& dom_map);
106+
IntSet EvalSet(Range r,
107+
const std::unordered_map<const Variable*, IntSet>& dom_map);
108+
106109

107110
/*!
108111
* \brief Create an union set of all sets

0 commit comments

Comments
 (0)