Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated 2 files
+2 −3 Makefile
+5 −1 src/ir/IRPrinter.cpp
28 changes: 20 additions & 8 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class IntSet : public NodeRef {
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x in Z }
* set = { coeff * x + base | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
Expand All @@ -127,16 +127,17 @@ class IntSet : public NodeRef {
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base{0};
/*! \brief linear co-efficient */
int coeff{1};
/*! \brief The base */
int base{0};

/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
e.coeff = 1;
e.base = 0;
return e;
}
/*!
Expand All @@ -157,14 +158,25 @@ struct IntSetNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};


/*!
* \brief Detect if e can be rewritten as e = base + var * coeff
* \brief Detect if e can be rewritten as e = sum_{i=0}^n var[i] * coeff[i] + coeff[n]
* Where coeff and base are invariant of var.
*
* \return [base, coeff] if it is possible, empty array if it is not.
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);

/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<Expr> DetectLinearEquation(Expr e, Var var);
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);

/*!
* \brief Find an symbolic integer set that contains all possible values of
Expand Down
5 changes: 5 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ TVM_REGISTER_API("arith.DetectLinearEquation")
*ret = DetectLinearEquation(args[0], args[1]);
});

TVM_REGISTER_API("arith.DetectClipBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectClipBound(args[0], args[1]);
});

TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1],
Expand Down
166 changes: 156 additions & 10 deletions src/arithmetic/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,27 @@ struct LinearEqEntry {
Expr coeff;
};

struct IntervalEntry {
Expr min_value;
Expr max_value;
};

class LinearEqDetector
: public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
public:
explicit LinearEqDetector(Var var)
: var_(var) {}

Array<Expr> Detect(const Expr& e) {
LinearEqEntry ret = VisitExpr(e, e);
if (fail_) return Array<Expr>();
if (!ret.base.defined()) {
ret.base = make_zero(var_.type());
bool Detect(const Expr& e, LinearEqEntry* ret) {
*ret = VisitExpr(e, e);
if (fail_) return false;
if (!ret->base.defined()) {
ret->base = make_zero(var_.type());
}
if (!ret.coeff.defined()) {
ret.coeff = make_zero(var_.type());
if (!ret->coeff.defined()) {
ret->coeff = make_zero(var_.type());
}
return Array<Expr>{ret.base, ret.coeff};
return true;
}

LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
Expand All @@ -48,6 +53,17 @@ class LinearEqDetector
ret.coeff = AddCombine(a.coeff, b.coeff);
return ret;
}

LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
LinearEqEntry ret;
ret.base = SubCombine(a.base, b.base);
ret.coeff = SubCombine(a.coeff, b.coeff);
return ret;
}

LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
Expand Down Expand Up @@ -94,16 +110,146 @@ class LinearEqDetector
if (!b.defined()) return a;
return ComputeExpr<Add>(a, b);
}
Expr SubCombine(Expr a, Expr b) {
if (!a.defined()) return -b;
if (!b.defined()) return a;
return ComputeExpr<Sub>(a, b);
}
Expr MulCombine(Expr a, Expr b) {
if (!a.defined()) return a;
if (!b.defined()) return b;
return ComputeExpr<Mul>(a, b);
}
};

Array<Expr> DetectLinearEquation(Expr e, Var var) {
return LinearEqDetector(var).Detect(e);
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
CHECK_GE(vars.size(), 1U);
Expr base = e;
Array<Expr> coeff;

for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}

std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>();
}
}
coeff.push_back(base);
return coeff;
}

// Detect clip condition as min max value
bool DetectClipBound(
const Expr& cond,
std::unordered_map<const Variable*, IntervalEntry>* bmap) {
int flag = 0;
Var var;
auto fvisit = [&bmap, &flag, &var](const NodeRef& n) {
if (const Variable* v = n.as<Variable>()) {
if (bmap->count(v)) {
if (flag == 0) {
var = Var(n.node_);
flag = 1;
} else if (flag == 1) {
if (!var.same_as(n)) {
flag = -1;
}
}
}
}
};
PostOrderVisit(cond, fvisit);
if (flag != 1) return false;
// canonical form: exp >= 0
Expr canonical;
if (const LT* op = cond.as<LT>()) {
if (!op->a.type().is_int()) return false;
canonical = op->b - op->a - make_const(op->a.type(), 1);
} else if (const LE* op = cond.as<LE>()) {
if (!op->a.type().is_int()) return false;
canonical = op->b - op->a;
} else if (const GT* op = cond.as<GT>()) {
if (!op->a.type().is_int()) return false;
canonical = op->a - op->b - make_const(op->a.type(), 1);
} else if (const GE* op = cond.as<GE>()) {
if (!op->a.type().is_int()) return false;
canonical = op->a - op->b;
} else {
return false;
}
LinearEqEntry ret;
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()];
if (is_one(ret.coeff)) {
// var + shift >=0 -> var >= -shift
if (p.min_value.defined()) {
p.min_value = ir::Max::make(p.min_value, -ret.base);
} else {
p.min_value = -ret.base;
}
return true;
}
if (is_const(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift
if (p.max_value.defined()) {
p.max_value = ir::Min::make(p.max_value, ret.base);
} else {
p.max_value = ret.base;
}
return true;
}
return false;
}


template<typename OP>
void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
if (const OP* op = e.as<OP>()) {
SplitCommExpr<OP>(op->a, ret);
SplitCommExpr<OP>(op->b, ret);
} else {
ret->push_back(e);
}
}

// Detect the lower and upper bound from the expression.
// e must be connected by and.
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
std::vector<Expr> splits;
SplitCommExpr<ir::And>(e, &splits);
std::unordered_map<const Variable*, IntervalEntry> rmap;
for (Var v : vars) {
rmap[v.get()] = IntervalEntry();
}
for (Expr cond : splits) {
if (!DetectClipBound(cond, &rmap)) return Array<Expr>();
}
Array<Expr> ret;
for (Var v : vars) {
IntervalEntry e = rmap[v.get()];
if (e.min_value.defined()) {
e.min_value = Simplify(e.min_value);
}
if (e.max_value.defined()) {
e.max_value = Simplify(e.max_value);
}
ret.push_back(e.min_value);
ret.push_back(e.max_value);
}
return ret;
}


} // namespace arith
} // namespace tvm
6 changes: 3 additions & 3 deletions src/pass/narrow_channel_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ class ChannelAccessRewriter : public IRMutator {
r = Range::make_by_min_extent(
ir::Simplify(r->min), ir::Simplify(r->extent));
if (ExprUseVar(r->extent, var)) return body;
Array<Expr> linear_eq = DetectLinearEquation(r->min, var);
Array<Expr> linear_eq = DetectLinearEquation(r->min, {var});
if (linear_eq.size() == 0) return body;
Expr base = linear_eq[0];
Expr coeff = linear_eq[1];
Expr coeff = linear_eq[0];
Expr base = linear_eq[1];
if (!is_zero(base)) return body;
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
if (!can_prove(left >= 0)) return body;
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_arith_detect_clip_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import tvm

def test_basic():
a = tvm.var("a")
b = tvm.var("b")
c = tvm.var("c")
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
a - 1 > 0), [a])
assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
assert m[0].value == 2
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
a - 1 > 0), [a, b])
assert len(m) == 0
m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
b - 1 > 0), [a, b])
assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
assert tvm.ir_pass.Simplify(m[2] - 2).value == 0


if __name__ == "__main__":
test_basic()
41 changes: 30 additions & 11 deletions tests/python/unittest/test_arith_detect_linear_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,41 @@
def test_basic():
a = tvm.var("a")
b = tvm.var("b")
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, a)
assert m[1].value == 4
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7)).value == 0
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a])
assert m[0].value == 4
assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0

m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, a)
m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a])
assert len(m) == 0

m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, a)
assert m[1].value == 5
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0
m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, [a])
assert m[0].value == 5
assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0

m = tvm.arith.DetectLinearEquation(a * b + 7, a)
assert m[1] == b
m = tvm.arith.DetectLinearEquation(a * b + 7, [a])
assert m[0] == b

m = tvm.arith.DetectLinearEquation(b * 7, a)
assert m[1].value == 0
m = tvm.arith.DetectLinearEquation(b * 7, [a])
assert m[0].value == 0

def test_multivariate():
v = [tvm.var("v%d" % i) for i in range(4)]
b = tvm.var("b")
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5))
assert(m[1].value == 8)

m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
assert(len(m) == 0)

m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v)
assert(len(m) == 0)

m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v)
assert(m[1].value == 16)
assert(m[2].value == 2)
assert(m[len(m)-1].value == 2)

if __name__ == "__main__":
test_basic()
test_multivariate()