Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,8 @@ IntSet Intersect(const Array<IntSet>& sets);
*/
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
const Map<Var, IntSet>& relax_map,
const bool return_lower_bound=false);
/*!
* \brief Same as DeduceBound with unordered_map signature.
*
Expand All @@ -648,7 +649,8 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);
const std::unordered_map<const Variable*, IntSet>& relax_map,
const bool return_lower_bound=false);

/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
Expand Down
45 changes: 38 additions & 7 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \brief Utility to deduce bound of expression
*/
#include <tvm/expr.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/arithmetic.h>
Expand Down Expand Up @@ -78,8 +79,10 @@ class BoundDeducer: public IRVisitor {
friend class Converter;
BoundDeducer(Expr target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
const std::unordered_map<const Variable*, IntSet>& relax_map,
const bool return_lower_bound=false)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map),
return_lower_bound_(return_lower_bound){}

void Deduce();

Expand Down Expand Up @@ -149,7 +152,32 @@ class BoundDeducer: public IRVisitor {
// always use relax bound
bool divided = analyzer_.CanProve(result_ % operand == 0);

result_ = result_ / operand;
if (return_lower_bound_) {
/*
* If both sides are convertible to int constant, then
* use floordiv and evaluate the new constant.
* This particularly is useful for bound check for conditions on
* loop itervars. e.g., an expr like -5/8 should not evaluate to
* zero because that indicate that the loop iter var can take on
* value of 0 in loop_partition.cc.
* TODO: On the other hand, under what conditions floordiv is not
* desirable for int constant and you want round towards zero
* result?
*/
auto dividend = analyzer_.Simplify(result_);
auto divisor = analyzer_.Simplify(operand);
auto* dividend_const_value_ptr = as_const_int(dividend);
auto* divisor_const_value_ptr = as_const_int(divisor);
if (dividend_const_value_ptr && divisor_const_value_ptr) {
auto int_bound = analyzer_.const_int_bound(floordiv(dividend, divisor));
CHECK_EQ(int_bound->min_value, int_bound->max_value);
result_ = Expr(static_cast<int32_t>(int_bound->min_value));
} else {
result_ = result_ / operand;
}
} else {
result_ = result_ / operand;
}

if (!divided) {
// Handle non-divisible case
Expand Down Expand Up @@ -208,6 +236,7 @@ class BoundDeducer: public IRVisitor {
size_t iter_{0};
// internal analzyer
Analyzer analyzer_;
bool return_lower_bound_{false};
};

class BoundDeduceInputChecker: public IRVisitor {
Expand Down Expand Up @@ -345,8 +374,9 @@ void BoundDeducer::Relax() {

IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
const std::unordered_map<const Variable*, IntSet>& relax_map,
const bool return_lower_bound) {
BoundDeducer d(v, e, hint_map, relax_map, return_lower_bound);
d.Deduce();
if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf();
Expand All @@ -365,7 +395,8 @@ IntSet DeduceBound(Expr v, Expr e,
// return empty set to represent deduce failure.
IntSet DeduceBound(Expr v, Expr e,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
const Map<Var, IntSet>& relax_map,
const bool return_lower_bound) {
std::unordered_map<const Variable*, IntSet> hmap;
for (auto kv : hint_map) {
hmap[kv.first.get()] = kv.second;
Expand All @@ -374,7 +405,7 @@ IntSet DeduceBound(Expr v, Expr e,
for (auto kv : relax_map) {
rmap[kv.first.get()] = kv.second;
}
return DeduceBound(v, e, hmap, rmap);
return DeduceBound(v, e, hmap, rmap, return_lower_bound);
}

} // namespace arith
Expand Down
9 changes: 7 additions & 2 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,21 @@ class PartitionFinder : public IRVisitor {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
// Get the lower bound for the interval. This is specifically to get around the issue
// where interval be neg_inf, ...
// Here without asking for lower_bound max of the interval can be zero, due to round
// to zero of the integer division.
bool return_lower_bound(true);
IntSet interval =
DeduceBound(current_var_, cond, hint_map_, relax_map_);
DeduceBound(current_var_, cond, hint_map_, relax_map_, return_lower_bound);
if (!interval.is_nothing()) {
// cond is true within interval
partitions[{cond.get(), true}] = interval;
}
Expr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval =
DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_, return_lower_bound);
if (!interval.is_nothing()) {
// cond is false within interval
partitions[{cond.get(), false}] = interval;
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,26 @@ def test_conv_tiling():
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_multilevel_splitting_with_indivisble_factors():
import topi
A = tvm.placeholder((130,), dtype="float32")
B = topi.nn.relu(A)
s = tvm.create_schedule(B.op)
(y,) = s[B].op.axis
(yo, yi) = s[B].split(y, factor=8)
(yoo, yoi) = s[B].split(yo, factor=16)
s[B].reorder(yoo, yoi, yi)
s[B].unroll(yi)

# But this does the right thing.
with tvm.build_config(partition_const_loop=True):
lowered_body = tvm.lower(s, [A, B]).body
def visit_stmt(op):
return(isinstance(op, tvm.expr.Max))
num_max = collect_visit(lowered_body, visit_stmt)
assert num_max.count(True) == 10


def test_double_splitting_with_indivisible_factors():
m = 48
dtype="float32"
Expand Down Expand Up @@ -443,4 +463,5 @@ def test_simple_rfactor():
test_cce_loop_3()
test_conv_tiling()
test_double_splitting_with_indivisible_factors()
test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor()