@@ -131,6 +131,13 @@ class Stage : public NodeRef {
131131 IterVar* p_x_outer, IterVar* p_y_outer,
132132 IterVar* p_x_inner, IterVar* p_y_inner,
133133 Expr x_factor, Expr y_factor);
134+ /* !
135+ * \brief Specify thread launching group in
136+ * outer most scope of the stage.
137+ * This is only valid for composite operators.
138+ * \param threads The threads to be launched.
139+ */
140+ Stage& outermost_threads (Array<IterVar> threads);
134141 /* !
135142 * \brief Vectorize iteration.
136143 * \param var The axis to be vectorized.
@@ -179,6 +186,28 @@ class Schedule : public NodeRef {
179186 Stage operator [](const Tensor& tensor) {
180187 return this ->operator [](tensor->op );
181188 }
189+ /* !
190+ * \brief create a cache read of original tensor for readers.
191+ * This will mutate the body of the readers.
192+ * A new stage will be created for the tensor.
193+ * \param tensor The tensor cached.
194+ * \param scope The scope of the cache.
195+ * \param readers The readers to redirect to the tensor.
196+ * \return The created tensor.
197+ */
198+ Tensor cache_read (const Tensor& tensor,
199+ const std::string& scope,
200+ const Array<Operation>& readers);
201+ /* !
202+ * \brief Create a cache write tensor for producing tensor.
203+ * The the tensor will take over body of original tensor op.
204+ * The original tensor's body will be changed to an identity read
205+ * from the corresponding cache.
206+ * \param tensor The tensor to be produced.
207+ * \param scope The scope of the storage.
208+ * \return The created tensor.
209+ */
210+ Tensor cache_write (const Tensor& tensor, const std::string& scope);
182211 /* !
183212 * \brief Normalize the schedule.
184213 * This is needed before bound inference.
@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
193222 * \return the pointer to the internal node container
194223 */
195224 inline const ScheduleNode* operator ->() const ;
225+ /* !
226+ * \brief access the internal node container
227+ * \return the pointer to the internal node container
228+ */
229+ inline ScheduleNode* operator ->();
196230 // declare container type
197231 using ContainerType = ScheduleNode;
198232};
@@ -244,17 +278,28 @@ class IterVarAttr : public NodeRef {
244278 */
245279class StageNode : public Node {
246280 public:
247- /* ! \brief The operation to be scheduled */
248- Operation op;
249281 /* ! \brief The thread scope level of the stage */
250282 std::string scope;
283+ /* ! \brief The operation of stage, can be different from original op. */
284+ Operation op;
285+ /* !
286+ * \brief The original operator.
287+ * The op field can change during schedule to alternate the dataflow,
288+ * while origin_op remains fixed.
289+ */
290+ Operation origin_op;
251291 /* ! \brief All the nodes in the iter var */
252292 Array<IterVar> all_iter_vars;
253293 /* !
254294 * \brief The current leafs in the schedule.
255295 * Operations can only be performed in leaves.
256296 */
257297 Array<IterVar> leaf_iter_vars;
298+ /* !
299+ * \brief Specify threads to be launched at the stage.
300+ * This is only valid for composite ops such as Scan.
301+ */
302+ Array<IterVar> outermost_threads;
258303 /* ! \brief The relation bwteen of IterVars */
259304 Array<IterVarRelation> relations;
260305 /* ! \brief additional attributes about iter var. */
@@ -265,17 +310,22 @@ class StageNode : public Node {
265310 IterVar attach_ivar;
266311 /* ! \brief The stage this node attaches to */
267312 Stage attach_stage;
313+ /* ! \brief Whether this is an output stage */
314+ bool is_output{false };
268315
269316 void VisitAttrs (AttrVisitor* v) final {
270317 v->Visit (" scope" , &scope);
271318 v->Visit (" op" , &op);
319+ v->Visit (" origin_op" , &origin_op);
272320 v->Visit (" all_iter_vars" , &all_iter_vars);
273321 v->Visit (" leaf_iter_vars" , &leaf_iter_vars);
322+ v->Visit (" outermost_threads" , &outermost_threads);
274323 v->Visit (" relations" , &relations);
275324 v->Visit (" iter_var_attrs" , &iter_var_attrs);
276325 v->Visit (" attach_type" , &attach_type);
277326 v->Visit (" attach_ivar" , &attach_ivar);
278327 v->Visit (" attach_stage" , &attach_stage);
328+ v->Visit (" is_output" , &is_output);
279329 }
280330
281331 static constexpr const char * _type_key = " Stage" ;
@@ -285,18 +335,18 @@ class StageNode : public Node {
285335/* ! \brief node container for schedule */
286336class ScheduleNode : public Node {
287337 public:
288- /* ! \brief The root operations */
289- Array<Operation> roots ;
338+ /* ! \brief The output operations in original data flow graph */
339+ Array<Operation> outputs ;
290340 /* !
291- * \brief list of all stages for non-placeholder ops
292- * The stage are ordered in PostDFS order of their op .
341+ * \brief list of all stages for non-placeholder ops.
342+ * The stages are sorted in dependency order.
293343 */
294344 Array<Stage> stages;
295345 /* ! \brief map of operation to the stages */
296346 Map<Operation, Stage> stage_map;
297347
298348 void VisitAttrs (AttrVisitor* v) final {
299- v->Visit (" roots " , &roots );
349+ v->Visit (" outputs " , &outputs );
300350 v->Visit (" stages" , &stages);
301351 v->Visit (" stage_map" , &stage_map);
302352 }
@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
412462
413463inline bool Stage::is_scheduled () const {
414464 const StageNode* n = operator ->();
415- return !(n->relations .empty () && n->attach_type == kNone );
465+ return !(n->relations .empty () && n->attach_type == kNone &&
466+ n->all_iter_vars .same_as (n->leaf_iter_vars ));
416467}
417468
418469inline const ScheduleNode* Schedule::operator ->() const {
419470 return static_cast <const ScheduleNode*>(node_.get ());
420471}
472+ inline ScheduleNode* Schedule::operator ->() {
473+ return static_cast <ScheduleNode*>(node_.get ());
474+ }
421475
422476inline const IterVarRelationNode* IterVarRelation::operator ->() const {
423477 return static_cast <const IterVarRelationNode*>(node_.get ());
0 commit comments