Skip to content

Commit 99fae32

Browse files
committed
[ARITH][BOUND] Pass analyzer to PropBoundToInputs
1 parent 4c9acac commit 99fae32

File tree

9 files changed

+27
-7
lines changed

9 files changed

+27
-7
lines changed

include/tvm/operation.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,15 @@ class OperationNode : public FunctionBaseNode {
100100
/*!
101101
* \brief Propagate the bounds to inputs
102102
* \param self The reference to self.
103+
* \param analyzer The analyzer to be used in the function.
103104
* \param dom_map the domain map of Variables(corresponds to root_iter_vars)
104105
* \param out_dom_map The output domain.
105106
* The function is only asked to fill the bounds for Tensors that
106107
* is already in the out_dom_map
107108
*/
108109
virtual void PropBoundToInputs(
109110
const Operation& self,
111+
arith::Analyzer* analyzer,
110112
const std::unordered_map<const Variable*, IntSet>& dom_map,
111113
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
112114
/*!
@@ -170,6 +172,7 @@ class PlaceholderOpNode : public OperationNode {
170172
const std::unordered_map<Tensor, Tensor>& rmap) const final;
171173
void PropBoundToInputs(
172174
const Operation& self,
175+
arith::Analyzer* analyzer,
173176
const std::unordered_map<const Variable*, IntSet>& dom_map,
174177
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
175178
void GatherBound(
@@ -247,6 +250,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
247250
const std::unordered_map<Tensor, Tensor>& rmap) const final;
248251
void PropBoundToInputs(
249252
const Operation& self,
253+
arith::Analyzer* analyzer,
250254
const std::unordered_map<const Variable*, IntSet>& dom_map,
251255
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
252256
Stmt BuildProvide(
@@ -299,6 +303,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
299303
const std::unordered_map<Tensor, Tensor>& rmap) const final;
300304
void PropBoundToInputs(
301305
const Operation& self,
306+
arith::Analyzer* analyzer,
302307
const std::unordered_map<const Variable*, IntSet>& dom_map,
303308
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
304309
Stmt BuildProvide(
@@ -373,6 +378,7 @@ class ScanOpNode : public OperationNode {
373378
const std::unordered_map<Tensor, Tensor>& rmap) const final;
374379
void PropBoundToInputs(
375380
const Operation& self,
381+
arith::Analyzer* analyzer,
376382
const std::unordered_map<const Variable*, IntSet>& dom_map,
377383
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
378384
void GatherBound(
@@ -439,6 +445,7 @@ class ExternOpNode : public OperationNode {
439445
const std::unordered_map<Tensor, Tensor>& rmap) const final;
440446
void PropBoundToInputs(
441447
const Operation& self,
448+
arith::Analyzer* analyzer,
442449
const std::unordered_map<const Variable*, IntSet>& dom_map,
443450
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
444451
void GatherBound(
@@ -506,6 +513,7 @@ class HybridOpNode : public OperationNode {
506513
const std::unordered_map<Tensor, Tensor>& rmap) const final;
507514
void PropBoundToInputs(
508515
const Operation& self,
516+
arith::Analyzer* analyzer,
509517
const std::unordered_map<const Variable*, IntSet>& dom_map,
510518
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
511519
void GatherBound(

src/op/compute_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,11 @@ Operation ComputeOpNode::ReplaceInputs(
211211

212212
void ComputeOpNode::PropBoundToInputs(
213213
const Operation& self,
214+
arith::Analyzer* analyzer,
214215
const std::unordered_map<const Variable*, IntSet>& dom_map,
215216
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
216217
CHECK_EQ(self.operator->(), this);
217-
auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) {
218+
auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
218219
auto *call = n.as<ir::Call>();
219220
if (call != nullptr && call->func.defined()) {
220221
Tensor t = Operation(call->func.node_).output(call->value_index);
@@ -233,11 +234,12 @@ void ComputeOpNode::PropBoundToInputs(
233234
Expr min_value = arg_interval->min_value;
234235
Expr max_value = arg_interval->max_value;
235236
// Prefer the shape bounds only when we can prove they are tighter.
236-
arith::Analyzer an;
237-
if (arith::is_neg_inf(min_value) || an.CanProve(shape_i_min_value >= min_value)) {
237+
if (arith::is_neg_inf(min_value) ||
238+
analyzer->CanProve(shape_i_min_value >= min_value)) {
238239
min_value = shape_i_min_value;
239240
}
240-
if (arith::is_pos_inf(max_value) || an.CanProve(shape_i_max_value <= max_value)) {
241+
if (arith::is_pos_inf(max_value) ||
242+
analyzer->CanProve(shape_i_max_value <= max_value)) {
241243
max_value = shape_i_max_value;
242244
}
243245
dom.data[i].push_back(IntSet::interval(min_value, max_value));

src/op/extern_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Operation ExternOpNode::ReplaceInputs(
112112

113113
void ExternOpNode::PropBoundToInputs(
114114
const Operation& self,
115+
arith::Analyzer* analyzer,
115116
const std::unordered_map<const Variable*, IntSet>& dom_map,
116117
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
117118
for (Tensor t : this->inputs) {

src/op/hybrid_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Operation HybridOpNode::ReplaceInputs(
110110

111111
void HybridOpNode::PropBoundToInputs(
112112
const Operation &self,
113+
arith::Analyzer* analyzer,
113114
const std::unordered_map<const Variable*, IntSet> &dom_map,
114115
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
115116
for (Tensor t : this->inputs) {

src/op/placeholder_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Operation PlaceholderOpNode::ReplaceInputs(
7878

7979
void PlaceholderOpNode::PropBoundToInputs(
8080
const Operation& self,
81+
arith::Analyzer* analyzer,
8182
const std::unordered_map<const Variable*, IntSet>& dom_map,
8283
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
8384
}

src/op/scan_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ Operation ScanOpNode::ReplaceInputs(
176176

177177
void ScanOpNode::PropBoundToInputs(
178178
const Operation& self,
179+
arith::Analyzer* analyzer,
179180
const std::unordered_map<const Variable*, IntSet>& dom_map,
180181
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
181182
CHECK_EQ(self.operator->(), this);

src/op/tensor_compute_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Operation TensorComputeOpNode::ReplaceInputs(
110110

111111
void TensorComputeOpNode::PropBoundToInputs(
112112
const Operation& self,
113+
arith::Analyzer* analyzer,
113114
const std::unordered_map<const Variable*, IntSet>& dom_map,
114115
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
115116
for (size_t i = 0; i < this->inputs.size(); ++i) {

src/op/tensorize.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,20 @@ size_t InferTensorizeRegion(
8585
// Get domains if inputs
8686
std::unordered_map<Tensor, TensorDom> in_dom;
8787
std::unordered_map<const Variable*, IntSet> temp_dmap;
88+
arith::Analyzer analyzer;
8889
Array<Tensor> inputs = self->InputTensors();
8990
for (Tensor t : inputs) {
9091
in_dom.emplace(t, TensorDom(t.ndim()));
9192
}
9293
for (IterVar iv : self->root_iter_vars()) {
9394
IntSet iset = up_state.at(iv);
94-
(*out_dom)[iv] = iset.cover_range(dom_map.at(iv));
95+
Range iv_range = iset.cover_range(dom_map.at(iv));
96+
(*out_dom)[iv] = iv_range;
97+
analyzer.Bind(iv->var, iv_range);
9598
temp_dmap[iv->var.get()] = iset;
9699
}
97100
// Input domains
98-
self->PropBoundToInputs(stage->op, temp_dmap, &in_dom);
101+
self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
99102
Range none;
100103
for (const auto& kv : in_dom) {
101104
Array<Range> vec;

src/schedule/bound.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ void InferRootBound(const Stage& stage,
191191
PassUpDomain(op_stage, *rmap, &up_state);
192192
// Relax if needed.
193193
std::unordered_map<const Variable*, IntSet> dom_map;
194+
arith::Analyzer analyzer;
194195
for (auto iv : op->root_iter_vars()) {
195196
Range r;
196197
if (up_state.count(iv)) {
@@ -203,8 +204,9 @@ void InferRootBound(const Stage& stage,
203204
} else {
204205
dom_map[iv->var.get()] = IntSet::range(r);
205206
}
207+
analyzer.Bind(iv->var, r);
206208
}
207-
op->PropBoundToInputs(op, dom_map, &tmap);
209+
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
208210
}
209211
stage->op->GatherBound(stage->op, tmap, rmap);
210212
}

0 commit comments

Comments
 (0)