2626#include < tvm/tir/stmt_functor.h>
2727#include < tvm/tir/transform.h>
2828
29+ #include < set>
30+ #include < stack>
2931#include < utility>
3032
3133#include " ../schedule/analysis.h"
@@ -87,8 +89,8 @@ Map<Var, Buffer> UpdateBufferMap(PrimFunc f) {
8789 */
8890class IndexTransformer : public StmtExprMutator {
8991 public:
90- explicit IndexTransformer (AccessAndDependencyCollector collector)
91- : collector_(std::move(collector)) {}
92+ explicit IndexTransformer (AccessAndDependencyCollector collector, AxisTree axis_tree )
93+ : collector_(std::move(collector)), axis_tree_(std::move(axis_tree)) {}
9294
9395 private:
9496 /* !
@@ -281,43 +283,124 @@ class IndexTransformer : public StmtExprMutator {
281283 sp_block->init .defined () ? VisitStmt (sp_block->init .value ()) : Optional<Stmt>(NullOpt);
282284 Stmt body = VisitStmt (sp_block->body );
283285
284- // Step 2. Create the new outer loop vars.
285- Array<Var> loop_vars;
286+ // Step 2. Create the new loop vars.
286287 std::unordered_map<const VarNode*, PrimExpr> var_map;
287- loop_vars. reserve (n_iter) ;
288+ Array<Var> all_loop_vars ;
288289 var_map.reserve (n_iter);
289290 for (const SpIterVar& sp_iter : sp_block->sp_iter_vars ) {
290291 Var loop_var (" v_" + sp_iter->var ->name_hint );
291- loop_vars .push_back (loop_var);
292+ all_loop_vars .push_back (loop_var);
292293 var_map[sp_iter->var .get ()] = loop_var;
293294 }
294295
295- // Step 3. Create block iters and iter bindings.
296- Array<IterVar> block_iters;
297- Array<PrimExpr> iter_bindings;
298- block_iters.reserve (n_iter);
299- iter_bindings.reserve (n_iter);
300- for (int i = 0 ; i < n_iter; ++i) {
301- block_iters.push_back (SpIterVarToIterVar (sp_block->sp_iter_vars [i], var_map));
302- iter_bindings.push_back (loop_vars[i]);
303- }
296+ // Step 3. Collet block iters and iter bindings.
297+ std::set<String> in_stack;
298+ in_stack.insert (" root" );
299+ /* A stack that stores block itervars in each block. */
300+ std::stack<Array<IterVar>> block_iters_st;
301+ /* A stack that stores itervar bindings in each block. */
302+ std::stack<Array<PrimExpr>> iter_bindings_st;
303+ /* A stack that stores generated loop vars in each block. */
304+ std::stack<Array<Var>> loop_vars_st;
305+ /* A stack that stores whether to place init block in each block. */
306+ std::stack<bool > place_init_st;
307+ /* An indicator that records whether init block has been set. */
308+ bool init_set = false ;
309+ do {
310+ /* Block itervars of current block. */
311+ Array<IterVar> block_iters;
312+ /* Itervar bindings of current block. */
313+ Array<PrimExpr> iter_bindings;
314+ /* Axis names of current block. */
315+ Array<String> axis_names;
316+ /* Generated loop vars of current block. */
317+ Array<Var> loop_vars;
318+ /* An indicator that records whether there is reduction axis in current block. */
319+ bool has_reduction_var = false ;
320+ for (int i = 0 ; i < n_iter; ++i) {
321+ SpIterVar sp_it_var = sp_block->sp_iter_vars [i];
322+ String axis_name = sp_it_var->axis ->name ;
323+ auto && parent_axis = axis_tree_->parent .Get (axis_name);
324+ CHECK (parent_axis.defined ()) << " Sparse IterVar not defined in Axis Tree." ;
325+ String parent_axis_name = parent_axis.value ();
326+ bool is_fixed_axis = sp_it_var->axis ->is_fixed ();
327+ /* Add itervar to current block when
328+ * - it's not used yet (not in stack) and
329+ * - it's parent axis was used in outer blocks or
330+ * - it's an iterator to a fixed axis.
331+ */
332+ if ((is_fixed_axis || in_stack.find (parent_axis_name) != in_stack.end ()) &&
333+ in_stack.find (axis_name) == in_stack.end ()) {
334+ loop_vars.push_back (all_loop_vars[i]);
335+ axis_names.push_back (std::move (axis_name));
336+ block_iters.push_back (SpIterVarToIterVar (sp_it_var, var_map));
337+ iter_bindings.push_back (all_loop_vars[i]);
338+ has_reduction_var |= sp_it_var->is_reduction ;
339+ }
340+ }
341+
342+ /* Tag axes in current block as "in-stack". */
343+ for (const String&& axis_name : axis_names) {
344+ in_stack.insert (std::move (axis_name));
345+ }
346+
347+ /* Update stack. */
348+ if (!block_iters.empty ()) {
349+ block_iters_st.push (std::move (block_iters));
350+ iter_bindings_st.push (std::move (iter_bindings));
351+ loop_vars_st.push (std::move (loop_vars));
352+ if (init_set) {
353+ place_init_st.push (false );
354+ } else {
355+ place_init_st.push (has_reduction_var);
356+ init_set |= has_reduction_var;
357+ }
358+ } else {
359+ break ;
360+ }
361+ } while (true );
304362
305363 // Step 4. Generate the read-region and write-retion of the block.
306364 Array<BufferRegion> reads{nullptr };
307365 Array<BufferRegion> writes{nullptr };
308366 GenerateReadWriteRegions (sp_block, &reads, &writes);
309367
310- // Step 5. Create the block and block-realize
311- Map<String, ObjectRef> mapping;
312- mapping.Set (" sparse" , Bool (true ));
313- Block block (block_iters, std::move (reads), std::move (writes), sp_block->name , std::move (body),
314- std::move (init), {}, {}, std::move (mapping));
315- BlockRealize block_realize (std::move (iter_bindings), const_true (), std::move (block));
316-
317- // Step 6. Create outer loops and the block binding.
318- Stmt loop = GenerateLoops (std::move (block_realize), block_iters, loop_vars);
368+ // Step 5. Generate nested blocks and loops from innermost to outermost.
369+ int blk_counter = 0 ;
370+ while (!block_iters_st.empty ()) {
371+ Array<IterVar> block_iters = std::move (block_iters_st.top ());
372+ Array<PrimExpr> iter_bindings = std::move (iter_bindings_st.top ());
373+ Array<Var> loop_vars = std::move (loop_vars_st.top ());
374+ bool place_init = place_init_st.top ();
375+ block_iters_st.pop ();
376+ iter_bindings_st.pop ();
377+ loop_vars_st.pop ();
378+ place_init_st.pop ();
379+
380+ Map<String, ObjectRef> mapping;
381+ mapping.Set (" sparse" , Bool (true ));
382+ String blk_name_hint = sp_block->name ;
383+ if (blk_counter != 0 ) {
384+ blk_name_hint = blk_name_hint + " _" + std::to_string (blk_counter);
385+ }
386+ Block block (/* iter_vars=*/ block_iters,
387+ /* reads=*/ reads,
388+ /* writes=*/ writes,
389+ /* name_hint=*/ blk_name_hint,
390+ /* body=*/ std::move (body),
391+ /* init=*/ place_init ? std::move (init) : NullOpt,
392+ /* alloc_buffers=*/ {},
393+ /* match_buffers=*/ {},
394+ /* annotations=*/ std::move (mapping),
395+ /* span=*/ sp_block->span );
396+ BlockRealize block_realize (std::move (iter_bindings), const_true (), std::move (block));
397+ // Generate outer loop and the block binding.
398+ Stmt loop = GenerateLoops (std::move (block_realize), block_iters, loop_vars);
399+ body = loop;
400+ blk_counter += 1 ;
401+ }
319402
320- return loop ;
403+ return body ;
321404 }
322405
323406 /* !
@@ -380,9 +463,10 @@ class IndexTransformer : public StmtExprMutator {
380463 }
381464
382465 /* !
383- * \brief generated nested for loops for sparse block.
466+ * \brief generated nested for- loops for sparse block.
384467 * \param block_iters The iterators defined in sparse blocks.
385468 * \param loop_vars The loop variables binded with block iterators.
469+ * \return The outermost loop.
386470 */
387471 Stmt GenerateLoops (Stmt body, const Array<IterVar>& block_iters, const Array<Var>& loop_vars) {
388472 int n_iter = static_cast <int >(block_iters.size ());
@@ -394,6 +478,7 @@ class IndexTransformer : public StmtExprMutator {
394478 }
395479
396480 AccessAndDependencyCollector collector_;
481+ AxisTree axis_tree_;
397482 arith::Analyzer ana_;
398483 std::unordered_set<const SparseBufferNode*> buffer_read_;
399484 std::unordered_set<const SparseBufferNode*> buffer_write_;
@@ -411,11 +496,12 @@ Stmt WrapWithRootBlock(Stmt body) {
411496}
412497
413498/* !
414- * \brief Rewrite the given primitive function
499+ * \brief Rewrite the given primitive function.
500+ * \param axis_tree The axis dependency tree.
415501 * \param f The Sparse-TIR primitive function to lower.
416502 * \return lowered primitive function in TIR.
417503 */
418- PrimFunc LowerSparseTIR (PrimFunc f) {
504+ PrimFunc LowerSparseTIR (AxisTree axis_tree, PrimFunc f) {
419505 // Only apply this pass to TIR that is not from TE schedules
420506 if (!IsFromLegacyTESchedule (f)) {
421507 PrimFuncNode* fptr = f.CopyOnWrite ();
@@ -425,7 +511,7 @@ PrimFunc LowerSparseTIR(PrimFunc f) {
425511 AccessAndDependencyCollector collector;
426512 collector.Collect (f->body );
427513 // Step 3. Lower indices.
428- fptr->body = IndexTransformer (collector)(std::move (f->body ));
514+ fptr->body = IndexTransformer (collector, axis_tree )(std::move (f->body ));
429515 // Step 4. Wrap the function body with a root block.
430516 fptr->body = WrapWithRootBlock (std::move (fptr->body ));
431517 return f;
@@ -438,10 +524,11 @@ namespace transform {
438524
439525/* !
440526 * \brief The lowering pass from TIR to Sparse TIR.
527+ * \param axis_tree The axis dependency tree.
441528 */
442- Pass LowerSparseTIR () {
529+ Pass LowerSparseTIR (AxisTree axis_tree ) {
443530 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
444- return LowerSparseTIR (std::move (f));
531+ return LowerSparseTIR (std::move (axis_tree), std::move ( f));
445532 };
446533 return CreatePrimFuncPass (pass_func, 0 , " tir.LowerSparseTIR" , {});
447534}
0 commit comments