Skip to content

Commit d9c6c5d

Browse files
yzh119Siyuan Feng
authored andcommitted
Explicitly set types for TVM_SREF_TO_ERROR/FOR/BLOCK (apache#406)
1 parent d52d50d commit d9c6c5d

File tree

11 files changed

+50
-39
lines changed

11 files changed

+50
-39
lines changed

src/meta_schedule/analysis.cc

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626
namespace tvm {
2727
namespace meta_schedule {
2828

29+
/**************** TIR Nodes ****************/
30+
using tir::BlockNode;
31+
using tir::ForNode;
32+
2933
bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
30-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
34+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
3135
tir::BlockRealize realize = tir::GetBlockRealize(block_sref);
3236
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
3337
const Array<PrimExpr>& bindings = realize->iter_values;
@@ -37,7 +41,7 @@ bool IsTrivialBinding(const tir::ScheduleState& self, const tir::StmtSRef& block
3741
int n = loops.size();
3842
for (int i = 0; i < n; ++i) {
3943
const PrimExpr& bind = bindings[i];
40-
const auto* loop = TVM_SREF_TO_FOR(loop, loops[i]);
44+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]);
4145
if (bind.as<tir::VarNode>() != loop->loop_var.get()) {
4246
return false;
4347
}
@@ -51,7 +55,7 @@ bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_s
5155
}
5256

5357
bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
54-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
58+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
5559
bool no_child = true;
5660
tir::PreOrderVisit(block->body, [&no_child](const ObjectRef& obj) -> bool {
5761
if (!no_child) {
@@ -67,7 +71,7 @@ bool IsLeafBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref
6771
}
6872

6973
Array<Integer> GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
70-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
74+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
7175
Array<Integer> result;
7276
for (const tir::IterVar& iter_var : block->iter_vars) {
7377
int iter_type = iter_var->iter_type;
@@ -77,7 +81,7 @@ Array<Integer> GetBlockVarTypes(const tir::ScheduleState& self, const tir::StmtS
7781
}
7882

7983
bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
80-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
84+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
8185
for (const tir::IterVar& iter_var : block->iter_vars) {
8286
if (iter_var->iter_type != tir::IterVarType::kDataPar) {
8387
return false;
@@ -88,8 +92,8 @@ bool IsSpatial(const tir::ScheduleState& self, const tir::StmtSRef& block_sref)
8892

8993
bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
9094
tir::StmtSRef parent_sref = tir::GetScopeRoot(block_sref).value();
91-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
92-
const auto* parent = TVM_SREF_TO_BLOCK(parent, parent_sref);
95+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
96+
const BlockNode* parent = TVM_SREF_TO_BLOCK(parent, parent_sref);
9397
if (parent_sref->parent == nullptr) {
9498
const tir::PrimFuncNode* func = tir::GetRootPrimFunc(self, parent_sref);
9599
for (const tir::BufferRegion& write : block->writes) {
@@ -112,7 +116,7 @@ bool IsOutputBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sr
112116
}
113117

114118
int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const Op& op) {
115-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
119+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
116120
int count = 0;
117121
tir::PostOrderVisit(block->body, [&count, &op](const ObjectRef& obj) {
118122
if (const auto* call = obj.as<tir::CallNode>()) {
@@ -125,7 +129,7 @@ int CountOp(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, con
125129
}
126130

127131
bool HasBranch(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
128-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
132+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
129133
bool has_branch = false;
130134
arith::Analyzer analyzer;
131135
auto f_visit = [&has_branch, &analyzer](const ObjectRef& obj) -> bool {
@@ -214,8 +218,8 @@ bool IsElementWiseMatch(const tir::ScheduleState& self, const tir::StmtSRef& pro
214218
const tir::StmtSRef& consumer_sref) {
215219
// Assume consumer is the only consumer of the producer
216220
tir::StmtSRef parent_sref = tir::GetScopeRoot(producer_sref).value();
217-
const auto* producer = TVM_SREF_TO_BLOCK(producer, producer_sref);
218-
const auto* consumer = TVM_SREF_TO_BLOCK(consumer, consumer_sref);
221+
const BlockNode* producer = TVM_SREF_TO_BLOCK(producer, producer_sref);
222+
const BlockNode* consumer = TVM_SREF_TO_BLOCK(consumer, consumer_sref);
219223
if (producer->writes.empty()) {
220224
return false;
221225
}
@@ -285,7 +289,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&
285289
if (!IsTrivialBinding(self, block_sref)) {
286290
return false;
287291
}
288-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
292+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
289293
// Assume complete/reduction block
290294
if (block->writes.size() != 1) {
291295
return false;
@@ -333,7 +337,7 @@ bool NeedsMultiLevelTiling(const tir::ScheduleState& self, const tir::StmtSRef&
333337

334338
bool IsStrictlyInlineable(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) {
335339
static const Op& op_tir_exp = Op::Get("tir.exp");
336-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
340+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
337341
// Const tensors are strictly inlineable
338342
if (block->reads.empty()) {
339343
return true;
@@ -764,7 +768,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,
764768
const tir::StmtSRef& block_sref,
765769
int64_t max_parallel_extent,
766770
int64_t basic_parallel_extent) {
767-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
771+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
768772
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
769773

770774
// Cond 1. The block is a reduction block and has trivial binding.
@@ -791,9 +795,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,
791795
}
792796

793797
// Cond 4.
794-
const auto* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
798+
const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
795799
if (i < static_cast<int>(loops.size()) - 1) {
796-
const auto* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
800+
const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
797801
if (loop_i->body.get() != loop_i1) {
798802
return false;
799803
}

src/meta_schedule/space/search_rule.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
namespace tvm {
2727
namespace meta_schedule {
2828

29+
/**************** TIR Nodes ****************/
30+
using tir::ForNode;
31+
using tir::BlockNode;
32+
2933
/********** Constructors **********/
3034

3135
SearchRule::SearchRule(String name, SearchRuleNode::FApply apply) {
@@ -962,7 +966,7 @@ class RuleAddRFactor {
962966
const BlockRV& block_rv) const {
963967
// Check the conditions of the rule.
964968
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
965-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
969+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
966970
if (HasAnyAnn(block_sref)) {
967971
return {sch};
968972
}
@@ -1041,7 +1045,7 @@ class RuleCrossThreadReduction {
10411045

10421046
// Check the conditions of the rule.
10431047
const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
1044-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
1048+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
10451049
if (HasAnyAnn(block_sref)) {
10461050
return {sch};
10471051
}

src/meta_schedule/utils.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ using tir::AsArray;
4242
using tir::AsOptArray;
4343
using tir::AsVector;
4444

45+
/**************** TIR Nodes ****************/
46+
using tir::ForNode;
47+
4548
/*!
4649
* \brief Compute mean of a FloatImm array.
4750
* Taken from Ansor
@@ -179,7 +182,7 @@ inline Optional<tir::StmtSRef> FindBlockSRef(const tir::ScheduleState& sch, FPre
179182
/**************** TIR Annotation ****************/
180183

181184
inline bool HasBinding(const tir::StmtSRef& loop_sref, const String& thread_tag) {
182-
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
185+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
183186
if (!loop->thread_binding.defined()) {
184187
return false;
185188
}
@@ -191,7 +194,7 @@ inline bool HasBinding(const tir::StmtSRef& loop_sref, const String& thread_tag)
191194
}
192195

193196
inline Optional<String> GetBinding(const tir::StmtSRef& loop_sref) {
194-
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
197+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
195198
if (!loop->thread_binding.defined()) {
196199
return NullOpt;
197200
}

src/tir/schedule/analysis/analysis.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
110110
const StmtSRef& scope_root) {
111111
BlockScope scope = self->GetBlockScope(scope_root);
112112
// Cond 1. All block vars are data parallel
113-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
113+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
114114
for (const IterVar& iter_var : block->iter_vars) {
115115
if (iter_var->iter_type != kDataPar) {
116116
return false;
@@ -515,7 +515,7 @@ bool RegionCoveredConsumer(const ScheduleState& self, const StmtSRef& consumer_b
515515
if (consumer_block_sref->parent == nullptr) {
516516
return true;
517517
}
518-
const auto* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref);
518+
const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref);
519519
BlockScope scope = self->GetBlockScope(scope_root);
520520
// Step 1. Gather all the producers
521521
struct Producer {
@@ -534,7 +534,7 @@ bool RegionCoveredConsumer(const ScheduleState& self, const StmtSRef& consumer_b
534534
// i.e. the RAW predecessor is producer
535535
if (edge->kind == DepKind::kRAW) {
536536
const StmtSRef& producer_block_sref = edge->src;
537-
const auto* producer_block = TVM_SREF_TO_BLOCK(producer_block, producer_block_sref);
537+
const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_block, producer_block_sref);
538538
for (const BufferRegion& output_region : producer_block->writes) {
539539
const VarNode* buffer_var = output_region->buffer->data.get();
540540
buffer_producers[buffer_var].emplace_back(producer_block_sref, output_region);
@@ -808,7 +808,7 @@ bool CompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
808808
const StmtSRef& scope_root) {
809809
BlockScope scope = self->GetBlockScope(scope_root);
810810
// Cond 2. Check if all the block vars are data parallel
811-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
811+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
812812
for (const IterVar& iter_var : block->iter_vars) {
813813
if (iter_var->iter_type != kDataPar) {
814814
return false;
@@ -839,7 +839,7 @@ bool ReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
839839
// return false;
840840
// }
841841
// Cond 4. Check whether the block body has the init statement.
842-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
842+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
843843
if (!block->init.defined()) {
844844
return false;
845845
}
@@ -885,8 +885,8 @@ bool ReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
885885
bool CanMergeReduction(const ScheduleState& self, const StmtSRef& init_block_sref,
886886
const StmtSRef& update_block_sref, const StmtSRef& scope_root) {
887887
BlockScope scope = self->GetBlockScope(scope_root);
888-
const auto* init = TVM_SREF_TO_BLOCK(init, init_block_sref);
889-
const auto* update = TVM_SREF_TO_BLOCK(update, update_block_sref);
888+
const BlockNode* init = TVM_SREF_TO_BLOCK(init, init_block_sref);
889+
const BlockNode* update = TVM_SREF_TO_BLOCK(update, update_block_sref);
890890
// Cond 1. Check the binding of update block is valid
891891
if (!self->IsAffineBlockBinding(update_block_sref)) {
892892
return false;
@@ -939,7 +939,7 @@ IterVarType GetLoopIterType(const ScheduleState& self, const StmtSRef& loop_sref
939939
int n_spatial = 0;
940940
int n_reduce = 0;
941941
int n_other = 0;
942-
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
942+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
943943
const Var& loop_var = loop->loop_var;
944944
auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool {
945945
if (const auto* realize = obj.as<BlockRealizeNode>()) {
@@ -993,7 +993,7 @@ Array<StmtSRef> CollectComputeLocation(const ScheduleState& self, const StmtSRef
993993
result.reserve(loop_srefs.size());
994994
bool visited_reduce = false;
995995
for (const StmtSRef& loop_sref : loop_srefs) {
996-
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
996+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
997997
IterVarType iter_type = GetLoopIterType(self, loop_sref);
998998
if (iter_type == IterVarType::kDataPar) {
999999
if (visited_reduce) {

src/tir/schedule/concrete_schedule.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
338338
TVM_TIR_SCHEDULE_BEGIN();
339339
// Prepare for the splitting
340340
StmtSRef loop_sref = this->GetSRef(loop_rv);
341-
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
341+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
342342
PrimExpr len = loop->extent;
343343
// Find out the None
344344
int n = factor_rvs.size();

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ class ConcreteScheduleNode : public ScheduleNode {
215215

216216
inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const {
217217
StmtSRef sref = this->GetSRef(block_rv);
218-
const auto* block = TVM_SREF_TO_BLOCK(block, sref);
218+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref);
219219
return GetRef<Block>(block);
220220
}
221221

222222
inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const {
223223
StmtSRef sref = this->GetSRef(loop_rv);
224-
const auto* loop = TVM_SREF_TO_FOR(loop, sref);
224+
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
225225
return GetRef<For>(loop);
226226
}
227227

src/tir/schedule/primitive/reduction.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis)
390390
std::unordered_map<const VarNode*, For> loop_vars;
391391
Array<StmtSRef> loops = GetLoops(block_sref);
392392
for (const StmtSRef& l_sref : loops) {
393-
const auto* l = TVM_SREF_TO_FOR(l, l_sref);
393+
const ForNode* l = TVM_SREF_TO_FOR(l, l_sref);
394394
if (l == loop) {
395395
CHECK(!data_par_iters.count(l->loop_var.get()))
396396
<< "ValueError: The rfactor loop cannot be touched by data parallel block vars";
@@ -598,7 +598,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis)
598598
// IR replacement later.
599599
Optional<StmtSRef> replace_top = NullOpt;
600600
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
601-
const auto* l = TVM_SREF_TO_FOR(l, loops[i]);
601+
const ForNode* l = TVM_SREF_TO_FOR(l, loops[i]);
602602
if (l->body->IsInstance<SeqStmtNode>()) {
603603
ICHECK_NE(i, static_cast<int>(loops.size()) - 1)
604604
<< "ValueError: The body of the innermost loop must not be a SeqStmt";

src/tir/schedule/primitive/sampling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ std::vector<int64_t> SamplePerfectTile(tir::ScheduleState self, Sampler* sampler
2626
const tir::StmtSRef& loop_sref, int n,
2727
int max_innermost_factor,
2828
Optional<Array<Integer>>* decision) {
29-
const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
29+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
3030
int64_t extent = GetLoopIntExtent(loop);
3131
std::vector<int64_t> result;
3232
if (extent == -1) {

src/tir/schedule/state.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ void ScheduleStateNode::DebugVerify() const {
10461046
/**************** BlockInfo-related ****************/
10471047

10481048
BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
1049-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
1049+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
10501050
auto it = this->block_info.find(block_sref);
10511051
CHECK(it != this->block_info.end())
10521052
<< "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"

src/tir/schedule/utils.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ BufferRegion SubstituteBufferRegion(const BufferRegion& buffer_region,
173173
}
174174

175175
BlockRealize GetBlockRealize(const StmtSRef& block_sref) {
176-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
176+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
177177
// We cannot support getting the BlockRealize of the root block, since the parent sref of the root
178178
// block sref is `nullptr`.
179179
CHECK(block_sref->parent != nullptr)
@@ -443,7 +443,7 @@ void UpdateAffineFlag(ScheduleState self, const StmtSRef& block_sref) {
443443
return;
444444
}
445445
BlockRealize realize = GetBlockRealize(block_sref);
446-
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
446+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
447447
Map<Var, Range> loop_var_ranges;
448448
for (StmtSRefNode* loop_sref = block_sref->parent; loop_sref != nullptr;
449449
loop_sref = loop_sref->parent) {

0 commit comments

Comments
 (0)