@@ -21,22 +21,27 @@ struct LinearEqEntry {
2121 Expr coeff;
2222};
2323
24+ struct IntervalEntry {
25+ Expr min_value;
26+ Expr max_value;
27+ };
28+
2429class LinearEqDetector
2530 : public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
2631 public:
2732 explicit LinearEqDetector (Var var)
2833 : var_(var) {}
2934
30- Array<Expr> Detect (const Expr& e) {
31- LinearEqEntry ret = VisitExpr (e, e);
32- if (fail_) return Array<Expr>() ;
33- if (!ret. base .defined ()) {
34- ret. base = make_zero (var_.type ());
35+ bool Detect (const Expr& e, LinearEqEntry* ret ) {
36+ * ret = VisitExpr (e, e);
37+ if (fail_) return false ;
38+ if (!ret-> base .defined ()) {
39+ ret-> base = make_zero (var_.type ());
3540 }
36- if (!ret. coeff .defined ()) {
37- ret. coeff = make_zero (var_.type ());
41+ if (!ret-> coeff .defined ()) {
42+ ret-> coeff = make_zero (var_.type ());
3843 }
39- return Array<Expr>{ret. base , ret. coeff } ;
44+ return true ;
4045 }
4146
4247 LinearEqEntry VisitExpr_ (const Add* op, const Expr& e) final {
@@ -48,6 +53,17 @@ class LinearEqDetector
4853 ret.coeff = AddCombine (a.coeff , b.coeff );
4954 return ret;
5055 }
56+
57+ LinearEqEntry VisitExpr_ (const Sub* op, const Expr& e) final {
58+ if (fail_) return LinearEqEntry ();
59+ LinearEqEntry a = VisitExpr (op->a , op->a );
60+ LinearEqEntry b = VisitExpr (op->b , op->b );
61+ LinearEqEntry ret;
62+ ret.base = SubCombine (a.base , b.base );
63+ ret.coeff = SubCombine (a.coeff , b.coeff );
64+ return ret;
65+ }
66+
5167 LinearEqEntry VisitExpr_ (const Mul* op, const Expr& e) final {
5268 if (fail_) return LinearEqEntry ();
5369 LinearEqEntry a = VisitExpr (op->a , op->a );
@@ -94,16 +110,146 @@ class LinearEqDetector
94110 if (!b.defined ()) return a;
95111 return ComputeExpr<Add>(a, b);
96112 }
113+ Expr SubCombine (Expr a, Expr b) {
114+ if (!a.defined ()) return -b;
115+ if (!b.defined ()) return a;
116+ return ComputeExpr<Sub>(a, b);
117+ }
97118 Expr MulCombine (Expr a, Expr b) {
98119 if (!a.defined ()) return a;
99120 if (!b.defined ()) return b;
100121 return ComputeExpr<Mul>(a, b);
101122 }
102123};
103124
104- Array<Expr> DetectLinearEquation (Expr e, Var var) {
105- return LinearEqDetector (var).Detect (e);
125+ Array<Expr> DetectLinearEquation (const Expr& e, const Array<Var>& vars) {
126+ CHECK_GE (vars.size (), 1U );
127+ Expr base = e;
128+ Array<Expr> coeff;
129+
130+ for (Var v : vars) {
131+ LinearEqEntry ret;
132+ if (!LinearEqDetector (v).Detect (base, &ret)) {
133+ return Array<Expr>();
134+ }
135+ coeff.push_back (ret.coeff );
136+ base = std::move (ret.base );
137+ }
138+
139+ std::unordered_set<const Variable*> vset;
140+ for (size_t i = vars.size (); i != 1 ; --i) {
141+ vset.insert (vars[i - 1 ].get ());
142+ // The previous coeff contains the variable
143+ if (ExprUseVar (coeff[i - 2 ], vset)) {
144+ return Array<Expr>();
145+ }
146+ }
147+ coeff.push_back (base);
148+ return coeff;
106149}
107150
151+ // Detect clip condition as min max value
152+ bool DetectClipBound (
153+ const Expr& cond,
154+ std::unordered_map<const Variable*, IntervalEntry>* bmap) {
155+ int flag = 0 ;
156+ Var var;
157+ auto fvisit = [&bmap, &flag, &var](const NodeRef& n) {
158+ if (const Variable* v = n.as <Variable>()) {
159+ if (bmap->count (v)) {
160+ if (flag == 0 ) {
161+ var = Var (n.node_ );
162+ flag = 1 ;
163+ } else if (flag == 1 ) {
164+ if (!var.same_as (n)) {
165+ flag = -1 ;
166+ }
167+ }
168+ }
169+ }
170+ };
171+ PostOrderVisit (cond, fvisit);
172+ if (flag != 1 ) return false ;
173+ // canonical form: exp >= 0
174+ Expr canonical;
175+ if (const LT* op = cond.as <LT>()) {
176+ if (!op->a .type ().is_int ()) return false ;
177+ canonical = op->b - op->a - make_const (op->a .type (), 1 );
178+ } else if (const LE* op = cond.as <LE>()) {
179+ if (!op->a .type ().is_int ()) return false ;
180+ canonical = op->b - op->a ;
181+ } else if (const GT* op = cond.as <GT>()) {
182+ if (!op->a .type ().is_int ()) return false ;
183+ canonical = op->a - op->b - make_const (op->a .type (), 1 );
184+ } else if (const GE* op = cond.as <GE>()) {
185+ if (!op->a .type ().is_int ()) return false ;
186+ canonical = op->a - op->b ;
187+ } else {
188+ return false ;
189+ }
190+ LinearEqEntry ret;
191+ if (!LinearEqDetector (var).Detect (canonical, &ret)) return false ;
192+ ret.coeff = Simplify (ret.coeff );
193+ IntervalEntry& p = (*bmap)[var.get ()];
194+ if (is_one (ret.coeff )) {
195+ // var + shift >=0 -> var >= -shift
196+ if (p.min_value .defined ()) {
197+ p.min_value = ir::Max::make (p.min_value , -ret.base );
198+ } else {
199+ p.min_value = -ret.base ;
200+ }
201+ return true ;
202+ }
203+ if (is_const (ret.coeff , -1 )) {
204+ // -var + shift >=0 -> var <= shift
205+ if (p.max_value .defined ()) {
206+ p.max_value = ir::Min::make (p.max_value , ret.base );
207+ } else {
208+ p.max_value = ret.base ;
209+ }
210+ return true ;
211+ }
212+ return false ;
213+ }
214+
215+
216+ template <typename OP>
217+ void SplitCommExpr (const Expr& e, std::vector<Expr>* ret) {
218+ if (const OP* op = e.as <OP>()) {
219+ SplitCommExpr<OP>(op->a , ret);
220+ SplitCommExpr<OP>(op->b , ret);
221+ } else {
222+ ret->push_back (e);
223+ }
224+ }
225+
226+ // Detect the lower and upper bound from the expression.
227+ // e must be connected by and.
228+ Array<Expr> DetectClipBound (const Expr& e, const Array<Var>& vars) {
229+ std::vector<Expr> splits;
230+ SplitCommExpr<ir::And>(e, &splits);
231+ std::unordered_map<const Variable*, IntervalEntry> rmap;
232+ for (Var v : vars) {
233+ rmap[v.get ()] = IntervalEntry ();
234+ }
235+ for (Expr cond : splits) {
236+ if (!DetectClipBound (cond, &rmap)) return Array<Expr>();
237+ }
238+ Array<Expr> ret;
239+ for (Var v : vars) {
240+ IntervalEntry e = rmap[v.get ()];
241+ if (e.min_value .defined ()) {
242+ e.min_value = Simplify (e.min_value );
243+ }
244+ if (e.max_value .defined ()) {
245+ e.max_value = Simplify (e.max_value );
246+ }
247+ ret.push_back (e.min_value );
248+ ret.push_back (e.max_value );
249+ }
250+ return ret;
251+ }
252+
253+
108254} // namespace arith
109255} // namespace tvm
0 commit comments