|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | +#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ |
| 20 | +#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ |
| 21 | + |
| 22 | +#include <tvm/meta_schedule/arg_info.h> |
| 23 | +#include <tvm/meta_schedule/runner.h> |
| 24 | +#include <tvm/tir/schedule/schedule.h> |
| 25 | + |
| 26 | +namespace tvm { |
| 27 | +namespace meta_schedule { |
| 28 | + |
| 29 | +// Forward declaration |
| 30 | +class TuneContext; |
| 31 | + |
| 32 | +/*! \brief The schedule (with input shapes) to be measured. */ |
| 33 | +class MeasureCandidateNode : public runtime::Object { |
| 34 | + public: |
| 35 | + /*! \brief The schedule for measurement. */ |
| 36 | + tir::Schedule sch; |
| 37 | + /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ |
| 38 | + Array<ArgInfo> args_info; |
| 39 | + |
| 40 | + void VisitAttrs(tvm::AttrVisitor* v) { |
| 41 | + v->Visit("sch", &sch); |
| 42 | + v->Visit("args_info", &args_info); |
| 43 | + } |
| 44 | + |
| 45 | + static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; |
| 46 | + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); |
| 47 | +}; |
| 48 | + |
| 49 | +/*! |
| 50 | + * \brief Managed reference to MeasureCandidateNode. |
| 51 | + * \sa MeasureCandidateNode |
| 52 | + */ |
| 53 | +class MeasureCandidate : public runtime::ObjectRef { |
| 54 | + public: |
| 55 | + /*! |
| 56 | + * \brief Constructor of MeasureCandidate. |
| 57 | + * \param sch The schedule for measurement. |
| 58 | + * \param args_info The argument information, e.g., (shape, dtype) for tensors. |
| 59 | + */ |
| 60 | + TVM_DLL MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info); |
| 61 | + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); |
| 62 | +}; |
| 63 | + |
| 64 | +/*! |
| 65 | + * \brief The search strategy for measure candidates generation. |
| 66 | + * \note The relationship between SearchStrategy and other classes are as follows: |
| 67 | + ┌──────────────────────────────────────────────────────────────┐ |
| 68 | + ┌──┴───────────────────────────────────────────────────────────┐ │ |
| 69 | +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ |
| 70 | +│ ┌─────────────────────┐ │ │ │ |
| 71 | +│ │ │ Generate │ │ │ |
| 72 | +│ │ Space Generator ├──────────────┐ │ │ │ |
| 73 | +│ │ │ │ │ │ │ |
| 74 | +│ └─────────────────────┘ ▼ │ │ │ |
| 75 | +│ Design Space │ │ │ |
| 76 | +│ ┌─────────────────────┐ │ │ │ │ |
| 77 | +│ Generate │ │ Pretuning │ │ │ │ |
| 78 | +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ |
| 79 | +│ │ │ │ │ ├──┘ |
| 80 | +│ │ └─────────────────────┘ ├──┘ |
| 81 | +└────┼─────────────────────────────────────────────────────────┘ |
| 82 | + │ |
| 83 | + │ |
| 84 | +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ |
| 85 | +│ │ ┌───────────┐ │ |
| 86 | +│ │ Send to │ │ Send to │ |
| 87 | +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ |
| 88 | +│ Measure Candidate │ Builder │ │ Runner │ │ |
| 89 | +│ │ │ └───────────┘ │ │ |
| 90 | +│ │ ┌────────────┴────────┐ │ │ |
| 91 | +│ │ │ │ ┌───────────┐ │ │ |
| 92 | +│ └────►│ Task Scheduler │ │ │ │ │ |
| 93 | +│ │ │ │ Runner │◄─────────┘ │ |
| 94 | +│ └─────────────────────┘ │ │ │ |
| 95 | +│ ▲ └─────┬─────┘ │ |
| 96 | +│ │ │ │ |
| 97 | +│ └─── Runner Future ◄────┘ │ |
| 98 | +└─────────────────────────────────────────────────────────────────────┘ |
| 99 | +*/ |
| 100 | +class SearchStrategyNode : public runtime::Object { |
| 101 | + public: |
| 102 | + /*! \brief Virtual destructor */ |
| 103 | + virtual ~SearchStrategyNode() = default; |
| 104 | + |
| 105 | + /*! |
| 106 | + * \brief Initialize the search strategy with tuning context. |
| 107 | + * \param tune_context The tuning context for initialization. |
| 108 | + * \note This method is supposed to be called only once before every other method. |
| 109 | + */ |
| 110 | + virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; |
| 111 | + |
| 112 | + /*! |
| 113 | + * \brief Pre-tuning for the search strategy. |
| 114 | + * \param design_spaces The design spaces for pre-tuning. |
| 115 | + * \note Pre-tuning is supposed to be called before the tuning process and after the |
| 116 | + * initialization. Because the search strategy is stateful, we can always call pretuning |
| 117 | + * and reset the search strategy. |
| 118 | + */ |
| 119 | + virtual void PreTuning(const Array<tir::Schedule>& design_spaces) = 0; |
| 120 | + |
| 121 | + /*! |
| 122 | + * \brief Post-tuning for the search strategy. |
| 123 | + * \note Post-tuning is supposed to be called after the tuning process and before we reset the |
| 124 | + * search strategy with another pre-tuning. Post-tuning can be empty. |
| 125 | + */ |
| 126 | + virtual void PostTuning() = 0; |
| 127 | + |
| 128 | + /*! |
| 129 | + * \brief Generate measure candidates from design spaces for measurement. |
| 130 | + * \return The measure candidates generated, nullptr if finished. |
| 131 | + */ |
| 132 | + virtual Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() = 0; |
| 133 | + |
| 134 | + /*! |
| 135 | + * \brief Update the search strategy with measurement results. |
| 136 | + * \param results The measurement results from the runner. |
| 137 | + */ |
| 138 | + virtual void NotifyRunnerResults(const Array<RunnerResult>& results) = 0; |
| 139 | + |
| 140 | + static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; |
| 141 | + TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); |
| 142 | +}; |
| 143 | + |
| 144 | +/*! \brief The python side customizable class for measure candidate generation */ |
| 145 | +class PySearchStrategyNode : public SearchStrategyNode { |
| 146 | + public: |
| 147 | + /*! |
| 148 | + * \brief The function type of `InitializeWithTuneContext` method. |
| 149 | + * \param tune_context The tuning context for initialization. |
| 150 | + */ |
| 151 | + using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>; |
| 152 | + /*! |
| 153 | + * \brief The function type of `PreTuning` method. |
| 154 | + * \param design_spaces The design spaces for pre-tuning. |
| 155 | + */ |
| 156 | + using FPreTuning = runtime::TypedPackedFunc<void(const Array<tir::Schedule>&)>; |
| 157 | + /*! \brief The function type of `PostTuning` method. */ |
| 158 | + using FPostTuning = runtime::TypedPackedFunc<void()>; |
| 159 | + /*! |
| 160 | + * \brief The function type of `GenerateMeasureCandidates` method. |
| 161 | + * \return The measure candidates generated, nullptr if finished. |
| 162 | + */ |
| 163 | + using FGenerateMeasureCandidates = runtime::TypedPackedFunc<Optional<Array<MeasureCandidate>>()>; |
| 164 | + /*! |
| 165 | + * \brief The function type of `NotifyRunnerResults` method. |
| 166 | + * \param results The measurement results from the runner. |
| 167 | + */ |
| 168 | + using FNotifyRunnerResults = runtime::TypedPackedFunc<void(const Array<RunnerResult>&)>; |
| 169 | + |
| 170 | + /*! \brief The packed function to the `InitializeWithTuneContext` method. */ |
| 171 | + FInitializeWithTuneContext f_initialize_with_tune_context; |
| 172 | + /*! \brief The packed function to the `PreTuning` method. */ |
| 173 | + FPreTuning f_pre_tuning; |
| 174 | + /*! \brief The packed function to the `PostTuning` method. */ |
| 175 | + FPostTuning f_post_tuning; |
| 176 | + /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ |
| 177 | + FGenerateMeasureCandidates f_generate_measure_candidates; |
| 178 | + /*! \brief The packed function to the `NotifyRunnerResults` method. */ |
| 179 | + FNotifyRunnerResults f_notify_runner_results; |
| 180 | + |
| 181 | + void VisitAttrs(tvm::AttrVisitor* v) { |
| 182 | + // `f_initialize_with_tune_context` is not visited |
| 183 | + // `f_pre_tuning` is not visited |
| 184 | + // `f_post_tuning` is not visited |
| 185 | + // `f_generate_measure_candidates` is not visited |
| 186 | + // `f_notify_runner_results` is not visited |
| 187 | + } |
| 188 | + |
| 189 | + void InitializeWithTuneContext(const TuneContext& context) final { |
| 190 | + this->f_initialize_with_tune_context(context); |
| 191 | + } |
| 192 | + |
| 193 | + void PreTuning(const Array<tir::Schedule>& design_spaces) final { |
| 194 | + this->f_pre_tuning(design_spaces); |
| 195 | + } |
| 196 | + |
| 197 | + void PostTuning() final { this->f_post_tuning(); } |
| 198 | + |
| 199 | + Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final { |
| 200 | + return this->f_generate_measure_candidates(); |
| 201 | + } |
| 202 | + |
| 203 | + void NotifyRunnerResults(const Array<RunnerResult>& results) final { |
| 204 | + this->f_notify_runner_results(results); |
| 205 | + } |
| 206 | + |
| 207 | + static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; |
| 208 | + TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); |
| 209 | +}; |
| 210 | + |
| 211 | +/*! |
| 212 | + * \brief Managed reference to SearchStrategyNode. |
| 213 | + * \sa SearchStrategyNode |
| 214 | + */ |
| 215 | +class SearchStrategy : public runtime::ObjectRef { |
| 216 | + public: |
| 217 | + /*! |
| 218 | + * \brief Create a search strategy with customized methods on the python-side. |
| 219 | + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. |
| 220 | + * \param f_pre_tuning The packed function of `PreTuning`. |
| 221 | + * \param f_post_tuning The packed function of `PostTuning`. |
| 222 | + * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. |
| 223 | + * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. |
| 224 | + * \return The search strategy created. |
| 225 | + */ |
| 226 | + TVM_DLL static SearchStrategy PySearchStrategy( |
| 227 | + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // |
| 228 | + PySearchStrategyNode::FPreTuning f_pre_tuning, // |
| 229 | + PySearchStrategyNode::FPostTuning f_post_tuning, // |
| 230 | + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // |
| 231 | + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); |
| 232 | + |
| 233 | + /*! |
| 234 | + * \brief Constructor of replay trace search strategy. |
| 235 | + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. |
| 236 | + * \param num_trials_total The total number of trials for trace replaying. |
| 237 | + */ |
| 238 | + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); |
| 239 | + |
| 240 | + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); |
| 241 | +}; |
| 242 | + |
| 243 | +} // namespace meta_schedule |
| 244 | +} // namespace tvm |
| 245 | + |
| 246 | +#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ |
0 commit comments