@@ -95,18 +95,21 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
95
95
96
96
Array<tir::Schedule> GenerateDesignSpace (const IRModule& mod_) final {
97
97
using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>;
98
- tir::Schedule sch = tir::Schedule::Traced ( //
99
- /* mod=*/ mod_, //
100
- /* rand_state=*/ ForkSeed (&this ->rand_state_ ), //
101
- /* debug_mode=*/ tir::kVerifySRefTree | tir::kVerifyCachedFlags , //
98
+ tir::Schedule sch = tir::Schedule::Traced ( //
99
+ /* mod=*/ mod_, //
100
+ /* rand_state=*/ ForkSeed (&this ->rand_state_ ), //
101
+ /* debug_mode=*/ 0 , // tir::kVerifySRefTree | tir::kVerifyCachedFlags
102
102
/* error_render_level=*/ tir::ScheduleErrorRenderLevel::kDetail );
103
103
104
104
std::vector<ScheduleAndUnvisitedBlocks> stack;
105
105
Array<tir::Schedule> result;
106
106
Array<tir::BlockRV> all_blocks = BlockCollector::Collect (sch), func_blocks, non_func_blocks;
107
107
for (const tir::BlockRV& block_rv : all_blocks) {
108
- if (tir::GetAnn<String>(sch->GetSRef (block_rv), " schedule_rule" )) {
109
- func_blocks.push_back (block_rv);
108
+ if (Optional<String> custom_sch_rule_name_opt =
109
+ tir::GetAnn<String>(sch->GetSRef (block_rv), " schedule_rule" )) {
110
+ if (custom_sch_rule_name_opt.value () != " None" ) {
111
+ func_blocks.push_back (block_rv);
112
+ }
110
113
} else {
111
114
non_func_blocks.push_back (block_rv);
112
115
}
@@ -130,21 +133,19 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
130
133
blocks.pop_back ();
131
134
if (sch->HasBlock (block_rv)) {
132
135
// pick out the blocks with annotation for customized search space
133
- Optional<ObjectRef > custom_sch_rule_name_opt =
136
+ Optional<String > custom_sch_rule_name_opt =
134
137
tir::GetAnn<String>(sch->GetSRef (block_rv), " schedule_rule" );
135
- ICHECK (custom_sch_rule_name_opt.defined ());
136
- String custom_sch_rule_name = Downcast<String>(custom_sch_rule_name_opt.value ());
137
- if (custom_sch_rule_name != " None" ) {
138
- const auto * custom_sch_rule_func = runtime::Registry::Get (custom_sch_rule_name);
139
- CHECK (custom_sch_rule_func) << " The given custom schedule function is not defined!" ;
140
- Array<tir::Schedule> applied = (*custom_sch_rule_func)(sch, block_rv);
141
- for (const tir::Schedule& sch : applied) {
142
- stack.emplace_back (sch, blocks);
143
- }
144
- continue ;
138
+ ICHECK (custom_sch_rule_name_opt.defined () && custom_sch_rule_name_opt.value () != " None" );
139
+ String custom_sch_rule_name = custom_sch_rule_name_opt.value ();
140
+ const auto * custom_sch_rule_func = runtime::Registry::Get (custom_sch_rule_name);
141
+ CHECK (custom_sch_rule_func) << " The given custom schedule function is not defined!" ;
142
+ Array<tir::Schedule> applied = (*custom_sch_rule_func)(sch, block_rv);
143
+ for (const tir::Schedule& sch : applied) {
144
+ stack.emplace_back (sch, blocks);
145
145
}
146
+ } else {
147
+ stack.emplace_back (sch, blocks);
146
148
}
147
- stack.emplace_back (sch, blocks);
148
149
}
149
150
150
151
// Enumerate the schedule rules first because you can
0 commit comments