Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,7 @@ class BlockCollector : public tir::StmtVisitor {
blocks_to_collect_.clear();
VisitStmt(func->body);
for (const String& block_name : blocks_to_collect_) {
tir::BlockRV block_rv = sch_->GetBlock(block_name, func_name_);
// pick out the blocks with annotation for customized search space
if (Optional<ObjectRef> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch_->GetSRef(block_rv), "schedule_rule")) {
String custom_sch_rule_name = Downcast<String>(custom_sch_rule_name_opt.value());
if (custom_sch_rule_name != "None") {
const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name);
(*custom_sch_rule_func)(sch_, block_rv);
}
} else {
results_.push_back(block_rv);
}
results_.push_back(sch_->GetBlock(block_name, func_name_));
}
}
}
Expand Down Expand Up @@ -109,17 +98,61 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/0, //
/*debug_mode=*/0, // tir::kVerifySRefTree | tir::kVerifyCachedFlags
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);

std::vector<ScheduleAndUnvisitedBlocks> stack;
Array<tir::Schedule> result{sch};
Array<tir::Schedule> result;
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch), func_blocks, non_func_blocks;
for (const tir::BlockRV& block_rv : all_blocks) {
if (Optional<String> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch->GetSRef(block_rv), "schedule_rule")) {
if (custom_sch_rule_name_opt.value() != "None") {
func_blocks.push_back(block_rv);
}
} else {
non_func_blocks.push_back(block_rv);
}
}

// only do this once for schedule rules on block annotations
stack.emplace_back(sch, func_blocks);
while (!stack.empty()) {
// get the stack.top()
tir::Schedule sch;
Array<tir::BlockRV> blocks;
std::tie(sch, blocks) = stack.back();
stack.pop_back();
// if all blocks are visited
if (blocks.empty()) {
result.push_back(sch);
continue;
}
// otherwise, get the last block that is not visited
tir::BlockRV block_rv = blocks.back();
blocks.pop_back();
if (sch->HasBlock(block_rv)) {
// pick out the blocks with annotation for customized search space
Optional<String> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch->GetSRef(block_rv), "schedule_rule");
ICHECK(custom_sch_rule_name_opt.defined() && custom_sch_rule_name_opt.value() != "None");
String custom_sch_rule_name = custom_sch_rule_name_opt.value();
const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name);
CHECK(custom_sch_rule_func) << "The given custom schedule function is not defined!";
Array<tir::Schedule> applied = (*custom_sch_rule_func)(sch, block_rv);
for (const tir::Schedule& sch : applied) {
stack.emplace_back(sch, blocks);
}
} else {
stack.emplace_back(sch, blocks);
}
}

// Enumerate the schedule rules first because you can
// always concat multiple schedule rules as one
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch);
for (ScheduleRule sch_rule : sch_rules_) {
for (const tir::Schedule& sch : result) {
stack.emplace_back(sch, all_blocks);
stack.emplace_back(sch, non_func_blocks);
}
result.clear();

Expand Down
Loading