Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,35 @@ class RewriteSimplifier {
* (n < 10) || (n < 5) => (n < 5)
*/
kApplyConstraintsToBooleanBranches = (1 << 2),

/* Special handling for expressions `(A+B)*C < (A*B)*D`
*
* Expressions of the form `(A+B)*C < (A*B)*D` can occur occur
* when comparing the number of operations required for two
* different orderings in which matrix multiplications can be
* performed. Proving or disproving this conditional allows an
* optimal order of execution to be selected, even for dynamic
* argument shapes.
*
* The default behavior of `ConstIntBounds` assumes that each term
* in an expression is independent, and is insufficient to prove
* these inequalities. For example, the maximum value of `(A+B)*C
* - (A*B)*D` is determined by taking the maximum value of
* `(A+B)*C` and subtracting the minimum value of `(A*B)*D`.
* While this algorithm can be applied in all cases, the bound it
* provides is looser than strictly required.
*
* This extension adds a check for this case. When `A`, `B`, `C`,
* and `D` are all positive values, as is the case for tensor
* shapes, the inequality can be written as `1/A + 1/B < D/C`. If
* this inequality holds for the minimum values of `A`, `B`, and
* `D`, along with the maximum value of `C`, then the inequality
* holds for all values.
*
* This extension requires little to no performance overhead, and
* may be enabled by default in future releases.
*/
kComparisonOfProductAndSum = (1 << 3),
};

/*! \brief Enable an optional extension or extensions
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
estimate_region_strict_bound,
estimate_region_upper_bound,
)
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength, Extension
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
from .int_solver import solve_linear_equations, solve_linear_inequalities
Expand Down
38 changes: 36 additions & 2 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Arithmetic data structure and utility"""
from enum import IntEnum
import enum
from typing import Union

import tvm._ffi
Expand All @@ -26,13 +26,26 @@
from . import _ffi_api


class ProofStrength(IntEnum):
class ProofStrength(enum.IntEnum):
"""Proof strength of the analysis"""

DEFAULT = 0
SYMBOLIC_BOUND = 1


class Extension(enum.Flag):
"""Extensions enabled for RewriteSimplifier

Values should match `RewriteSimplifier::Extensions`
"""

NoExtensions = 0
TransitivelyProveInequalities = 1 << 0
ConvertBooleanToAndOfOrs = 1 << 1
ApplyConstraintsToBooleanBranches = 1 << 2
ComparisonOfProductAndSum = 1 << 3


@tvm._ffi.register_object("arith.ModularSet")
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z"""
Expand Down Expand Up @@ -107,6 +120,8 @@ def __init__(self):
self._enter_constraint_context = _mod("enter_constraint_context")
self._can_prove_equal = _mod("can_prove_equal")
self._can_prove = _mod("can_prove")
self._get_enabled_extensions = _mod("get_enabled_extensions")
self._set_enabled_extensions = _mod("set_enabled_extensions")

def const_int_bound(self, expr):
"""Find constant integer bound for expr.
Expand Down Expand Up @@ -311,3 +326,22 @@ def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"):
Whether we can prove that lhs == rhs
"""
return self._can_prove_equal(lhs, rhs)

@property
def enabled_extensions(self) -> Extension:
"""Return the currently enabled extensions"""
value = self._get_enabled_extensions()
return Extension(value)

@enabled_extensions.setter
def enabled_extensions(self, flags: Union[int, Extension]):
"""Enable extensions for the analyzer

Parameters
----------
flags: Union[int,Extension]

The extensions to enable.
"""
flags = Extension(flags).value
self._set_enabled_extensions(flags)
10 changes: 10 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,16 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
} else if (name == "can_prove_equal") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); });
} else if (name == "get_enabled_extensions") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<std::int64_t>(self->rewrite_simplify.GetEnabledExtensions());
});
} else if (name == "set_enabled_extensions") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
std::int64_t flags = args[0];
self->rewrite_simplify.SetEnabledExtensions(
static_cast<RewriteSimplifier::Extension>(flags));
});
}
return PackedFunc();
};
Expand Down
168 changes: 0 additions & 168 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,6 @@ class ConstIntBoundAnalyzer::Impl
ret.min_value = InfAwareAdd(a.min_value, b.min_value);
ret.max_value = InfAwareAdd(a.max_value, b.max_value);

if (auto bound = BoundUsingReciprocal(GetRef<PrimExpr>(op))) {
ret = Intersect(ret, bound.value());
}

return ret;
}

Expand All @@ -254,12 +250,6 @@ class ConstIntBoundAnalyzer::Impl
ret.min_value = InfAwareAdd(a.min_value, -b.max_value);
ret.max_value = InfAwareAdd(a.max_value, -b.min_value);

if (auto bound = BoundUsingReciprocal(GetRef<Sub>(op))) {
ret = Intersect(ret, bound.value());
}
if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) {
ret = Intersect(ret, Negative(bound.value()));
}
return ret;
}

Expand Down Expand Up @@ -775,164 +765,6 @@ class ConstIntBoundAnalyzer::Impl
std::ceil(std::log2(arg_bounds.max_value)));
}
}

std::optional<Entry> BoundUsingReciprocal(PrimExpr expr) {
// Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on
// previous simplifications, the exact form of the expression may vary.
auto opt_special_case = [&]() -> std::optional<std::tuple<Entry, Entry, Entry, Entry>> {
PVar<PrimExpr> A, B, C, D;

if (PMatchesOneOf{
(A + B) * C - (A * B) * D,
(A + B) * C - (B * A) * D,
}
.Match(expr)) {
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
VisitExpr(D.Eval())};
} else if (PMatchesOneOf{
(A + B) * C - A * B,
(A + B) * C - B * A,
}
.Match(expr)) {
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
MakeBound(1, 1)};
} else if (PMatchesOneOf{
(A * B) * D - (A + B) * C,
(B * A) * D - (A + B) * C,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))};
} else if (PMatchesOneOf{
A * B - (A + B) * C,
B * A - (A + B) * C,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)};
} else if (PMatchesOneOf{
(A * B) * D + (A + B) * C,
(B * A) * D + (A + B) * C,
(A + B) * C + (A * B) * D,
(A + B) * C + (B * A) * D,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))};
} else if (PMatchesOneOf{
(A * B) + (A + B) * C,
(B * A) + (A + B) * C,
(A + B) * C + (A * B),
(A + B) * C + (B * A),
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
VisitExpr(C.Eval()), MakeBound(-1, -1)};
} else {
return std::nullopt;
}
}();

if (!opt_special_case.has_value()) {
return std::nullopt;
}
// Unpacking the tuple would be cleaner with a structured binding.
// However, until C++20, structured bindings cannot be captured for
// use in a lambda function.
auto A_bound = std::get<0>(*opt_special_case);
auto B_bound = std::get<1>(*opt_special_case);
auto C_bound = std::get<2>(*opt_special_case);
auto D_bound = std::get<3>(*opt_special_case);

// If C and D have different signs, flip the signs of A/B/C so
// that C will match the sign of D.
if ((D_bound.max_value < 0 && C_bound.min_value > 0) ||
(D_bound.min_value > 0 && C_bound.max_value < 0)) {
A_bound = Negative(A_bound);
B_bound = Negative(B_bound);
C_bound = Negative(C_bound);
}

// If all terms are negative, then we'll be providing an upper bound
// rather than a lower bound. To avoid code duplication, flip all the
// signs here, find a lower bound, then flip the sign to produce the
// upper bound of the original expression.
bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 &&
C_bound.max_value < 0 && D_bound.max_value < 0);
if (all_terms_negative) {
A_bound = Negative(A_bound);
B_bound = Negative(B_bound);
C_bound = Negative(C_bound);
D_bound = Negative(D_bound);
}

bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 &&
C_bound.min_value > 0 && D_bound.min_value > 0);
if (!all_terms_positive) {
return std::nullopt;
}

// (A + B) * C - (A * B) * D
// (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C )
// (A*B*C*D) * ( (1/A + 1/B)/D - 1/C )
// (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C)
//
// The constant (A*B*C*D) is positive, and its minimum value is the
// product of the minimum values of A, B, C, and D. If the reciprocal
// term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can
// be used to provide a lower bound on the expression.

bool reciprocal_term_is_positive = [&]() {
if (D_bound.max_value == ConstIntBound::kPosInf) {
// If D can grow without bound, the `1/(A*D)` and `1/(B*D)`
// terms will approach zero, at which point the `-1/C` term
// will determine the sign the sign.
return false;
}

if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) {
// 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D).
// Since each term is positive, this condition can hold if either
// A*D <= C or B*D <= C.
return true;
}
if (A_bound.max_value != ConstIntBound::kPosInf &&
B_bound.max_value != ConstIntBound::kPosInf) {
// Even if neither term is sufficient on its own, if both A and B
// have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D)
// may still be provable.
//
// The maximum value of the LHS is found when C is minimized. The
// minimum value of the RHS is found when A, B, and D are
// maximized. If the condition holds in this case, then it holds
// in all cases.
//
// 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max)
// A_max*B_max*D_max < C_min*B_max + C_min*A_max
// A_max*B_max*D_max < C_min*(A_max + B_max)
//
if (A_bound.max_value * B_bound.max_value * D_bound.max_value <
C_bound.min_value * (A_bound.max_value + B_bound.max_value)) {
return true;
}
}
return false;
}();

if (!reciprocal_term_is_positive) {
return std::nullopt;
}

auto ret = Everything(expr->dtype);
ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value;

// If we flipped the sign of the original expression, flip the sign of
// the resulting set of possible values.
if (all_terms_negative) {
ret = Negative(ret);
}
return ret;
}
};

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {
Expand Down
Loading