Skip to content

Commit

Permalink
[LANG/PASS] Support Vectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 9, 2017
1 parent 08505e3 commit b59602d
Show file tree
Hide file tree
Showing 24 changed files with 1,119 additions and 138 deletions.
1 change: 1 addition & 0 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class IRMutator {
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ Stmt StorageFlatten(Stmt stmt,
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);

/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
56 changes: 56 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class StageNode;
class ScheduleNode;
// Node container for IterVarRelation
class IterVarRelationNode;
// Attribute of itervar.
class IterVarAttrNode;

/*! \brief the attachment type */
enum AttachType : int {
Expand All @@ -27,6 +29,12 @@ enum AttachType : int {
kScope = 3
};

/*! \brief IterVar type */
enum IterVarType : int {
kUnrolled = 1,
kVectorized = 2
};

/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef {
public:
Expand Down Expand Up @@ -123,6 +131,18 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
// declare container type
using ContainerType = StageNode;
};
Expand Down Expand Up @@ -187,6 +207,21 @@ class IterVarRelation : public NodeRef {
inline const IterVarRelationNode* operator->() const;
};

/*!
* \brief Additional scheduable attributes about IterVar.
*/
class IterVarAttr : public NodeRef {
public:
IterVarAttr() {}
explicit IterVarAttr(IterVarType t);
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarAttrNode* operator->() const;
};

// defintion of node containers
/*!
* \brief represents the schedule of the tensor
Expand Down Expand Up @@ -217,6 +252,8 @@ class StageNode : public Node {
Array<IterVar> leaf_iter_vars;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
Map<IterVar, IterVarAttr> iter_var_attrs;
/*! \brief The attachment type of the schedule */
AttachType attach_type{kNone};
/*! \brief The attach point of this schedule. */
Expand All @@ -230,6 +267,7 @@ class StageNode : public Node {
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
Expand Down Expand Up @@ -262,6 +300,20 @@ class ScheduleNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
};

/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
public:
/*! \brief The iteration type. */
IterVarType iter_type;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
}

static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode);
};

/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
};
Expand Down Expand Up @@ -361,5 +413,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
}

inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(node_.get());
}

} // namespace tvm
#endif // TVM_SCHEDULE_H_
1 change: 1 addition & 0 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def build(sch,
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,23 @@ def tile(self, x_parent, y_parent, x_factor, y_factor):
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner

def vectorize(self, var):
"""Vectorize the iteration.
Parameters
----------
var : IterVar
The iteration to be vectorize
"""
_api_internal._StageVectorize(self, var)

def unroll(self, var):
"""Unroll the iteration.
Parameters
----------
var : IterVar
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)
12 changes: 12 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});

TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.unroll(args[1]);
});

TVM_REGISTER_API(_StageVectorize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.vectorize(args[1]);
});

TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
Expand Down
18 changes: 18 additions & 0 deletions src/arithmetic/compute_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <tvm/ir.h>
#include <pass/Interval.h>
#include <limits>

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -52,6 +53,23 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
}
}

// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0;
uint64_t v2 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
if (GetConst(e, &v2)) {
if (v2 > static_cast<uint64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v2); return true;
}
return false;
}

#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
Expand Down
Loading

0 comments on commit b59602d

Please sign in to comment.