Skip to content

Commit 6503895

Browse files
authored
[ARITH] Improve detect linear equation (#529)
* [ARITH] Improve detect linear equation * fix doc
1 parent 4608222 commit 6503895

File tree

7 files changed

+236
-33
lines changed

7 files changed

+236
-33
lines changed

HalideIR

include/tvm/arithmetic.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class IntSet : public NodeRef {
118118
* \brief Range of a linear integer function.
119119
* Use to do specify the possible index values.
120120
*
121-
* set = { base + coeff * x | x in Z }
121+
* set = { coeff * x + base | x in Z }
122122
*
123123
* When coeff != 0, it can also be written as
124124
* set = { n | n % coeff == base }
@@ -127,16 +127,17 @@ class IntSet : public NodeRef {
127127
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
128128
*/
129129
struct ModularEntry {
130-
/*! \brief The base */
131-
int base{0};
132130
/*! \brief linear co-efficient */
133131
int coeff{1};
132+
/*! \brief The base */
133+
int base{0};
134134

135135
/*! \return entry represent everything */
136136
static ModularEntry everything() {
137137
// always safe to set 0 + x, so it can be everything.
138138
ModularEntry e;
139-
e.base = 0; e.coeff = 1;
139+
e.coeff = 1;
140+
e.base = 0;
140141
return e;
141142
}
142143
/*!
@@ -157,14 +158,25 @@ struct IntSetNode : public Node {
157158
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
158159
};
159160

160-
161161
/*!
162-
* \brief Detect if e can be rewritten as e = base + var * coeff
162+
* \brief Detect if e can be rewritten as e = sum_{i=0}^n var[i] * coeff[i] + coeff[n]
163163
* Where coeff and base are invariant of var.
164164
*
165-
* \return [base, coeff] if it is possible, empty array if it is not.
165+
* \param e The expression to be detected.
166+
* \param vars List of variables to be used in detection.
167+
* \return [coeff[i]] if it is possible, empty array if it is not.
168+
*/
169+
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);
170+
171+
/*!
172+
* \brief Detect if expression corresponds to clip bound of the vars
173+
*
174+
* \param e The expression to be detected.
175+
* \param vars List of variables to be used in detection.
176+
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
177+
* return empty if the e does not match the pattern.
166178
*/
167-
Array<Expr> DetectLinearEquation(Expr e, Var var);
179+
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
168180

169181
/*!
170182
* \brief Find an symbolic integer set that contains all possible values of

src/api/api_arith.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ TVM_REGISTER_API("arith.DetectLinearEquation")
3636
*ret = DetectLinearEquation(args[0], args[1]);
3737
});
3838

39+
TVM_REGISTER_API("arith.DetectClipBound")
40+
.set_body([](TVMArgs args, TVMRetValue *ret) {
41+
*ret = DetectClipBound(args[0], args[1]);
42+
});
43+
3944
TVM_REGISTER_API("arith.DeduceBound")
4045
.set_body([](TVMArgs args, TVMRetValue *ret) {
4146
*ret = DeduceBound(args[0], args[1],

src/arithmetic/detect_linear_equation.cc

Lines changed: 156 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,27 @@ struct LinearEqEntry {
2121
Expr coeff;
2222
};
2323

24+
struct IntervalEntry {
25+
Expr min_value;
26+
Expr max_value;
27+
};
28+
2429
class LinearEqDetector
2530
: public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
2631
public:
2732
explicit LinearEqDetector(Var var)
2833
: var_(var) {}
2934

30-
Array<Expr> Detect(const Expr& e) {
31-
LinearEqEntry ret = VisitExpr(e, e);
32-
if (fail_) return Array<Expr>();
33-
if (!ret.base.defined()) {
34-
ret.base = make_zero(var_.type());
35+
bool Detect(const Expr& e, LinearEqEntry* ret) {
36+
*ret = VisitExpr(e, e);
37+
if (fail_) return false;
38+
if (!ret->base.defined()) {
39+
ret->base = make_zero(var_.type());
3540
}
36-
if (!ret.coeff.defined()) {
37-
ret.coeff = make_zero(var_.type());
41+
if (!ret->coeff.defined()) {
42+
ret->coeff = make_zero(var_.type());
3843
}
39-
return Array<Expr>{ret.base, ret.coeff};
44+
return true;
4045
}
4146

4247
LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
@@ -48,6 +53,17 @@ class LinearEqDetector
4853
ret.coeff = AddCombine(a.coeff, b.coeff);
4954
return ret;
5055
}
56+
57+
LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final {
58+
if (fail_) return LinearEqEntry();
59+
LinearEqEntry a = VisitExpr(op->a, op->a);
60+
LinearEqEntry b = VisitExpr(op->b, op->b);
61+
LinearEqEntry ret;
62+
ret.base = SubCombine(a.base, b.base);
63+
ret.coeff = SubCombine(a.coeff, b.coeff);
64+
return ret;
65+
}
66+
5167
LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
5268
if (fail_) return LinearEqEntry();
5369
LinearEqEntry a = VisitExpr(op->a, op->a);
@@ -94,16 +110,146 @@ class LinearEqDetector
94110
if (!b.defined()) return a;
95111
return ComputeExpr<Add>(a, b);
96112
}
113+
Expr SubCombine(Expr a, Expr b) {
114+
if (!a.defined()) return -b;
115+
if (!b.defined()) return a;
116+
return ComputeExpr<Sub>(a, b);
117+
}
97118
Expr MulCombine(Expr a, Expr b) {
98119
if (!a.defined()) return a;
99120
if (!b.defined()) return b;
100121
return ComputeExpr<Mul>(a, b);
101122
}
102123
};
103124

104-
Array<Expr> DetectLinearEquation(Expr e, Var var) {
105-
return LinearEqDetector(var).Detect(e);
125+
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
126+
CHECK_GE(vars.size(), 1U);
127+
Expr base = e;
128+
Array<Expr> coeff;
129+
130+
for (Var v : vars) {
131+
LinearEqEntry ret;
132+
if (!LinearEqDetector(v).Detect(base, &ret)) {
133+
return Array<Expr>();
134+
}
135+
coeff.push_back(ret.coeff);
136+
base = std::move(ret.base);
137+
}
138+
139+
std::unordered_set<const Variable*> vset;
140+
for (size_t i = vars.size(); i != 1; --i) {
141+
vset.insert(vars[i - 1].get());
142+
// The previous coeff contains the variable
143+
if (ExprUseVar(coeff[i - 2], vset)) {
144+
return Array<Expr>();
145+
}
146+
}
147+
coeff.push_back(base);
148+
return coeff;
106149
}
107150

151+
// Detect clip condition as min max value
152+
bool DetectClipBound(
153+
const Expr& cond,
154+
std::unordered_map<const Variable*, IntervalEntry>* bmap) {
155+
int flag = 0;
156+
Var var;
157+
auto fvisit = [&bmap, &flag, &var](const NodeRef& n) {
158+
if (const Variable* v = n.as<Variable>()) {
159+
if (bmap->count(v)) {
160+
if (flag == 0) {
161+
var = Var(n.node_);
162+
flag = 1;
163+
} else if (flag == 1) {
164+
if (!var.same_as(n)) {
165+
flag = -1;
166+
}
167+
}
168+
}
169+
}
170+
};
171+
PostOrderVisit(cond, fvisit);
172+
if (flag != 1) return false;
173+
// canonical form: exp >= 0
174+
Expr canonical;
175+
if (const LT* op = cond.as<LT>()) {
176+
if (!op->a.type().is_int()) return false;
177+
canonical = op->b - op->a - make_const(op->a.type(), 1);
178+
} else if (const LE* op = cond.as<LE>()) {
179+
if (!op->a.type().is_int()) return false;
180+
canonical = op->b - op->a;
181+
} else if (const GT* op = cond.as<GT>()) {
182+
if (!op->a.type().is_int()) return false;
183+
canonical = op->a - op->b - make_const(op->a.type(), 1);
184+
} else if (const GE* op = cond.as<GE>()) {
185+
if (!op->a.type().is_int()) return false;
186+
canonical = op->a - op->b;
187+
} else {
188+
return false;
189+
}
190+
LinearEqEntry ret;
191+
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
192+
ret.coeff = Simplify(ret.coeff);
193+
IntervalEntry& p = (*bmap)[var.get()];
194+
if (is_one(ret.coeff)) {
195+
// var + shift >=0 -> var >= -shift
196+
if (p.min_value.defined()) {
197+
p.min_value = ir::Max::make(p.min_value, -ret.base);
198+
} else {
199+
p.min_value = -ret.base;
200+
}
201+
return true;
202+
}
203+
if (is_const(ret.coeff, -1)) {
204+
// -var + shift >=0 -> var <= shift
205+
if (p.max_value.defined()) {
206+
p.max_value = ir::Min::make(p.max_value, ret.base);
207+
} else {
208+
p.max_value = ret.base;
209+
}
210+
return true;
211+
}
212+
return false;
213+
}
214+
215+
216+
template<typename OP>
217+
void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
218+
if (const OP* op = e.as<OP>()) {
219+
SplitCommExpr<OP>(op->a, ret);
220+
SplitCommExpr<OP>(op->b, ret);
221+
} else {
222+
ret->push_back(e);
223+
}
224+
}
225+
226+
// Detect the lower and upper bound from the expression.
227+
// e must be connected by and.
228+
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
229+
std::vector<Expr> splits;
230+
SplitCommExpr<ir::And>(e, &splits);
231+
std::unordered_map<const Variable*, IntervalEntry> rmap;
232+
for (Var v : vars) {
233+
rmap[v.get()] = IntervalEntry();
234+
}
235+
for (Expr cond : splits) {
236+
if (!DetectClipBound(cond, &rmap)) return Array<Expr>();
237+
}
238+
Array<Expr> ret;
239+
for (Var v : vars) {
240+
IntervalEntry e = rmap[v.get()];
241+
if (e.min_value.defined()) {
242+
e.min_value = Simplify(e.min_value);
243+
}
244+
if (e.max_value.defined()) {
245+
e.max_value = Simplify(e.max_value);
246+
}
247+
ret.push_back(e.min_value);
248+
ret.push_back(e.max_value);
249+
}
250+
return ret;
251+
}
252+
253+
108254
} // namespace arith
109255
} // namespace tvm

src/pass/narrow_channel_access.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ class ChannelAccessRewriter : public IRMutator {
175175
r = Range::make_by_min_extent(
176176
ir::Simplify(r->min), ir::Simplify(r->extent));
177177
if (ExprUseVar(r->extent, var)) return body;
178-
Array<Expr> linear_eq = DetectLinearEquation(r->min, var);
178+
Array<Expr> linear_eq = DetectLinearEquation(r->min, {var});
179179
if (linear_eq.size() == 0) return body;
180-
Expr base = linear_eq[0];
181-
Expr coeff = linear_eq[1];
180+
Expr coeff = linear_eq[0];
181+
Expr base = linear_eq[1];
182182
if (!is_zero(base)) return body;
183183
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
184184
if (!can_prove(left >= 0)) return body;
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import tvm
2+
3+
def test_basic():
4+
a = tvm.var("a")
5+
b = tvm.var("b")
6+
c = tvm.var("c")
7+
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
8+
a - 1 > 0), [a])
9+
assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
10+
assert m[0].value == 2
11+
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
12+
a - 1 > 0), [a, b])
13+
assert len(m) == 0
14+
m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
15+
b - 1 > 0), [a, b])
16+
assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
17+
assert tvm.ir_pass.Simplify(m[2] - 2).value == 0
18+
19+
20+
if __name__ == "__main__":
21+
test_basic()

tests/python/unittest/test_arith_detect_linear_equation.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,41 @@
33
def test_basic():
44
a = tvm.var("a")
55
b = tvm.var("b")
6-
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, a)
7-
assert m[1].value == 4
8-
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7)).value == 0
6+
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a])
7+
assert m[0].value == 4
8+
assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0
99

10-
m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, a)
10+
m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a])
1111
assert len(m) == 0
1212

13-
m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, a)
14-
assert m[1].value == 5
15-
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0
13+
m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, [a])
14+
assert m[0].value == 5
15+
assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0
1616

17-
m = tvm.arith.DetectLinearEquation(a * b + 7, a)
18-
assert m[1] == b
17+
m = tvm.arith.DetectLinearEquation(a * b + 7, [a])
18+
assert m[0] == b
1919

20-
m = tvm.arith.DetectLinearEquation(b * 7, a)
21-
assert m[1].value == 0
20+
m = tvm.arith.DetectLinearEquation(b * 7, [a])
21+
assert m[0].value == 0
22+
23+
def test_multivariate():
24+
v = [tvm.var("v%d" % i) for i in range(4)]
25+
b = tvm.var("b")
26+
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
27+
assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5))
28+
assert(m[1].value == 8)
29+
30+
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
31+
assert(len(m) == 0)
32+
33+
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v)
34+
assert(len(m) == 0)
35+
36+
m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v)
37+
assert(m[1].value == 16)
38+
assert(m[2].value == 2)
39+
assert(m[len(m)-1].value == 2)
2240

2341
if __name__ == "__main__":
2442
test_basic()
43+
test_multivariate()

0 commit comments

Comments
 (0)