Skip to content

Commit f2abd9f

Browse files
ajtullochtqchen
authored andcommitted
[TVM] Rewrite simplification rule to eliminate unnecessary conditionals. (#4076)
The current bounds checking infrastructure inserts checks like: ``` for (i, 0, bounds[n]) { if (likely(i < bounds[n]) { ... } } ``` into the TVM IR which is currently not removed by simplification infrastructure. This is a little unclean, as these are trivially true since for a loop var `i` with a given min and extent, we are guaranteed that `i >= min` and `i < min + extent`. Thus, we can insert these checks into the IR and use them to eliminate trivial bounds checks early on.
1 parent c12275e commit f2abd9f

File tree

6 files changed

+97
-1
lines changed

6 files changed

+97
-1
lines changed

include/tvm/arithmetic.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ class RewriteSimplifier {
245245
const Expr& new_expr,
246246
bool override = false);
247247

248+
std::function<void()> EnterConstraint(const Expr& constraint);
249+
248250
private:
249251
friend class Analyzer;
250252
friend class ConstraintContext;

src/arithmetic/analyzer.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ void ConstraintContext::EnterWithScope() {
6767
// entering the scope.
6868
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
6969
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
70+
auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
7071
// recovery function.
71-
exit_ = [f0, f1]() {
72+
exit_ = [f0, f1, f2]() {
73+
if (f2 != nullptr) f2();
7274
if (f1 != nullptr) f1();
7375
if (f0 != nullptr) f0();
7476
};

src/arithmetic/rewrite_simplify.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,17 @@ Mutate_(const Add* op, const Expr& self) {
220220
return ret;
221221
}
222222

223+
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& constraint) {
224+
size_t old_literal_size = literal_constraints_.size();
225+
literal_constraints_.push_back(constraint);
226+
size_t new_literal_size = literal_constraints_.size();
227+
auto frecover = [old_literal_size, new_literal_size, this]() {
228+
CHECK_EQ(literal_constraints_.size(), new_literal_size);
229+
literal_constraints_.resize(old_literal_size);
230+
};
231+
return frecover;
232+
}
233+
223234
Expr RewriteSimplifier::Impl::
224235
Mutate_(const Sub* op, const Expr& self) {
225236
Expr ret = IRMutator::Mutate_(op, self);
@@ -1705,6 +1716,14 @@ Mutate_(const Call* op, const Expr& self) {
17051716
return op->args[0] & op->args[1];
17061717
}
17071718
}
1719+
if (op->is_intrinsic(Call::likely)) {
1720+
for (const auto& constraint : literal_constraints_) {
1721+
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
1722+
if (Equal(constraint, op->args[0])) {
1723+
return make_const(op->type, true);
1724+
}
1725+
}
1726+
}
17081727
return ret;
17091728
}
17101729

@@ -1761,6 +1780,10 @@ void RewriteSimplifier::Update(const Var& var,
17611780
impl_->Update(var, info, override);
17621781
}
17631782

1783+
std::function<void()> RewriteSimplifier::EnterConstraint(const Expr& constraint) {
1784+
return impl_->EnterConstraint(constraint);
1785+
}
1786+
17641787
RewriteSimplifier::RewriteSimplifier(Analyzer* parent)
17651788
: impl_(new Impl(parent)) {
17661789
}

src/arithmetic/rewrite_simplify.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/expr_operator.h>
2929
#include <tvm/ir_mutator.h>
3030
#include <unordered_map>
31+
#include <vector>
3132
#include "const_fold.h"
3233
#include "pattern_match.h"
3334
#include "ir_mutator_with_analyzer.h"
@@ -74,6 +75,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
7475
Expr Mutate_(const Cast* op, const Expr& self) override;
7576
Expr Mutate_(const Let* op, const Expr& self) override;
7677

78+
std::function<void()> EnterConstraint(const Expr& constraint);
79+
7780
protected:
7881
/*! \brief internal structure for comparison. */
7982
enum CompareResult {
@@ -89,6 +92,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
8992
int recur_depth_{0};
9093
// internal variable map
9194
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
95+
96+
std::vector<Expr> literal_constraints_;
97+
9298
// maximum number of recursion allowed during a single pass.
9399
static const constexpr int kMaxRecurDepth = 5;
94100

src/arithmetic/stmt_simplify.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
5151
return Mutate(stmt);
5252
}
5353

54+
Stmt Mutate_(const For* op, const Stmt& s) final {
55+
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
56+
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
57+
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
58+
return IRMutator::Mutate_(op, s);
59+
}
60+
5461
Stmt Mutate_(const LetStmt* op, const Stmt& s) {
5562
Expr value = this->Mutate(op->value);
5663
if (!ir::HasSideEffect(value)) {

tests/python/unittest/test_arith_stmt_simplify.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,62 @@ def test_thread_extent_simplify():
4747
assert isinstance(body.body.body.body, tvm.stmt.Store)
4848

4949

50+
def test_basic_likely_elimination():
51+
n = tvm.var('n')
52+
X = tvm.placeholder(shape=(n,), name="x")
53+
W = tvm.placeholder(shape=(n + 1,), dtype="int32", name="w")
54+
55+
def f(i):
56+
start = W[i]
57+
extent = W[i+1] - W[i]
58+
rv = tvm.reduce_axis((0, extent))
59+
return tvm.sum(X[rv + start], axis=rv)
60+
Y = tvm.compute(X.shape, f, name="y")
61+
s = tvm.create_schedule([Y.op])
62+
stmt = tvm.lower(s, [X, W, Y], simple_mode=True)
63+
assert('if' not in str(stmt))
64+
65+
def test_complex_likely_elimination():
66+
def cumsum(X):
67+
"""
68+
Y[i] = sum(X[:i])
69+
"""
70+
(m, ) = X.shape
71+
s_state = tvm.placeholder((m + 1, ), dtype="int32", name="state")
72+
s_init = tvm.compute((1, ), lambda _: tvm.const(0, "int32"))
73+
s_update = tvm.compute((m + 1, ), lambda l: s_state[l - 1] + X[l - 1])
74+
return tvm.scan(s_init, s_update, s_state, inputs=[X], name="cumsum")
75+
76+
def sparse_lengths_sum(data, indices, lengths):
77+
oshape = list(data.shape)
78+
oshape[0] = lengths.shape[0]
79+
length_offsets = cumsum(lengths)
80+
81+
def sls(n, d):
82+
gg = tvm.reduce_axis((0, lengths[n]))
83+
indices_idx = length_offsets[n] + gg
84+
data_idx = indices[indices_idx]
85+
data_val = data[data_idx, d]
86+
return tvm.sum(data_val, axis=gg)
87+
88+
return tvm.compute(oshape, sls)
89+
90+
m, n, d, i, l = tvm.var('m'), tvm.var('n'), tvm.var('d'), tvm.var('i'), tvm.var('l')
91+
data_ph = tvm.placeholder((m, d * 32), name="data")
92+
indices_ph = tvm.placeholder((i,), name="indices", dtype="int32")
93+
lengths_ph = tvm.placeholder((n,), name="lengths", dtype="int32")
94+
Y = sparse_lengths_sum(data_ph, indices_ph, lengths_ph)
95+
s = tvm.create_schedule([Y.op])
96+
(n, d) = s[Y].op.axis
97+
(do, di) = s[Y].split(d, factor=32)
98+
(gg,) = s[Y].op.reduce_axis
99+
s[Y].reorder(n, do, gg, di)
100+
s[Y].vectorize(di)
101+
stmt = tvm.lower(s, [data_ph, indices_ph, lengths_ph, Y], simple_mode=True)
102+
assert('if' not in str(stmt))
103+
50104
if __name__ == "__main__":
51105
test_stmt_simplify()
52106
test_thread_extent_simplify()
107+
test_basic_likely_elimination()
108+
test_complex_likely_elimination()

0 commit comments

Comments
 (0)