Skip to content

Commit

Permalink
[REFACTOR][ARITH] Unified IR, introduce arith subfolder. (apache#4722)
Browse files Browse the repository at this point in the history
Spread the arithmetic.h into several components and move
into arith subfolder.

The arith namespace will be used for arithmetic expression
pattern detections and simplifications.
  • Loading branch information
tqchen authored Jan 16, 2020
1 parent dd13c2c commit c7a8319
Show file tree
Hide file tree
Showing 63 changed files with 483 additions and 357 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ file(GLOB COMPILER_SRCS
src/ir/*.cc
src/target/*.cc
src/api/*.cc
src/arithmetic/*.cc
src/arith/*.cc
src/autotvm/*.cc
src/codegen/*.cc
src/lang/*.cc
Expand Down
275 changes: 8 additions & 267 deletions include/tvm/arithmetic.h → include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@
*/

/*!
* \file tvm/arithmetic.h
* \brief Algebra and set operations and simplifications.
* \file tvm/arith/analyzer.h
* \brief Algebra expression simplifications.
*/
#ifndef TVM_ARITHMETIC_H_
#define TVM_ARITHMETIC_H_
#ifndef TVM_ARITH_ANALYZER_H_
#define TVM_ARITH_ANALYZER_H_

#include <tvm/support/with.h>
#include <tvm/ir/expr.h>
#include <tvm/arith/int_set.h>

#include <vector>
#include <unordered_map>
#include <memory>
#include <limits>
#include "expr.h"

namespace tvm {
// forward delcare Tensor
class Tensor;
/*! \brief namespace of arithmetic */
/*! \brief namespace of arithmetic analysis. */
namespace arith {
//-------------------------------------------------------
// Base integer analysis API.
Expand Down Expand Up @@ -332,113 +331,6 @@ class ConstraintContext {
std::function<void()> exit_;
};

//-----------------------------------------------
// Integer set data structure.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign type of an integer expression.
*/
enum SignType {
kPositive,
kNegative,
kZero,
kUnknown
};

/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Object {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};

/*!
* \brief Integer set class, represent a set of integers in one dimension.
*/
class IntSet : public ObjectRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IntSetNode* operator->() const;
/*!
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*! \return Lower bound of the set */
PrimExpr min() const;
/*! \return upper bound of the set */
PrimExpr max() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*! \return Whether the set is proved to be bigger than 0 */
bool can_prove_positive() const;
/*! \return Whether the set is proved to be smaller than 0 */
bool can_prove_negative() const;
/*! \return Whether the set is proved to be smaller than or equal to 0 */
bool can_prove_non_positive() const;
/*! \return Whether the set is proved to be larger than or equal to 0 */
bool can_prove_non_negative() const;
/*! \return The sign of the elements in the integer set */
SignType sign_type() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
PrimExpr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
/*! \return The set contains nothing */
static IntSet nothing();
/*! \return The set contains everything */
static IntSet everything();
/*!
* \brief construct a point set.
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet single_point(PrimExpr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(PrimExpr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
static IntSet range(Range r);
/*!
* \brief Construct a set representing a interval.
* \param min The minimum value of the interval.
* \param max The maximum value of the interval.
* \return constructed set.
*/
static IntSet interval(PrimExpr min, PrimExpr max);
};

/*!
* \brief Integer set analyzer.
*/
Expand Down Expand Up @@ -545,157 +437,6 @@ class Analyzer {
PrimExpr Simplify(const PrimExpr& expr);
};

//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param r The initial range.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);

/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param s The initial set.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(IntSet s,
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param r The range to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return the map from the expression to its possible value.
*/
ExprIntSetMap EvalSetForEachSubExpr(
PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \return the set after union
*/
IntSet Union(const Array<IntSet>& sets);

/*!
* \brief Create an union set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
IntSet Intersect(const Array<IntSet>& sets);

/*!
* \brief Deduce the bound of the target variable in a expression,
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
/*!
* \brief Same as DeduceBound with unordered_map signature.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map);

/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
* \param body The given statement.
* \param tensor The name of the calls or provides.
* \param consider_calls If calls (read) are considered.
* \param consider_provides If provides (write) are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

// Expression pattern detector.
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
const Array<Var>& vars);

/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<PrimExpr> DetectClipBound(const PrimExpr& e,
const Array<Var>& vars);

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(get());
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_H_
#endif // TVM_ARITH_ANALYZER_H_
82 changes: 82 additions & 0 deletions include/tvm/arith/bound.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/bound.h
* \brief Bound deducers.
*/
#ifndef TVM_ARITH_BOUND_H_
#define TVM_ARITH_BOUND_H_

#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/arith/int_set.h>
#include <tvm/expr.h>

#include <unordered_map>

namespace tvm {
// forward delcare Tensor
class Tensor;
namespace arith {

/*!
* \brief Deduce the bound of the target variable in a expression,
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
/*!
* \brief Same as DeduceBound with unordered_map signature.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map);

/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
* \param body The given statement.
* \param tensor The name of the calls or provides.
* \param consider_calls If calls (read) are considered.
* \param consider_provides If provides (write) are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_BOUND_H_
Loading

0 comments on commit c7a8319

Please sign in to comment.