Skip to content

Commit d114dfc

Browse files
authored
[SCHEDULE] Mutate dataflow in schedule, refactor Stage (#44)
1 parent 820a859 commit d114dfc

File tree

14 files changed

+561
-168
lines changed

14 files changed

+561
-168
lines changed

include/tvm/operation.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
136136
* \param dtype the data type of the tensor.
137137
* \param name The name of the Tensor.
138138
*/
139-
Tensor Placeholder(Array<Expr> shape,
139+
Tensor placeholder(Array<Expr> shape,
140140
Type dtype = Float(32),
141141
std::string name = "placeholder");
142142

@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
147147
* \param fcompute The compute function to create the tensor.
148148
* \param name The optional name of the tensor.
149149
*/
150-
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
150+
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
151151

152152
/*!
153153
* \brief Construct new tensors by scan over scan_axis.
@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
158158
* \param state_placeholder The placeholder for the states.
159159
* \param name The optional name of the tensor.
160160
*/
161-
Array<Tensor> Scan(IterVar scan_axis,
161+
Array<Tensor> scan(IterVar scan_axis,
162162
Array<Tensor> init,
163163
Array<Tensor> update,
164164
Array<Tensor> state_placeholder,
165165
std::string name = "scan");
166166

167167
// same as compute, specialized for different fcompute function
168-
inline Tensor Compute(Array<Expr> shape,
168+
inline Tensor compute(Array<Expr> shape,
169169
std::function<Expr(Var)> f,
170170
std::string name = "tensor") {
171171
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
172-
return Compute(shape, fc, name);
172+
return compute(shape, fc, name);
173173
}
174-
inline Tensor Compute(Array<Expr> shape,
174+
inline Tensor compute(Array<Expr> shape,
175175
std::function<Expr(Var, Var)> f,
176176
std::string name = "tensor") {
177177
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
178-
return Compute(shape, fc, name);
178+
return compute(shape, fc, name);
179179
}
180-
inline Tensor Compute(Array<Expr> shape,
180+
inline Tensor compute(Array<Expr> shape,
181181
std::function<Expr(Var, Var, Var)> f,
182182
std::string name = "tensor") {
183183
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
184-
return Compute(shape, fc, name);
184+
return compute(shape, fc, name);
185185
}
186-
inline Tensor Compute(Array<Expr> shape,
186+
inline Tensor compute(Array<Expr> shape,
187187
std::function<Expr(Var, Var, Var, Var)> f,
188188
std::string name = "tensor") {
189189
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
190-
return Compute(shape, fc, name);
190+
return compute(shape, fc, name);
191191
}
192192

193193
} // namespace tvm

include/tvm/schedule.h

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ class Stage : public NodeRef {
131131
IterVar* p_x_outer, IterVar* p_y_outer,
132132
IterVar* p_x_inner, IterVar* p_y_inner,
133133
Expr x_factor, Expr y_factor);
134+
/*!
135+
* \brief Specify thread launching group in
136+
* outer most scope of the stage.
137+
* This is only valid for composite operators.
138+
* \param threads The threads to be launched.
139+
*/
140+
Stage& outermost_threads(Array<IterVar> threads);
134141
/*!
135142
* \brief Vectorize iteration.
136143
* \param var The axis to be vectorized.
@@ -179,6 +186,28 @@ class Schedule : public NodeRef {
179186
Stage operator[](const Tensor& tensor) {
180187
return this->operator[](tensor->op);
181188
}
189+
/*!
190+
* \brief create a cache read of original tensor for readers.
191+
* This will mutate the body of the readers.
192+
* A new stage will be created for the tensor.
193+
* \param tensor The tensor cached.
194+
* \param scope The scope of the cache.
195+
* \param readers The readers to redirect to the tensor.
196+
* \return The created tensor.
197+
*/
198+
Tensor cache_read(const Tensor& tensor,
199+
const std::string& scope,
200+
const Array<Operation>& readers);
201+
/*!
202+
* \brief Create a cache write tensor for producing tensor.
203+
* The the tensor will take over body of original tensor op.
204+
* The original tensor's body will be changed to an identity read
205+
* from the corresponding cache.
206+
* \param tensor The tensor to be produced.
207+
* \param scope The scope of the storage.
208+
* \return The created tensor.
209+
*/
210+
Tensor cache_write(const Tensor& tensor, const std::string& scope);
182211
/*!
183212
* \brief Normalize the schedule.
184213
* This is needed before bound inference.
@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
193222
* \return the pointer to the internal node container
194223
*/
195224
inline const ScheduleNode* operator->() const;
225+
/*!
226+
* \brief access the internal node container
227+
* \return the pointer to the internal node container
228+
*/
229+
inline ScheduleNode* operator->();
196230
// declare container type
197231
using ContainerType = ScheduleNode;
198232
};
@@ -244,17 +278,28 @@ class IterVarAttr : public NodeRef {
244278
*/
245279
class StageNode : public Node {
246280
public:
247-
/*! \brief The operation to be scheduled */
248-
Operation op;
249281
/*! \brief The thread scope level of the stage */
250282
std::string scope;
283+
/*! \brief The operation of stage, can be different from original op. */
284+
Operation op;
285+
/*!
286+
* \brief The original operator.
287+
* The op field can change during schedule to alternate the dataflow,
288+
* while origin_op remains fixed.
289+
*/
290+
Operation origin_op;
251291
/*! \brief All the nodes in the iter var */
252292
Array<IterVar> all_iter_vars;
253293
/*!
254294
* \brief The current leafs in the schedule.
255295
* Operations can only be performed in leaves.
256296
*/
257297
Array<IterVar> leaf_iter_vars;
298+
/*!
299+
* \brief Specify threads to be launched at the stage.
300+
* This is only valid for composite ops such as Scan.
301+
*/
302+
Array<IterVar> outermost_threads;
258303
/*! \brief The relation bwteen of IterVars */
259304
Array<IterVarRelation> relations;
260305
/*! \brief additional attributes about iter var. */
@@ -265,17 +310,22 @@ class StageNode : public Node {
265310
IterVar attach_ivar;
266311
/*! \brief The stage this node attaches to */
267312
Stage attach_stage;
313+
/*! \brief Whether this is an output stage */
314+
bool is_output{false};
268315

269316
void VisitAttrs(AttrVisitor* v) final {
270317
v->Visit("scope", &scope);
271318
v->Visit("op", &op);
319+
v->Visit("origin_op", &origin_op);
272320
v->Visit("all_iter_vars", &all_iter_vars);
273321
v->Visit("leaf_iter_vars", &leaf_iter_vars);
322+
v->Visit("outermost_threads", &outermost_threads);
274323
v->Visit("relations", &relations);
275324
v->Visit("iter_var_attrs", &iter_var_attrs);
276325
v->Visit("attach_type", &attach_type);
277326
v->Visit("attach_ivar", &attach_ivar);
278327
v->Visit("attach_stage", &attach_stage);
328+
v->Visit("is_output", &is_output);
279329
}
280330

281331
static constexpr const char* _type_key = "Stage";
@@ -285,18 +335,18 @@ class StageNode : public Node {
285335
/*! \brief node container for schedule */
286336
class ScheduleNode : public Node {
287337
public:
288-
/*! \brief The root operations */
289-
Array<Operation> roots;
338+
/*! \brief The output operations in original data flow graph */
339+
Array<Operation> outputs;
290340
/*!
291-
* \brief list of all stages for non-placeholder ops
292-
* The stage are ordered in PostDFS order of their op.
341+
* \brief list of all stages for non-placeholder ops.
342+
* The stages are sorted in dependency order.
293343
*/
294344
Array<Stage> stages;
295345
/*! \brief map of operation to the stages */
296346
Map<Operation, Stage> stage_map;
297347

298348
void VisitAttrs(AttrVisitor* v) final {
299-
v->Visit("roots", &roots);
349+
v->Visit("outputs", &outputs);
300350
v->Visit("stages", &stages);
301351
v->Visit("stage_map", &stage_map);
302352
}
@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
412462

413463
inline bool Stage::is_scheduled() const {
414464
const StageNode* n = operator->();
415-
return !(n->relations.empty() && n->attach_type == kNone);
465+
return !(n->relations.empty() && n->attach_type == kNone &&
466+
n->all_iter_vars.same_as(n->leaf_iter_vars));
416467
}
417468

418469
inline const ScheduleNode* Schedule::operator->() const {
419470
return static_cast<const ScheduleNode*>(node_.get());
420471
}
472+
inline ScheduleNode* Schedule::operator->() {
473+
return static_cast<ScheduleNode*>(node_.get());
474+
}
421475

422476
inline const IterVarRelationNode* IterVarRelation::operator->() const {
423477
return static_cast<const IterVarRelationNode*>(node_.get());

python/tvm/build.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def build(sch,
6363
arg_list.append(x)
6464
else:
6565
raise ValueError("args must be Tensor, Buffer or Var")
66-
6766
# lowering
6867
bounds = schedule.InferBound(sch)
6968
stmt = schedule.ScheduleOps(sch, bounds)

python/tvm/schedule.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._ctypes._node import NodeBase, register_node
55
from . import _api_internal
66
from . import tensor as _tensor
7+
from . import collections as _collections
78

89
@register_node
910
class Buffer(NodeBase):
@@ -41,6 +42,53 @@ def normalize(self):
4142
"""
4243
_api_internal._ScheduleNormalize(self)
4344

45+
def cache_read(self, tensor, scope, readers):
46+
"""Create a cache read of original tensor for readers.
47+
48+
This will mutate the body of the readers.
49+
A new cache stage will be created for the tensor.
50+
Call this before doing any split/fuse schedule.
51+
52+
Parameters
53+
----------
54+
tensor : Tensor
55+
The tensor to be cached.
56+
scope : str
57+
The scope of cached
58+
readers : list of Tensor or Operation
59+
The readers to read the cache.
60+
61+
Returns
62+
-------
63+
cache : Tensor
64+
The created cache tensor.
65+
"""
66+
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
67+
readers = [readers]
68+
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
69+
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)
70+
71+
def cache_write(self, tensor, scope):
72+
"""Create a cache write of original tensor, before storing into tensor.
73+
74+
This will mutate the body of the tensor.
75+
A new cache stage will created before feed into the tensor.
76+
77+
Parameters
78+
----------
79+
tensor : Tensor
80+
The tensor to be feed to.
81+
scope : str
82+
The scope of cached
83+
84+
Returns
85+
-------
86+
cache : Tensor
87+
The created cache tensor.
88+
"""
89+
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
90+
91+
4492
@register_node
4593
class Stage(NodeBase):
4694
"""A Stage represents schedule for one operation."""
@@ -104,6 +152,18 @@ def set_scope(self, scope):
104152
"""
105153
return _api_internal._StageSetScope(self, scope)
106154

155+
def outermost_threads(self, threads):
156+
"""Force launch threads at outermost scope of the stage.
157+
158+
Parameters
159+
----------
160+
threads : list of threads
161+
The threads to be launched.
162+
"""
163+
if isinstance(threads, _collections.IterVar):
164+
threads = [threads]
165+
_api_internal._StageOutermostThreads(self, threads)
166+
107167
def compute_at(self, parent, scope):
108168
"""Attach the stage at parent's scope
109169

src/api/api_lang.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
161161

162162
TVM_REGISTER_API(_Placeholder)
163163
.set_body([](TVMArgs args, TVMRetValue* ret) {
164-
*ret = Placeholder(args[0],
164+
*ret = placeholder(args[0],
165165
args[1],
166166
args[2]);
167167
});
@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
262262
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
263263
});
264264

265+
TVM_REGISTER_API(_StageOutermostThreads)
266+
.set_body([](TVMArgs args, TVMRetValue* ret) {
267+
args[0].operator Stage()
268+
.outermost_threads(args[1]);
269+
});
270+
265271
TVM_REGISTER_API(_StageUnroll)
266272
.set_body([](TVMArgs args, TVMRetValue* ret) {
267273
args[0].operator Stage()
@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
280286
.normalize();
281287
});
282288

289+
TVM_REGISTER_API(_ScheduleCacheRead)
290+
.set_body([](TVMArgs args, TVMRetValue* ret) {
291+
*ret = args[0].operator Schedule()
292+
.cache_read(args[1], args[2], args[3]);
293+
});
294+
295+
TVM_REGISTER_API(_ScheduleCacheWrite)
296+
.set_body([](TVMArgs args, TVMRetValue* ret) {
297+
*ret = args[0].operator Schedule()
298+
.cache_write(args[1], args[2]);
299+
});
300+
283301
} // namespace tvm

src/lang/operation.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
5353

5454

5555

56-
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
56+
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
5757
return PlaceholderOpNode::make(name, shape, dtype).output(0);
5858
}
5959

@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
8282
return Array<Expr>(shape);
8383
}
8484

85-
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
85+
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
8686
auto op_node = std::make_shared<ComputeOpNode>();
8787
// compute dimension.
8888
size_t ndim = shape.size();
@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
188188
return Operation(n);
189189
}
190190

191-
Array<Tensor> Scan(IterVar scan_axis,
191+
Array<Tensor> scan(IterVar scan_axis,
192192
Array<Tensor> init,
193193
Array<Tensor> update,
194194
Array<Tensor> state_placeholder,

0 commit comments

Comments
 (0)