Skip to content

Commit

Permalink
[ARITH] Remove the legacy Simplify, migrate to Analyzer. (#5385)
Browse files Browse the repository at this point in the history
The legacy Simplify/CanonicalSimplify are now a thin wrapper around the Analyzer.
This PR removes these functions and migrated every place that requires
simplification to enforce Analyzer creation.
The new API would encourage more Analyzer sharing and potentially enable
context-aware analyzer-based simplification.
  • Loading branch information
tqchen authored Apr 21, 2020
1 parent b8efe27 commit d9cecdf
Show file tree
Hide file tree
Showing 56 changed files with 369 additions and 384 deletions.
52 changes: 26 additions & 26 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,16 @@ class ConstIntBoundAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr);
TVM_DLL ConstIntBound operator()(const PrimExpr& expr);

/*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
TVM_DLL ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);

/*!
* \brief Update constant int bound information of var.
Expand All @@ -130,22 +130,22 @@ class ConstIntBoundAnalyzer {
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
TVM_DLL void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const Var& var, const Range& range);
TVM_DLL void Bind(const Var& var, const Range& range);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(Analyzer* parent);
~ConstIntBoundAnalyzer();
TVM_DLL ~ConstIntBoundAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
Expand Down Expand Up @@ -212,23 +212,23 @@ class ModularSetAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ModularSet operator()(const PrimExpr& expr);
TVM_DLL ModularSet operator()(const PrimExpr& expr);
/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ModularSet& info,
bool override = false);
TVM_DLL void Update(const Var& var,
const ModularSet& info,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit ModularSetAnalyzer(Analyzer* parent);
~ModularSetAnalyzer();
TVM_DLL ~ModularSetAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
Expand All @@ -252,7 +252,7 @@ class RewriteSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
PrimExpr operator()(const PrimExpr& expr);
TVM_DLL PrimExpr operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
Expand All @@ -261,9 +261,9 @@ class RewriteSimplifier {
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);
TVM_DLL void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);

std::function<void()> EnterConstraint(const PrimExpr& constraint);

Expand All @@ -272,7 +272,7 @@ class RewriteSimplifier {
friend class ConstraintContext;
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
TVM_DLL ~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand All @@ -288,7 +288,7 @@ class CanonicalSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
PrimExpr operator()(const PrimExpr& expr);
TVM_DLL PrimExpr operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
Expand All @@ -297,15 +297,15 @@ class CanonicalSimplifier {
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);
TVM_DLL void Update(const Var& var,
const PrimExpr& new_expr,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
~CanonicalSimplifier();
TVM_DLL ~CanonicalSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand Down Expand Up @@ -363,12 +363,12 @@ class IntSetAnalyzer {
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);

private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
~IntSetAnalyzer();
TVM_DLL ~IntSetAnalyzer();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand All @@ -384,7 +384,7 @@ class IntSetAnalyzer {
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
class Analyzer {
class TVM_DLL Analyzer {
public:
/*
* Disable copy constructor.
Expand Down
33 changes: 0 additions & 33 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,6 @@
namespace tvm {
namespace tir {

/*!
* \brief Simplify the expression.
* \param expr The expression to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
TVM_DLL PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief Simplify the statement.
* \param stmt The statement to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt,
Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief Simplify by applying canonical form.
* \param expr The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized expression.
*/
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/autotvm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from random import randrange

import numpy as np

from tvm.tir import expr, ir_pass
import tvm.arith
from tvm.tir import expr

logger = logging.getLogger('autotvm')

Expand Down Expand Up @@ -156,7 +156,8 @@ def get_const_int(exp):
if isinstance(exp, int):
return exp
if not isinstance(exp, (expr.IntImm,)):
exp = ir_pass.Simplify(exp)
ana = tvm.arith.Analyzer()
exp = ana.simplify(exp)
if not isinstance(exp, (expr.IntImm,)):
raise ValueError("Expect value to be constant int")
return exp.value
Expand All @@ -180,7 +181,8 @@ def get_const_tuple(in_tuple):
if isinstance(elem, expr.Var):
ret.append(elem)
elif not isinstance(elem, (expr.IntImm, int)):
elem = ir_pass.Simplify(elem)
ana = tvm.arith.Analyzer()
elem = ana.simplify(elem)
if not isinstance(elem, (expr.IntImm)):
ret.append(elem)
else:
Expand Down
1 change: 1 addition & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def _build_for_device(input_mod, target, target_host):
lambda f: "calling_conv" in f.attrs and
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_mixed)
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import tvm.tir
import tvm.te
import tvm.te._ffi_api
import tvm.arith

from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.tir import ir_pass as _ir_pass
from tvm.te.tensor import Tensor, Operation
from tvm.tir import all as _all
from tvm.tir import any as _any
Expand Down Expand Up @@ -160,6 +160,7 @@ def __init__(self, args, usage, symbols, closure_vars, func_name=None):
self.outputs = [] # Output tensors' name
self.side_effect = set() # Tensors with side effects
self.parsed_body = None # The parsed HalideIR body
self.analyzer = tvm.arith.Analyzer()
self.returned = False # If this function has a valid return


Expand Down Expand Up @@ -326,7 +327,7 @@ def visit_Assign(self, node):
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0]
if isinstance(rhs, _expr.PrimExpr):
rhs = _ir_pass.Simplify(rhs)
rhs = self.analyzer.simplify(rhs)
if isinstance(lhs, ast.Name):
#TODO: support defined intermediate buffer later
lhs_ = lhs
Expand Down Expand Up @@ -410,7 +411,7 @@ def visit_With(self, node):


def visit_If(self, node):
cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
cond = self.analyzer.simplify(self.visit(node.test))

# Return no IfThenElse if proven
if isinstance(cond, _expr.IntImm):
Expand Down Expand Up @@ -501,8 +502,8 @@ def visit_For(self, node):
_name = node.target.id

if isinstance(for_type, tuple):
low = _ir_pass.CanonicalSimplify(low)
ext = _ir_pass.CanonicalSimplify(ext)
low = self.analyzer.simplify(low)
ext = self.analyzer.simplify(ext)
_internal_assert(isinstance(low, _expr.ConstExpr) and
isinstance(ext, _expr.ConstExpr), \
"Const range should start from a const " + \
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import logging
import numpy as np
import tvm
import tvm.arith
import tvm.tir
import tvm._ffi


Expand Down Expand Up @@ -168,4 +170,23 @@ def compare_derivative(j, n_der, grad):
x_name, grad.shape, dist, max_diff, avg_diff)


def assert_prim_expr_equal(lhs, rhs):
"""Assert lhs and rhs equals to each iother.
Parameters
----------
lhs : tvm.tir.PrimExpr
The left operand.
rhs : tvm.tir.PrimExpr
The left operand.
"""
ana = tvm.arith.Analyzer()
res = ana.simplify(lhs - rhs)
equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
if not equal:
raise ValueError("{} and {} are not equal".format(lhs, rhs))



tvm._ffi._init_api("testing", __name__)
3 changes: 1 addition & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from . import stmt as _stmt
from . import expr as _expr
from . import ir_pass as _pass


class WithScope(object):
Expand Down Expand Up @@ -212,7 +211,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
self.nidx += 1
self._seq_stack.append([])
loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else _pass.Simplify(end - begin)
extent = end if begin == 0 else (end - begin)
def _exit_cb():
if for_type == "serial":
for_type_id = 0
Expand Down
8 changes: 5 additions & 3 deletions src/arith/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ bool DetectClipBound(
return false;
}
LinearEqEntry ret;
Analyzer analyzer;
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = Simplify(ret.coeff);
ret.coeff = analyzer.Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()];
if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift
Expand Down Expand Up @@ -254,14 +255,15 @@ Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
for (PrimExpr cond : splits) {
if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>();
}
Analyzer analyzer;
Array<PrimExpr> ret;
for (Var v : vars) {
IntervalEntry e = rmap[v.get()];
if (e.min_value.defined()) {
e.min_value = Simplify(e.min_value);
e.min_value = analyzer.Simplify(e.min_value);
}
if (e.max_value.defined()) {
e.max_value = Simplify(e.max_value);
e.max_value = analyzer.Simplify(e.max_value);
}
ret.push_back(e.min_value);
ret.push_back(e.max_value);
Expand Down
Loading

0 comments on commit d9cecdf

Please sign in to comment.