4141namespace tvm {
4242namespace ansor {
4343
44- TVM_REGISTER_OBJECT_TYPE (MetaTileRewritePolicyNode);
44+ TVM_REGISTER_NODE_TYPE (MetaTileRewritePolicyNode);
45+ TVM_REGISTER_OBJECT_TYPE (PreAddCustomRuleNode);
4546
4647// All possible candidates for auto_unroll
4748const std::vector<int > MetaTileRewritePolicyNode::auto_unroll_configs{0 , 16 , 64 , 512 , 1024 };
@@ -241,7 +242,7 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector<State>* best_states,
241242
242243 // Synthesize meta structure
243244 std::vector<State> meta_structures;
244- SynthesizeMetaStructure (&meta_structures);
245+ GenerateMetaSketch (&meta_structures);
245246
246247 // PrintAllStates(meta_structures);
247248 // exit(0);
@@ -272,8 +273,8 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector<State>* best_states,
272273 RandomSampleStates (init_population, &rand_gen_, num_random_states * 10 , random_states);
273274}
274275
275- // The baseclass of derivation rules used in meta structure synthesis
276- class StructureSynthesisRule {
276+ // The baseclass of derivation rules used in meta sketch generation
277+ class SketchGenerationRule {
277278 public:
278279 enum ConditionEnum {
279280 kPass , kApply , kApplyAndSkipRest
@@ -345,7 +346,7 @@ static inline bool ShouldAlwaysBeInlined(
345346}
346347
347348// The rule that inlines simple elementwise ops
348- class RuleAlwaysInline : public StructureSynthesisRule {
349+ class RuleAlwaysInline : public SketchGenerationRule {
349350 public:
350351 ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
351352 const State& state, int stage_id) final {
@@ -362,7 +363,7 @@ class RuleAlwaysInline : public StructureSynthesisRule {
362363};
363364
364365// The rule that simply skip the current stage
365- class RuleSkipStage : public StructureSynthesisRule {
366+ class RuleSkipStage : public SketchGenerationRule {
366367 public:
367368 ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
368369 const State& state, int stage_id) final {
@@ -387,7 +388,7 @@ class RuleSkipStage : public StructureSynthesisRule {
387388};
388389
389390// The rule that performs multi-level tiling
390- class RuleMultiLevelTiling : public StructureSynthesisRule {
391+ class RuleMultiLevelTiling : public SketchGenerationRule {
391392 public:
392393 ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
393394 const State& state, int stage_id) final {
@@ -413,7 +414,7 @@ class RuleMultiLevelTiling : public StructureSynthesisRule {
413414};
414415
415416// The rule that performs multi-level tiling and fuses later consumers
416- class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule {
417+ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule {
417418 public:
418419 ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
419420 const State& state, int stage_id) final {
@@ -482,7 +483,7 @@ class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule {
482483};
483484
484485// The rule that adds a cache write stage
485- class RuleAddCacheWrite : public StructureSynthesisRule {
486+ class RuleAddCacheWrite : public SketchGenerationRule {
486487 public:
487488 ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
488489 const State& state, int stage_id) final {
@@ -515,7 +516,7 @@ class RuleAddCacheWrite : public StructureSynthesisRule {
515516// The rule that adds a cache read stage
516517// Mainly used for GPU cooperative fetching
517518// Currently only support 1 to 1 match cache read
518- class RuleAddCacheRead : public StructureSynthesisRule {
519+ class RuleAddCacheRead : public SketchGenerationRule {
519520 public:
520521 ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
521522 const State& state, int stage_id) final {
@@ -546,7 +547,7 @@ class RuleAddCacheRead : public StructureSynthesisRule {
546547};
547548
548549// The rule that adds rfactor stage
549- class RuleAddRfactor : public StructureSynthesisRule {
550+ class RuleAddRfactor : public SketchGenerationRule {
550551 public:
551552 ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
552553 const State& state, int stage_id) final {
@@ -610,7 +611,7 @@ class RuleAddRfactor : public StructureSynthesisRule {
610611 }
611612};
612613
613- void MetaTileRewritePolicyNode::SynthesizeMetaStructure (
614+ void MetaTileRewritePolicyNode::GenerateMetaSketch (
614615 std::vector<State>* out_states) {
615616 State init_state = cur_task_->compute_dag .GetInitState ();
616617 std::string cpu_multi_level_tiling_structure =
@@ -634,18 +635,22 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure(
634635 static RuleAddCacheWrite rule_add_cache_write_stage;
635636 static RuleAddCacheRead rule_add_cache_read_stage;
636637 static RuleAddRfactor rule_add_rfactor;
637- // We may apply and skip the rest when processing some rules,
638- // should take care of the rule vector order here
639- static std::vector<StructureSynthesisRule*> all_rules {
640- &rule_always_inline, &rule_add_cache_write_stage,
641- &rule_multi_level_tiling_with_fusion, &rule_multi_level_tiling,
642- &rule_add_rfactor, &rule_skip_stage
643- };
644- if (IS_GPU (cur_task_)) {
645- // Try cache read first before cache write
646- all_rules.insert (all_rules.begin () + 1 , &rule_add_cache_read_stage);
638+ if (sketch_rules.empty ()) {
639+ // We may apply and skip the rest when processing some rules,
640+ // should take care of the rule vector order here
641+ sketch_rules.push_back (&rule_always_inline);
642+ sketch_rules.push_back (&rule_add_cache_write_stage);
643+ sketch_rules.push_back (&rule_multi_level_tiling_with_fusion);
644+ sketch_rules.push_back (&rule_multi_level_tiling);
645+ sketch_rules.push_back (&rule_add_rfactor);
646+ sketch_rules.push_back (&rule_skip_stage);
647+ if (IS_GPU (cur_task_)) {
648+ // Try cache read first before cache write
649+ sketch_rules.insert (sketch_rules.begin () + 1 , &rule_add_cache_read_stage);
650+ }
651+ // TODO(xian): Add a new rule to try combination of multi-level
652+ // tiling + rfactor
647653 }
648- // TODO(xian): Add a new rule to try combination of multi-level tiling + rfactor
649654
650655 // Derivation rule based synthesizer
651656 while (!pnow->empty ()) {
@@ -661,15 +666,15 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure(
661666 }
662667
663668 // Try all derivation rules
664- for (const auto & rule : all_rules ) {
669+ for (const auto & rule : sketch_rules ) {
665670 auto rule_check = rule->MeetCondition (this , state, stage_id);
666- if (rule_check > StructureSynthesisRule ::ConditionEnum::kPass ) {
671+ if (rule_check > SketchGenerationRule ::ConditionEnum::kPass ) {
667672 for (const auto & pair : rule->Apply (this , state, stage_id)) {
668673 cur_stage_id_map[pair.first ] = pair.second ;
669674 pnext->push_back (pair.first );
670675 }
671676 // Skip the reset rules
672- if (rule_check == StructureSynthesisRule ::ConditionEnum::kApplyAndSkipRest ) {
677+ if (rule_check == SketchGenerationRule ::ConditionEnum::kApplyAndSkipRest ) {
673678 break ;
674679 }
675680 }
@@ -1444,12 +1449,71 @@ void MetaTileRewritePolicyNode::EvolutionarySearch(
14441449 << std::fixed << std::setprecision (2 ) << duration << std::endl;
14451450}
14461451
1452+ class RuleCustomSketch : public SketchGenerationRule {
1453+ public:
1454+ RuleCustomSketch (PackedFunc meet_condition_func, PackedFunc apply_func) :
1455+ meet_condition_func_ (meet_condition_func), apply_func_(apply_func) {}
1456+
1457+ inline ConditionEnum MeetCondition (const MetaTileRewritePolicyNode* policy,
1458+ const State& state, int stage_id) final {
1459+ auto ret = meet_condition_func_ (
1460+ tvm::runtime::GetRef<MetaTileRewritePolicy>(policy), state, stage_id);
1461+ if (ret.type_code () == 0 ) {
1462+ return ConditionEnum (static_cast <int >(ret));
1463+ } else {
1464+ return kApplyAndSkipRest ;
1465+ }
1466+ }
1467+
1468+ inline std::vector<std::pair<State, int > > Apply (
1469+ const MetaTileRewritePolicyNode* policy,
1470+ const State& state, int stage_id) final {
1471+ std::vector<std::pair<State, int > > ret;
1472+
1473+ Array<Array<ObjectRef>> apply_ret = apply_func_ (
1474+ tvm::runtime::GetRef<MetaTileRewritePolicy>(policy), state, stage_id);
1475+
1476+ for (const auto & item : apply_ret) {
1477+ CHECK_EQ (item.size (), 2 );
1478+ State state = Downcast<State>(item[0 ]);
1479+ auto next = item[1 ].as <IntImmNode>();
1480+ ret.emplace_back (state, next->value );
1481+ }
1482+ return ret;
1483+ }
1484+
1485+ private:
1486+ PackedFunc meet_condition_func_;
1487+ PackedFunc apply_func_;
1488+ };
1489+
1490+ SearchCallback PreAddCustomRuleNode::make (PackedFunc meet_condition_func,
1491+ PackedFunc apply_func) {
1492+ auto node = make_object<PreAddCustomRuleNode>();
1493+ node->meet_condition_func = meet_condition_func;
1494+ node->apply_func = apply_func;
1495+ return SearchCallback (node);
1496+ }
1497+
1498+ void PreAddCustomRuleNode::callback (SearchPolicyNode* policy) {
1499+ CHECK (policy->IsInstance <MetaTileRewritePolicyNode>());
1500+ auto meta_policy = dynamic_cast <MetaTileRewritePolicyNode*>(policy);
1501+ meta_policy->sketch_rules .emplace_back (
1502+ new RuleCustomSketch (meet_condition_func, apply_func));
1503+ StdCout (policy->verbose_ ) << " Custom sketch rule added." << std::endl;
1504+ }
1505+
14471506TVM_REGISTER_GLOBAL (" ansor.MetaTileRewritePolicy" )
14481507.set_body_typed([](CostModel program_cost_model,
14491508 Map<String, ObjectRef> params,
14501509 int seed){
14511510 return MetaTileRewritePolicyNode::make (program_cost_model, params, seed);
14521511});
14531512
1513+ TVM_REGISTER_GLOBAL (" ansor.PreAddCustomRule" )
1514+ .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) {
1515+ return PreAddCustomRuleNode::make (meet_condition_func, apply_func);
1516+ });
1517+
14541518} // namespace ansor
14551519} // namespace tvm
0 commit comments