2626namespace tvm {
2727namespace meta_schedule {
2828
29+ /* *************** TIR Nodes ****************/
30+ using tir::BlockNode;
31+ using tir::ForNode;
32+
2933bool IsTrivialBinding (const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
30- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
34+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
3135 tir::BlockRealize realize = tir::GetBlockRealize (block_sref);
3236 Array<tir::StmtSRef> loops = tir::GetLoops (block_sref);
3337 const Array<PrimExpr>& bindings = realize->iter_values ;
@@ -37,7 +41,7 @@ bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block
3741 int n = loops.size ();
3842 for (int i = 0 ; i < n; ++i) {
3943 const PrimExpr& bind = bindings[i];
40- const auto * loop = TVM_SREF_TO_FOR (loop, loops[i]);
44+ const ForNode * loop = TVM_SREF_TO_FOR (loop, loops[i]);
4145 if (bind.as <tir::VarNode>() != loop->loop_var .get ()) {
4246 return false ;
4347 }
@@ -51,7 +55,7 @@ bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_s
5155}
5256
5357bool IsLeafBlock (const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
54- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
58+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
5559 bool no_child = true ;
5660 tir::PreOrderVisit (block->body , [&no_child](const ObjectRef& obj) -> bool {
5761 if (!no_child) {
@@ -67,7 +71,7 @@ bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref
6771}
6872
6973Array<Integer> GetBlockVarTypes (const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
70- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
74+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
7175 Array<Integer> result;
7276 for (const tir::IterVar& iter_var : block->iter_vars ) {
7377 int iter_type = iter_var->iter_type ;
@@ -77,7 +81,7 @@ Array<Integer> GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtS
7781}
7882
7983bool IsSpatial (const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
80- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
84+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
8185 for (const tir::IterVar& iter_var : block->iter_vars ) {
8286 if (iter_var->iter_type != tir::IterVarType::kDataPar ) {
8387 return false ;
@@ -88,8 +92,8 @@ bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref)
8892
8993bool IsOutputBlock (const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
9094 tir::StmtSRef parent_sref = tir::GetScopeRoot (block_sref).value ();
91- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
92- const auto * parent = TVM_SREF_TO_BLOCK (parent, parent_sref);
95+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
96+ const BlockNode * parent = TVM_SREF_TO_BLOCK (parent, parent_sref);
9397 if (parent_sref->parent == nullptr ) {
9498 const tir::PrimFuncNode* func = tir::GetRootPrimFunc (self, parent_sref);
9599 for (const tir::BufferRegion& write : block->writes ) {
@@ -112,7 +116,7 @@ bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sr
112116}
113117
114118int CountOp (const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const Op& op) {
115- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
119+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
116120 int count = 0 ;
117121 tir::PostOrderVisit (block->body , [&count, &op](const ObjectRef& obj) {
118122 if (const auto * call = obj.as <tir::CallNode>()) {
@@ -125,7 +129,7 @@ int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, con
125129}
126130
127131bool HasBranch (const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
128- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
132+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
129133 bool has_branch = false ;
130134 arith::Analyzer analyzer;
131135 auto f_visit = [&has_branch, &analyzer](const ObjectRef& obj) -> bool {
@@ -214,8 +218,8 @@ bool IsElementWiseMatch(const tir::ScheduleState& self, const tir::StmtSRef& pro
214218 const tir::StmtSRef& consumer_sref) {
215219 // Assume consumer is the only consumer of the producer
216220 tir::StmtSRef parent_sref = tir::GetScopeRoot (producer_sref).value ();
217- const auto * producer = TVM_SREF_TO_BLOCK (producer, producer_sref);
218- const auto * consumer = TVM_SREF_TO_BLOCK (consumer, consumer_sref);
221+ const BlockNode * producer = TVM_SREF_TO_BLOCK (producer, producer_sref);
222+ const BlockNode * consumer = TVM_SREF_TO_BLOCK (consumer, consumer_sref);
219223 if (producer->writes .empty ()) {
220224 return false ;
221225 }
@@ -285,7 +289,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&
285289 if (!IsTrivialBinding (self, block_sref)) {
286290 return false ;
287291 }
288- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
292+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
289293 // Assume complete/reduction block
290294 if (block->writes .size () != 1 ) {
291295 return false ;
@@ -333,7 +337,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&
333337
334338bool IsStrictlyInlineable (const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
335339 static const Op& op_tir_exp = Op::Get (" tir.exp" );
336- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
340+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
337341 // Const tensors are strictly inlineable
338342 if (block->reads .empty ()) {
339343 return true ;
@@ -764,7 +768,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,
764768 const tir::StmtSRef& block_sref,
765769 int64_t max_parallel_extent,
766770 int64_t basic_parallel_extent) {
767- const auto * block = TVM_SREF_TO_BLOCK (block, block_sref);
771+ const BlockNode * block = TVM_SREF_TO_BLOCK (block, block_sref);
768772 Array<tir::StmtSRef> loops = tir::GetLoops (block_sref);
769773
770774 // Cond 1. The block is a reduction block and has trivial binding.
@@ -791,9 +795,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,
791795 }
792796
793797 // Cond 4.
794- const auto * loop_i = TVM_SREF_TO_FOR (loop_i, loops[i]);
798+ const ForNode * loop_i = TVM_SREF_TO_FOR (loop_i, loops[i]);
795799 if (i < static_cast <int >(loops.size ()) - 1 ) {
796- const auto * loop_i1 = TVM_SREF_TO_FOR (loop_i1, loops[i + 1 ]);
800+ const ForNode * loop_i1 = TVM_SREF_TO_FOR (loop_i1, loops[i + 1 ]);
797801 if (loop_i->body .get () != loop_i1) {
798802 return false ;
799803 }
0 commit comments