Skip to content

Commit

Permalink
[REFACTOR] Add Types to IterVar, Isolate Operator (apache#62)
Browse files Browse the repository at this point in the history
* [IterVar/REFACTOR] Add types to IterVar

* [ARITH/REFACTOR] Move IntSet to include

* [REFACTOR/OP] Move Op detail to seperate folder.

* fix test
  • Loading branch information
tqchen authored Mar 5, 2017
1 parent c8ebfbe commit 3fb8579
Show file tree
Hide file tree
Showing 50 changed files with 1,695 additions and 1,190 deletions.
125 changes: 101 additions & 24 deletions src/arithmetic/int_set.h → include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.h
* \brief Abstraction for all integer set operations.
* \file arithmetic.h
* \brief Algebra and set operations.
*/
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_
#ifndef TVM_ARITHMETIC_H_
#define TVM_ARITHMETIC_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <vector>
#include <unordered_map>
#include <memory>
#include "./expr.h"

namespace tvm {
/*! \brief namespace of arithmetic */
namespace arith {

/*!
* \brief Sign of an expression or set.
*/
enum SignType {
kPositive,
kNegative,
Expand Down Expand Up @@ -101,6 +105,41 @@ class IntSet : public NodeRef {
static IntSet interval(Expr min, Expr max);
};

/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;

/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
return e;
}
/*!
* \brief Add two modular entries together to get a new modular entry.
* \param a The left operand.
* \param b The right operand.
* \return The combined modular entry.
*/
static ModularEntry Add(const ModularEntry& a,
const ModularEntry& b);
};

/*!
* \brief Base class of all IntSet containers.
*/
Expand All @@ -109,9 +148,6 @@ struct IntSetNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};

using ExprIntSetMap = std::unordered_map<Expr, IntSet,
Halide::ExprHash, Halide::ExprEqual>;

/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
Expand All @@ -122,6 +158,13 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet,
*/
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);

Expand All @@ -135,11 +178,18 @@ IntSet EvalSet(Expr e,
*/
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param r The range to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);



/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand All @@ -148,7 +198,8 @@ IntSet EvalSet(Range r,
* \param dom_map The domain of each variable.
* \return the map from the expression to its possible value.
*/
ExprIntSetMap EvalSetForEachSubExpr(Expr r,
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);

/*!
Expand All @@ -165,11 +216,6 @@ IntSet Union(const Array<IntSet>& sets);
*/
IntSet Intersect(const Array<IntSet>& sets);

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}

/*!
* \brief Deduce the bound of the target variable in a expression,
* give the domain of each variables. Return undefined IntSet to
Expand All @@ -178,18 +224,49 @@ inline const IntSetNode* IntSet::operator->() const {
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax The domain of each variable, used to relax the domain.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
*/
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);
/*!
* \brief Same as DeduceBound with unordered_map signature.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);

/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);

/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
} // namespace arith
} // namespace tvm

#endif // TVM_ARITHMETIC_INT_SET_H_
#endif // TVM_ARITHMETIC_H_
121 changes: 112 additions & 9 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ using Halide::Bool;
using Halide::Int;
using Halide::UInt;
using Halide::Handle;
using Halide::ExprHash;
using Halide::ExprEqual;

using Halide::Expr;
using Halide::VarExpr;
Expand Down Expand Up @@ -57,7 +59,14 @@ class Var : public Halide::VarExpr {
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}

/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
}
/*! \brief type indicate the container type */
using ContainerType = Variable;
};
Expand Down Expand Up @@ -90,6 +99,72 @@ class Range : public Halide::IR::Range {
static Range make_with_min_extent(Expr min, Expr extent);
};

/*!
* \brief Type of iteration variable.
* Each IterVar have a specific type.
*
* The type of iter var can be overriden via
* stage.iter_var_attrs given they are compatible.
*/
enum IterVarType : int {
/*!
* \brief Data parallel iteration.
* This normally corresponds to axis of Tensor.
* Allow all IterVar manipulations.
*
* \note This does not mean the loop
* have to be executed in parallel fashion.
*/
kDataPar = 0,
/*!
* \brief The IterVar itself is a thread-index
* of a fixed thread launching group.
* Note that this is already assumed to be paralellized.
*
* Disallow: split/fuse/vectorize/parallel
*/
kThreadIndex = 1,
/*!
* \brief Communicative reduction.
* Cannot be directly parallelized.
*
* Disallow: parallel/vectorize
*/
kCommReduce = 2,
/*!
* \brief Serial loops with loop carry dependency,
* the iteration must execute in order.
* Cannot be re-ordered.
*
* Disallow: reorder/parallel/vectorize
*/
kOrdered = 3,
/*!
* \brief IterVar is opaque,
*
* May not corresponds to any generated loop
* Disallow all IterVar manipulations and compute_at
*
* \note This is usually used to implement composite op
* or external op, where the
*/
kOpaque = 4,
// The following are possible additional
// types that are provided during schedule
/*!
* \brief The execution is unrolled.
*/
kUnrolled = 5,
/*!
* \brief The loop is vectorized.
*/
kVectorized = 6,
/*!
* \brief The loop is parallelized.
*/
kParallelized = 7
};

/*!
* \brief Iteration Variable,
* represents an iteration over an integer interval.
Expand All @@ -100,13 +175,6 @@ class IterVar : public NodeRef {
IterVar() {}
// construct from shared ptr.
explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construction of iteration variable.
* \param dom The iteration domain.
* \param var_name The name of iteration variable.
* \param thread_tag The additional tag to indicate whether the var is binded to fixed-thread.
*/
explicit IterVar(Range dom, std::string var_name = "i", std::string thread_tag = "");
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand All @@ -120,6 +188,22 @@ class IterVar : public NodeRef {
using ContainerType = IterVarNode;
};

/*!
* \brief Create a new IterVar that represents an axis in thread.
*
* \param dom Optional, domain of the thread axis.
* \param tag The thread tag of the axis.
*/
IterVar thread_axis(Range dom, std::string tag);

/*!
* \brief Create a new IterVar for reduction operations.
*
* \param dom The domain of the reduction axis.
* \param name The name of the reduction axis.
*/
IterVar reduce_axis(Range dom, std::string name = "rv");

using Domain = Array<Range>;

// functions
Expand Down Expand Up @@ -168,6 +252,8 @@ class IterVarNode : public Node {
Range dom;
/*! \brief The looping variable */
Var var;
/*! \brief The type of the IterVar */
IterVarType iter_type;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
Expand All @@ -177,10 +263,13 @@ class IterVarNode : public Node {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("iter_type", &iter_type);
v->Visit("thread_tag", &thread_tag);
}

static IterVar make(Range dom, Var var, std::string thread_tag);
static IterVar make(Range dom, Var var,
IterVarType iter_type,
std::string thread_tag = "");

static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
Expand All @@ -195,6 +284,20 @@ inline IterVar::operator Expr() const {
return (*this)->var;
}

inline const char* IterVarType2String(IterVarType t) {
switch (t) {
case kDataPar: return "DataPar";
case kThreadIndex: return "ThreadIndex";
case kCommReduce: return "CommRedude";
case kOrdered: return "Ordered";
case kOpaque: return "Opaque";
case kUnrolled: return "Unrolled";
case kVectorized: return "Vectorized";
case kParallelized: return "Parallelized";
}
return "Unknown";
}

} // namespace tvm

namespace std {
Expand Down
Loading

0 comments on commit 3fb8579

Please sign in to comment.