Skip to content

Commit 8499d01

Browse files
tqchendhruvaray
authored andcommitted
[ARITH] Remove the legacy Simplify, migrate to Analyzer. (apache#5385)
The legacy Simplify/CanonicalSimplify are now a thin wrapper around the Analyzer. This PR removes these functions and migrated every place that requires simplification to enforce Analyzer creation. The new API would encourage more Analyzer sharing and potentially enable context-aware analyzer-based simplification.
1 parent 118f943 commit 8499d01

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+369
-384
lines changed

include/tvm/arith/analyzer.h

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,16 @@ class ConstIntBoundAnalyzer {
112112
* \param expr The expression of interest.
113113
* \return the result of the analysis.
114114
*/
115-
ConstIntBound operator()(const PrimExpr& expr);
115+
TVM_DLL ConstIntBound operator()(const PrimExpr& expr);
116116

117117
/*!
118118
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
119119
* \param expr The expression of interest.
120120
* \param bound The lookup table to store the intermediate results
121121
* \return the result of the analysis.
122122
*/
123-
ConstIntBound operator()(const PrimExpr& expr,
124-
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
123+
TVM_DLL ConstIntBound operator()(const PrimExpr& expr,
124+
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
125125

126126
/*!
127127
* \brief Update constant int bound information of var.
@@ -130,22 +130,22 @@ class ConstIntBoundAnalyzer {
130130
* \param info The bound information.
131131
* \param override Whether do we allow override of existing information.
132132
*/
133-
void Update(const Var& var,
134-
const ConstIntBound& info,
135-
bool override = false);
133+
TVM_DLL void Update(const Var& var,
134+
const ConstIntBound& info,
135+
bool override = false);
136136
/*!
137137
* \brief Bind variable to a range.
138138
*
139139
* \param var The variable.
140140
* \param range The range we bind to.
141141
*/
142-
void Bind(const Var& var, const Range& range);
142+
TVM_DLL void Bind(const Var& var, const Range& range);
143143

144144
private:
145145
friend class Analyzer;
146146
friend class ConstraintContext;
147147
explicit ConstIntBoundAnalyzer(Analyzer* parent);
148-
~ConstIntBoundAnalyzer();
148+
TVM_DLL ~ConstIntBoundAnalyzer();
149149
/*!
150150
* \brief Update the internal state to enter constraint.
151151
* \param constraint A constraint expression.
@@ -212,23 +212,23 @@ class ModularSetAnalyzer {
212212
* \param expr The expression of interest.
213213
* \return the result of the analysis.
214214
*/
215-
ModularSet operator()(const PrimExpr& expr);
215+
TVM_DLL ModularSet operator()(const PrimExpr& expr);
216216
/*!
217217
* \brief Update constant int bound information of var.
218218
*
219219
* \param var The variable of interest.
220220
* \param info The bound information.
221221
* \param override Whether do we allow override of existing information.
222222
*/
223-
void Update(const Var& var,
224-
const ModularSet& info,
225-
bool override = false);
223+
TVM_DLL void Update(const Var& var,
224+
const ModularSet& info,
225+
bool override = false);
226226

227227
private:
228228
friend class Analyzer;
229229
friend class ConstraintContext;
230230
explicit ModularSetAnalyzer(Analyzer* parent);
231-
~ModularSetAnalyzer();
231+
TVM_DLL ~ModularSetAnalyzer();
232232
/*!
233233
* \brief Update the internal state to enter constraint.
234234
* \param constraint A constraint expression.
@@ -252,7 +252,7 @@ class RewriteSimplifier {
252252
* \param expr The expression of interest.
253253
* \return the result of the analysis.
254254
*/
255-
PrimExpr operator()(const PrimExpr& expr);
255+
TVM_DLL PrimExpr operator()(const PrimExpr& expr);
256256

257257
/*!
258258
* \brief Update binding of var to a new expression.
@@ -261,9 +261,9 @@ class RewriteSimplifier {
261261
* \param new_expr
262262
* \param override Whether do we allow override of existing information.
263263
*/
264-
void Update(const Var& var,
265-
const PrimExpr& new_expr,
266-
bool override = false);
264+
TVM_DLL void Update(const Var& var,
265+
const PrimExpr& new_expr,
266+
bool override = false);
267267

268268
std::function<void()> EnterConstraint(const PrimExpr& constraint);
269269

@@ -272,7 +272,7 @@ class RewriteSimplifier {
272272
friend class ConstraintContext;
273273
friend class CanonicalSimplifier;
274274
explicit RewriteSimplifier(Analyzer* parent);
275-
~RewriteSimplifier();
275+
TVM_DLL ~RewriteSimplifier();
276276
class Impl;
277277
/*! \brief Internal impl */
278278
Impl* impl_;
@@ -288,7 +288,7 @@ class CanonicalSimplifier {
288288
* \param expr The expression of interest.
289289
* \return the result of the analysis.
290290
*/
291-
PrimExpr operator()(const PrimExpr& expr);
291+
TVM_DLL PrimExpr operator()(const PrimExpr& expr);
292292

293293
/*!
294294
* \brief Update binding of var to a new expression.
@@ -297,15 +297,15 @@ class CanonicalSimplifier {
297297
* \param new_expr
298298
* \param override Whether do we allow override of existing information.
299299
*/
300-
void Update(const Var& var,
301-
const PrimExpr& new_expr,
302-
bool override = false);
300+
TVM_DLL void Update(const Var& var,
301+
const PrimExpr& new_expr,
302+
bool override = false);
303303

304304
private:
305305
friend class Analyzer;
306306
friend class ConstraintContext;
307307
explicit CanonicalSimplifier(Analyzer* parent);
308-
~CanonicalSimplifier();
308+
TVM_DLL ~CanonicalSimplifier();
309309
class Impl;
310310
/*! \brief Internal impl */
311311
Impl* impl_;
@@ -363,12 +363,12 @@ class IntSetAnalyzer {
363363
* \param dom_map The domain map to indicate which variable to relax.
364364
* \return the result of the analysis.
365365
*/
366-
IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
366+
TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
367367

368368
private:
369369
friend class Analyzer;
370370
explicit IntSetAnalyzer(Analyzer* parent);
371-
~IntSetAnalyzer();
371+
TVM_DLL ~IntSetAnalyzer();
372372
class Impl;
373373
/*! \brief Internal impl */
374374
Impl* impl_;
@@ -384,7 +384,7 @@ class IntSetAnalyzer {
384384
* If the analyzer uses memoization, we need to clear the internal
385385
* cache when information about a Var has been overridden.
386386
*/
387-
class Analyzer {
387+
class TVM_DLL Analyzer {
388388
public:
389389
/*
390390
* Disable copy constructor.

include/tvm/tir/ir_pass.h

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,6 @@
4141
namespace tvm {
4242
namespace tir {
4343

44-
/*!
45-
* \brief Simplify the expression.
46-
* \param expr The expression to be simplifed.
47-
* \param vrange The range information about the variable.
48-
* \return Canonicalized statement.
49-
*/
50-
TVM_DLL PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange = Map<Var, Range>());
51-
52-
/*!
53-
* \brief Simplify the statement.
54-
* \param stmt The statement to be simplifed.
55-
* \param vrange The range information about the variable.
56-
* \return Canonicalized statement.
57-
*/
58-
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
59-
60-
/*!
61-
* \brief Simplify by applying canonical form.
62-
* \param stmt The statement to be canonically simplifed.
63-
* \param vrange The range information about the variable.
64-
* \return Canonicalized statement.
65-
*/
66-
Stmt CanonicalSimplify(Stmt stmt,
67-
Map<Var, Range> vrange = Map<Var, Range>());
68-
69-
/*!
70-
* \brief Simplify by applying canonical form.
71-
* \param expr The statement to be canonically simplifed.
72-
* \param vrange The range information about the variable.
73-
* \return Canonicalized expression.
74-
*/
75-
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
76-
Map<Var, Range> vrange = Map<Var, Range>());
7744

7845
/*!
7946
* \brief verifies whether the IR stmt or Expr is in SSA form.

python/tvm/autotvm/util.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from random import randrange
2424

2525
import numpy as np
26-
27-
from tvm.tir import expr, ir_pass
26+
import tvm.arith
27+
from tvm.tir import expr
2828

2929
logger = logging.getLogger('autotvm')
3030

@@ -156,7 +156,8 @@ def get_const_int(exp):
156156
if isinstance(exp, int):
157157
return exp
158158
if not isinstance(exp, (expr.IntImm,)):
159-
exp = ir_pass.Simplify(exp)
159+
ana = tvm.arith.Analyzer()
160+
exp = ana.simplify(exp)
160161
if not isinstance(exp, (expr.IntImm,)):
161162
raise ValueError("Expect value to be constant int")
162163
return exp.value
@@ -180,7 +181,8 @@ def get_const_tuple(in_tuple):
180181
if isinstance(elem, expr.Var):
181182
ret.append(elem)
182183
elif not isinstance(elem, (expr.IntImm, int)):
183-
elem = ir_pass.Simplify(elem)
184+
ana = tvm.arith.Analyzer()
185+
elem = ana.simplify(elem)
184186
if not isinstance(elem, (expr.IntImm)):
185187
ret.append(elem)
186188
else:

python/tvm/driver/build_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def _build_for_device(input_mod, target, target_host):
287287
lambda f: "calling_conv" in f.attrs and
288288
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
289289
tvm.tir.transform.LowerWarpMemory(),
290+
tvm.tir.transform.Simplify(),
290291
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
291292
tvm.tir.transform.LowerIntrin()])
292293
mod_dev = opt_device(mod_mixed)

python/tvm/te/hybrid/parser.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
import tvm.tir
3030
import tvm.te
3131
import tvm.te._ffi_api
32+
import tvm.arith
3233

3334
from tvm.tir import expr as _expr
3435
from tvm.tir import stmt as _stmt
35-
from tvm.tir import ir_pass as _ir_pass
3636
from tvm.te.tensor import Tensor, Operation
3737
from tvm.tir import all as _all
3838
from tvm.tir import any as _any
@@ -160,6 +160,7 @@ def __init__(self, args, usage, symbols, closure_vars, func_name=None):
160160
self.outputs = [] # Output tensors' name
161161
self.side_effect = set() # Tensors with side effects
162162
self.parsed_body = None # The parsed HalideIR body
163+
self.analyzer = tvm.arith.Analyzer()
163164
self.returned = False # If this function has a valid return
164165

165166

@@ -326,7 +327,7 @@ def visit_Assign(self, node):
326327
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
327328
lhs = node.targets[0]
328329
if isinstance(rhs, _expr.PrimExpr):
329-
rhs = _ir_pass.Simplify(rhs)
330+
rhs = self.analyzer.simplify(rhs)
330331
if isinstance(lhs, ast.Name):
331332
#TODO: support defined intermediate buffer later
332333
lhs_ = lhs
@@ -410,7 +411,7 @@ def visit_With(self, node):
410411

411412

412413
def visit_If(self, node):
413-
cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
414+
cond = self.analyzer.simplify(self.visit(node.test))
414415

415416
# Return no IfThenElse if proven
416417
if isinstance(cond, _expr.IntImm):
@@ -501,8 +502,8 @@ def visit_For(self, node):
501502
_name = node.target.id
502503

503504
if isinstance(for_type, tuple):
504-
low = _ir_pass.CanonicalSimplify(low)
505-
ext = _ir_pass.CanonicalSimplify(ext)
505+
low = self.analyzer.simplify(low)
506+
ext = self.analyzer.simplify(ext)
506507
_internal_assert(isinstance(low, _expr.ConstExpr) and
507508
isinstance(ext, _expr.ConstExpr), \
508509
"Const range should start from a const " + \

python/tvm/testing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import logging
2121
import numpy as np
2222
import tvm
23+
import tvm.arith
24+
import tvm.tir
2325
import tvm._ffi
2426

2527

@@ -168,4 +170,23 @@ def compare_derivative(j, n_der, grad):
168170
x_name, grad.shape, dist, max_diff, avg_diff)
169171

170172

173+
def assert_prim_expr_equal(lhs, rhs):
174+
"""Assert lhs and rhs equals to each iother.
175+
176+
Parameters
177+
----------
178+
lhs : tvm.tir.PrimExpr
179+
The left operand.
180+
181+
rhs : tvm.tir.PrimExpr
182+
The left operand.
183+
"""
184+
ana = tvm.arith.Analyzer()
185+
res = ana.simplify(lhs - rhs)
186+
equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
187+
if not equal:
188+
raise ValueError("{} and {} are not equal".format(lhs, rhs))
189+
190+
191+
171192
tvm._ffi._init_api("testing", __name__)

python/tvm/tir/ir_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from . import stmt as _stmt
2323
from . import expr as _expr
24-
from . import ir_pass as _pass
2524

2625

2726
class WithScope(object):
@@ -212,7 +211,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
212211
self.nidx += 1
213212
self._seq_stack.append([])
214213
loop_var = _expr.Var(name, dtype=dtype)
215-
extent = end if begin == 0 else _pass.Simplify(end - begin)
214+
extent = end if begin == 0 else (end - begin)
216215
def _exit_cb():
217216
if for_type == "serial":
218217
for_type_id = 0

src/arith/detect_linear_equation.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ bool DetectClipBound(
207207
return false;
208208
}
209209
LinearEqEntry ret;
210+
Analyzer analyzer;
210211
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
211-
ret.coeff = Simplify(ret.coeff);
212+
ret.coeff = analyzer.Simplify(ret.coeff);
212213
IntervalEntry& p = (*bmap)[var.get()];
213214
if (is_const_int(ret.coeff, 1)) {
214215
// var + shift >=0 -> var >= -shift
@@ -254,14 +255,15 @@ Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
254255
for (PrimExpr cond : splits) {
255256
if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>();
256257
}
258+
Analyzer analyzer;
257259
Array<PrimExpr> ret;
258260
for (Var v : vars) {
259261
IntervalEntry e = rmap[v.get()];
260262
if (e.min_value.defined()) {
261-
e.min_value = Simplify(e.min_value);
263+
e.min_value = analyzer.Simplify(e.min_value);
262264
}
263265
if (e.max_value.defined()) {
264-
e.max_value = Simplify(e.max_value);
266+
e.max_value = analyzer.Simplify(e.max_value);
265267
}
266268
ret.push_back(e.min_value);
267269
ret.push_back(e.max_value);

0 commit comments

Comments
 (0)