Skip to content

Commit

Permalink
[SCHEDULE] Mutate dataflow in schedule, refactor Stage
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 16, 2017
1 parent 820a859 commit 48dc381
Show file tree
Hide file tree
Showing 14 changed files with 561 additions and 168 deletions.
22 changes: 11 additions & 11 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Tensor placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");

Expand All @@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct new tensors by scan over scan_axis.
Expand All @@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");

// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}

} // namespace tvm
Expand Down
70 changes: 62 additions & 8 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage& outermost_threads(Array<IterVar> threads);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
Expand Down Expand Up @@ -179,6 +186,28 @@ class Schedule : public NodeRef {
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
Expand All @@ -193,6 +222,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
// declare container type
using ContainerType = ScheduleNode;
};
Expand Down Expand Up @@ -244,17 +278,28 @@ class IterVarAttr : public NodeRef {
*/
class StageNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief The operation of stage, can be different from original op. */
Operation op;
/*!
* \brief The original operator.
* The op field can change during schedule to alternate the dataflow,
* while origin_op remains fixed.
*/
Operation origin_op;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
/*!
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array<IterVar> outermost_threads;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
Expand All @@ -265,17 +310,22 @@ class StageNode : public Node {
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
/*! \brief Whether this is an output stage */
bool is_output{false};

void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("outermost_threads", &outermost_threads);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
v->Visit("is_output", &is_output);
}

static constexpr const char* _type_key = "Stage";
Expand All @@ -285,18 +335,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*! \brief The output operations in original data flow graph */
Array<Operation> outputs;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
* \brief list of all stages for non-placeholder ops.
* The stages are sorted in dependency order.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
}
Expand Down Expand Up @@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {

inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone);
return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}

inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}

inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
Expand Down
1 change: 0 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")

# lowering
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
Expand Down
60 changes: 60 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
from . import tensor as _tensor
from . import collections as _collections

@register_node
class Buffer(NodeBase):
Expand Down Expand Up @@ -41,6 +42,53 @@ def normalize(self):
"""
_api_internal._ScheduleNormalize(self)

def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)

def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)


@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
Expand Down Expand Up @@ -104,6 +152,18 @@ def set_scope(self, scope):
"""
return _api_internal._StageSetScope(self, scope)

def outermost_threads(self, threads):
"""Force launch threads at outermost scope of the stage.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, _collections.IterVar):
threads = [threads]
_api_internal._StageOutermostThreads(self, threads)

def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
Expand Down
20 changes: 19 additions & 1 deletion src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)

TVM_REGISTER_API(_Placeholder)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Placeholder(args[0],
*ret = placeholder(args[0],
args[1],
args[2]);
});
Expand Down Expand Up @@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});

TVM_REGISTER_API(_StageOutermostThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.outermost_threads(args[1]);
});

TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
Expand All @@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize();
});

TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});

TVM_REGISTER_API(_ScheduleCacheWrite)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_write(args[1], args[2]);
});

} // namespace tvm
6 changes: 3 additions & 3 deletions src/lang/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,



Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}

Expand Down Expand Up @@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape);
}

Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
Expand Down Expand Up @@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}

Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Expand Down
Loading

0 comments on commit 48dc381

Please sign in to comment.