2727#include < tvm/tir/transform.h>
2828
2929#include < set>
30- #include < stack>
3130#include < utility>
3231
3332#include " ../../support/utils.h"
@@ -521,84 +520,84 @@ class IndexTransformer : public StmtExprMutator {
521520 var_map[sp_iter_var->var .get ()] = loop_var;
522521 }
523522
524- // Step 4. Collet block iters and iter bindings.
525- std::set<const AxisNode*> in_stack;
523+ // Step 4. Collect block iters and iter bindings.
524+ /* Whether the axis appears in the stack. */
525+ std::unordered_set<const AxisNode*> in_stack;
526526 /* A stack that stores block itervars in each block. */
527- std::stack <Array<IterVar>> block_iters_st;
527+ std::vector <Array<IterVar>> block_iters_st;
528528 /* A stack that stores itervar bindings in each block. */
529- std::stack <Array<PrimExpr>> iter_bindings_st;
529+ std::vector <Array<PrimExpr>> iter_bindings_st;
530530 /* A stack that stores generated loop vars in each block. */
531- std::stack <Array<Var>> loop_vars_st;
531+ std::vector <Array<Var>> loop_vars_st;
532532 /* A stack that stores whether to place init block in each block. */
533- std::stack <bool > place_init_st;
533+ std::vector <bool > place_init_st;
534534 /* An indicator that records whether init block has been set. */
535535 bool init_set = false ;
536- do {
537- /* Block itervars of current block. */
538- Array<IterVar> block_iters;
539- /* Itervar bindings of current block. */
540- Array<PrimExpr> iter_bindings;
541- /* Axis names of current block. */
542- Array<Axis> blk_axes;
543- /* Generated loop vars of current block. */
544- Array<Var> loop_vars;
545- /* An indicator that records whether there is reduction axis in current block. */
546- bool has_reduction_var = false ;
547- for (int i = 0 ; i < n_iter; ++i) {
548- SpIterVar sp_it_var = sp_block->sp_iter_vars [i];
549- Axis axis = sp_it_var->axis ;
550-
551- /* Add itervar to current block when
552- * - it's not used yet (not in stack) and
553- * - it's parent axis was used in outer blocks or
554- * - it's an iterator to a fixed axis.
555- */
556- auto parent = axis->GetParentAxis ();
557- bool emit_iter_var = true ;
558- if (in_stack.find (axis.get ()) !=
559- in_stack.end ()) { // the iter var has already been emitted.
560- emit_iter_var = false ;
536+ /* Block itervars of current block. */
537+ Array<IterVar> block_iters;
538+ /* Itervar bindings of current block. */
539+ Array<PrimExpr> iter_bindings;
540+ /* Generated loop vars of current block. */
541+ Array<Var> loop_vars;
542+ /* Whether the axis appears in the cuurent block. */
543+ std::unordered_set<const AxisNode*> in_block;
544+ /* An indicator that records whether there is reduction axis in current block. */
545+ bool has_reduction_var = false ;
546+
547+ auto UpdateStack = [&]() {
548+ block_iters_st.emplace_back (std::move (block_iters));
549+ iter_bindings_st.emplace_back (std::move (iter_bindings));
550+ loop_vars_st.emplace_back (std::move (loop_vars));
551+ if (init_set) {
552+ place_init_st.emplace_back (false );
553+ } else {
554+ place_init_st.emplace_back (has_reduction_var);
555+ init_set |= has_reduction_var;
556+ }
557+ };
558+
559+ for (int i = 0 ; i < n_iter; ++i) {
560+ SpIterVar sp_it_var = sp_block->sp_iter_vars [i];
561+ Axis axis = sp_it_var->axis ;
562+ auto parent = axis->GetParentAxis ();
563+ bool create_new_blk = false ;
564+ bool is_fixed_axis = axis->kind () == AxisKind::kDenseFixed || axis->kind () == AxisKind::kSparseFixed ;
565+ if (!is_fixed_axis && parent.defined ()) {
566+ const AxisNode* parent_node = parent.value ().get ();
567+ if (in_block.find (parent_node) != in_block.end ()) {
568+ /* parent node is in the current block, need to create new block. */
569+ create_new_blk = true ;
570+ } else if (in_stack.find (parent_node) != in_stack.end ()) {
571+ /* parent node is in the previous blocks in the stack, no need to create new block. */
572+ create_new_blk = false ;
561573 } else {
562- if (parent.defined ()) { // has parent
563- if (in_stack.find (parent.value ().get ()) == in_stack.end ()) { // parent not emitted yet
564- if (axis->kind () == AxisKind::kDenseVariable ||
565- axis->kind () == AxisKind::kSparseVariable ) { // is not fixed axis.
566- emit_iter_var = false ;
567- }
568- }
569- }
574+ CHECK (false ) << " The parent axis of " << axis->GetName () << " should appear before " << axis->GetName () << " when defining a sparse block." ;
570575 }
571- // LOG(INFO) << axis->name << " " << (parent.defined() ? parent.value()->name : "no-parent")
572- // << " " << emit_iter_var;
573- if (emit_iter_var) {
574- loop_vars.push_back (all_loop_vars[i]);
575- blk_axes.push_back (axis);
576- block_iters.push_back (SpIterVarToIterVar (sp_it_var, var_map));
577- iter_bindings.push_back (all_loop_vars[i]);
578- has_reduction_var |= sp_it_var->is_reduction ;
576+ }
577+ if (create_new_blk) {
578+ /* update in stack set. */
579+ for (const AxisNode* node : in_block) {
580+ in_stack.insert (node);
579581 }
582+ /* Update stack. */
583+ UpdateStack ();
584+ /* Reset block states. */
585+ loop_vars = {};
586+ block_iters = {};
587+ iter_bindings = {};
588+ has_reduction_var = false ;
589+ in_block.clear ();
580590 }
581591
582- /* Tag axes in current block as "in-stack". */
583- for (const Axis&& axis : blk_axes) {
584- in_stack.insert (axis.get ());
585- }
592+ loop_vars.push_back (all_loop_vars[i]);
593+ block_iters.push_back (SpIterVarToIterVar (sp_it_var, var_map));
594+ iter_bindings.push_back (all_loop_vars[i]);
595+ has_reduction_var |= sp_it_var->is_reduction ;
596+ in_block.insert (axis.get ());
597+ }
586598
587- /* Update stack. */
588- if (!block_iters.empty ()) {
589- block_iters_st.push (std::move (block_iters));
590- iter_bindings_st.push (std::move (iter_bindings));
591- loop_vars_st.push (std::move (loop_vars));
592- if (init_set) {
593- place_init_st.push (false );
594- } else {
595- place_init_st.push (has_reduction_var);
596- init_set |= has_reduction_var;
597- }
598- } else {
599- break ;
600- }
601- } while (true );
599+ // Update the last block.
600+ UpdateStack ();
602601
603602 // Step 5. Generate the read-region and write-retion of the block.
604603 Array<BufferRegion> reads{};
@@ -608,14 +607,14 @@ class IndexTransformer : public StmtExprMutator {
608607 // Step 6. Generate nested blocks and loops from innermost to outermost.
609608 int blk_counter = 0 ;
610609 while (!block_iters_st.empty ()) {
611- Array<IterVar> block_iters = std::move (block_iters_st.top ());
612- Array<PrimExpr> iter_bindings = std::move (iter_bindings_st.top ());
613- Array<Var> loop_vars = std::move (loop_vars_st.top ());
614- bool place_init = place_init_st.top ();
615- block_iters_st.pop ();
616- iter_bindings_st.pop ();
617- loop_vars_st.pop ();
618- place_init_st.pop ();
610+ Array<IterVar> block_iters = std::move (block_iters_st.back ());
611+ Array<PrimExpr> iter_bindings = std::move (iter_bindings_st.back ());
612+ Array<Var> loop_vars = std::move (loop_vars_st.back ());
613+ bool place_init = place_init_st.back ();
614+ block_iters_st.pop_back ();
615+ iter_bindings_st.pop_back ();
616+ loop_vars_st.pop_back ();
617+ place_init_st.pop_back ();
619618
620619 Map<String, ObjectRef> mapping;
621620 mapping.Set (" sparse" , Bool (true ));
0 commit comments