Skip to content

Commit

Permalink
[ARITH][BOUND] Fix bound inference to avoid allocating too much (#3526)
Browse files Browse the repository at this point in the history
* [TVM] Fix bound inference to avoid allocating too much

* [ARITH][BOUND] Pass analyzer to PropBoundToInputs
  • Loading branch information
sgrechanik-h authored and tqchen committed Jul 14, 2019
1 parent 75892d2 commit 9fad94c
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 20 deletions.
8 changes: 8 additions & 0 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ class OperationNode : public ir::FunctionBaseNode {
/*!
* \brief Propagate the bounds to inputs
* \param self The reference to self.
* \param analyzer The analyzer to be used in the function.
* \param dom_map the domain map of Variables(corresponds to root_iter_vars)
* \param out_dom_map The output domain.
* The function is only asked to fill the bounds for Tensors that
* is already in the out_dom_map
*/
virtual void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
Expand Down Expand Up @@ -170,6 +172,7 @@ class PlaceholderOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down Expand Up @@ -247,6 +250,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
Expand Down Expand Up @@ -299,6 +303,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
Expand Down Expand Up @@ -373,6 +378,7 @@ class ScanOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down Expand Up @@ -439,6 +445,7 @@ class ExternOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down Expand Up @@ -506,6 +513,7 @@ class HybridOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
Expand Down
29 changes: 27 additions & 2 deletions src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/int_set.h"

namespace tvm {

Expand Down Expand Up @@ -209,17 +210,41 @@ Operation ComputeOpNode::ReplaceInputs(

void ComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) {
auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && out_dom_map->count(t)) {
TensorDom& dom = out_dom_map->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
dom.data[i].push_back(EvalSet(call->args[i], dom_map));
// We assume that the value of the argument cannot be out of bounds (otherwise it is
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = EvalSet(call->args[i], dom_map);
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
Expr shape_i_min_value = make_zero(t->shape[i].type());
Expr shape_i_max_value = t->shape[i] - 1;
Expr min_value = arg_interval->min_value;
Expr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) ||
analyzer->CanProve(shape_i_min_value >= min_value)) {
min_value = shape_i_min_value;
}
if (arith::is_pos_inf(max_value) ||
analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
} else {
dom.data[i].push_back(arg_intset);
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/op/extern_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Operation ExternOpNode::ReplaceInputs(

void ExternOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
Expand Down
1 change: 1 addition & 0 deletions src/op/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Operation HybridOpNode::ReplaceInputs(

void HybridOpNode::PropBoundToInputs(
const Operation &self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
Expand Down
1 change: 1 addition & 0 deletions src/op/placeholder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Operation PlaceholderOpNode::ReplaceInputs(

void PlaceholderOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
}
Expand Down
1 change: 1 addition & 0 deletions src/op/scan_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ Operation ScanOpNode::ReplaceInputs(

void ScanOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
Expand Down
1 change: 1 addition & 0 deletions src/op/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Operation TensorComputeOpNode::ReplaceInputs(

void TensorComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (size_t i = 0; i < this->inputs.size(); ++i) {
Expand Down
7 changes: 5 additions & 2 deletions src/op/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,20 @@ size_t InferTensorizeRegion(
// Get domains if inputs
std::unordered_map<Tensor, TensorDom> in_dom;
std::unordered_map<const Variable*, IntSet> temp_dmap;
arith::Analyzer analyzer;
Array<Tensor> inputs = self->InputTensors();
for (Tensor t : inputs) {
in_dom.emplace(t, TensorDom(t.ndim()));
}
for (IterVar iv : self->root_iter_vars()) {
IntSet iset = up_state.at(iv);
(*out_dom)[iv] = iset.cover_range(dom_map.at(iv));
Range iv_range = iset.cover_range(dom_map.at(iv));
(*out_dom)[iv] = iv_range;
analyzer.Bind(iv->var, iv_range);
temp_dmap[iv->var.get()] = iset;
}
// Input domains
self->PropBoundToInputs(stage->op, temp_dmap, &in_dom);
self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
Range none;
for (const auto& kv : in_dom) {
Array<Range> vec;
Expand Down
4 changes: 3 additions & 1 deletion src/schedule/bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ void InferRootBound(const Stage& stage,
PassUpDomain(op_stage, *rmap, &up_state);
// Relax if needed.
std::unordered_map<const Variable*, IntSet> dom_map;
arith::Analyzer analyzer;
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
Expand All @@ -203,8 +204,9 @@ void InferRootBound(const Stage& stage,
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
analyzer.Bind(iv->var, r);
}
op->PropBoundToInputs(op, dom_map, &tmap);
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
}
stage->op->GatherBound(stage->op, tmap, rmap);
}
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,27 @@ def _body():
assert isinstance(bounds, tvm.container.Map)
assert(bounds[B.op.axis[0]].extent.value == 10)

def test_bound_simplification_failure():
# Check that the bounds are not expanded
A = tvm.compute((2,), lambda j: j, "A")

def _check(B, A=A):
s = tvm.create_schedule(B.op)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.lower(s, [B, A], simple_mode=True)
if not bounds[A.op.axis[0]].extent.value <= 2:
print(stmt)
assert bounds[A.op.axis[0]].extent.value <= 2

# These are hard to simplify, moreover we don't simplify them
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
_check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
# This would cause out of bounds, but we nevertheless include it
_check(tvm.compute((10,), lambda i: A[i]))

if __name__ == "__main__":
test_bound_nest_thread()
test_bound1()
Expand All @@ -320,3 +341,4 @@ def _body():
test_gemm_bound()
test_bound_warp()
test_bound_tensor_compute_op()
test_bound_simplification_failure()
15 changes: 0 additions & 15 deletions tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,20 +286,6 @@ def _compute(*indice):
stmt = tvm.schedule.ScheduleOps(s, bounds)


def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.if_then_else(
tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
AL1 = s.cache_read(A,"local",[Apad])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))


def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
Expand Down Expand Up @@ -514,7 +500,6 @@ def _compute(*index) :
test_schedule1()
test_schedule2()
test_schedule_cache()
test_schedule_bound_condition()
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()
Expand Down

0 comments on commit 9fad94c

Please sign in to comment.