Skip to content

Commit a16ccf4

Browse files
zxybazhjunrushaospectrometerHBHMasterJH5574jinhongyii
authored
[Meta Schedule][M3a] SearchStrategy (#9132)
* Add c++ side SearchStrategy. Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> * Add python-side code & test. Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> * Add docs. * Minor fix. * Add workflow. * Add docs. * Fix docs. * Add notes. Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
1 parent e4946f4 commit a16ccf4

File tree

19 files changed

+1121
-6
lines changed

19 files changed

+1121
-6
lines changed

include/tvm/meta_schedule/builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
namespace tvm {
2626
namespace meta_schedule {
2727

28-
/*! \brief The builder's input. */
28+
/*! \brief The builder's input, containing an IRModule and the target. */
2929
class BuilderInputNode : public runtime::Object {
3030
public:
3131
/*! \brief The IRModule to be built. */
@@ -57,7 +57,7 @@ class BuilderInput : public runtime::ObjectRef {
5757
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode);
5858
};
5959

60-
/*! \brief The builder's output. */
60+
/*! \brief The builder's output, containing the artifact path or error message if any. */
6161
class BuilderResultNode : public runtime::Object {
6262
public:
6363
/*! \brief The path to the built artifact. */

include/tvm/meta_schedule/runner.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_META_SCHEDULE_RUNNER_H_
20+
#define TVM_META_SCHEDULE_RUNNER_H_
21+
22+
#include <tvm/ir/expr.h>
23+
24+
namespace tvm {
25+
namespace meta_schedule {
26+
27+
/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
28+
class RunnerResultNode : public runtime::Object {
29+
public:
30+
/*! \brief The run time in seconds. If not None, error_msg should be None. */
31+
Optional<Array<FloatImm>> run_secs;
32+
/*! \brief The error message, if any. If not None, run_secs should be None. */
33+
Optional<String> error_msg;
34+
35+
void VisitAttrs(tvm::AttrVisitor* v) {
36+
v->Visit("run_secs", &run_secs);
37+
v->Visit("error_msg", &error_msg);
38+
}
39+
40+
static constexpr const char* _type_key = "meta_schedule.RunnerResult";
41+
TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object);
42+
};
43+
44+
/*!
45+
* \brief Managed reference to RunnerResultNode
46+
* \sa RunnerResultNode
47+
*/
48+
class RunnerResult : public runtime::ObjectRef {
49+
public:
50+
/*!
51+
* \brief Constructor for RunnerResult.
52+
* \param run_secs The run time in seconds.
53+
* \param error_msg The error message, if any.
54+
*/
55+
TVM_DLL explicit RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String> error_msg);
56+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode);
57+
};
58+
59+
} // namespace meta_schedule
60+
} // namespace tvm
61+
62+
#endif // TVM_META_SCHEDULE_RUNNER_H_
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
20+
#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
21+
22+
#include <tvm/meta_schedule/arg_info.h>
23+
#include <tvm/meta_schedule/runner.h>
24+
#include <tvm/tir/schedule/schedule.h>
25+
26+
namespace tvm {
27+
namespace meta_schedule {
28+
29+
// Forward declaration
30+
class TuneContext;
31+
32+
/*! \brief The schedule (with input shapes) to be measured. */
33+
class MeasureCandidateNode : public runtime::Object {
34+
public:
35+
/*! \brief The schedule for measurement. */
36+
tir::Schedule sch;
37+
/*! \brief The argument information, e.g., (shape, dtype) for tensors. */
38+
Array<ArgInfo> args_info;
39+
40+
void VisitAttrs(tvm::AttrVisitor* v) {
41+
v->Visit("sch", &sch);
42+
v->Visit("args_info", &args_info);
43+
}
44+
45+
static constexpr const char* _type_key = "meta_schedule.MeasureCandidate";
46+
TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object);
47+
};
48+
49+
/*!
50+
* \brief Managed reference to MeasureCandidateNode.
51+
* \sa MeasureCandidateNode
52+
*/
53+
class MeasureCandidate : public runtime::ObjectRef {
54+
public:
55+
/*!
56+
* \brief Constructor of MeasureCandidate.
57+
* \param sch The schedule for measurement.
58+
* \param args_info The argument information, e.g., (shape, dtype) for tensors.
59+
*/
60+
TVM_DLL MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info);
61+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode);
62+
};
63+
64+
/*!
65+
* \brief The search strategy for measure candidates generation.
66+
* \note The relationship between SearchStrategy and other classes are as follows:
67+
┌──────────────────────────────────────────────────────────────┐
68+
┌──┴───────────────────────────────────────────────────────────┐ │
69+
┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
70+
│ ┌─────────────────────┐ │ │ │
71+
│ │ │ Generate │ │ │
72+
│ │ Space Generator ├──────────────┐ │ │ │
73+
│ │ │ │ │ │ │
74+
│ └─────────────────────┘ ▼ │ │ │
75+
│ Design Space │ │ │
76+
│ ┌─────────────────────┐ │ │ │ │
77+
│ Generate │ │ Pretuning │ │ │ │
78+
│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
79+
│ │ │ │ │ ├──┘
80+
│ │ └─────────────────────┘ ├──┘
81+
└────┼─────────────────────────────────────────────────────────┘
82+
83+
84+
┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
85+
│ │ ┌───────────┐ │
86+
│ │ Send to │ │ Send to │
87+
│ ▼ ┌─────────────►│ Builder ├──────────┐ │
88+
│ Measure Candidate │ Builder │ │ Runner │ │
89+
│ │ │ └───────────┘ │ │
90+
│ │ ┌────────────┴────────┐ │ │
91+
│ │ │ │ ┌───────────┐ │ │
92+
│ └────►│ Task Scheduler │ │ │ │ │
93+
│ │ │ │ Runner │◄─────────┘ │
94+
│ └─────────────────────┘ │ │ │
95+
│ ▲ └─────┬─────┘ │
96+
│ │ │ │
97+
│ └─── Runner Future ◄────┘ │
98+
└─────────────────────────────────────────────────────────────────────┘
99+
*/
100+
class SearchStrategyNode : public runtime::Object {
101+
public:
102+
/*! \brief Virtual destructor */
103+
virtual ~SearchStrategyNode() = default;
104+
105+
/*!
106+
* \brief Initialize the search strategy with tuning context.
107+
* \param tune_context The tuning context for initialization.
108+
* \note This method is supposed to be called only once before every other method.
109+
*/
110+
virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;
111+
112+
/*!
113+
* \brief Pre-tuning for the search strategy.
114+
* \param design_spaces The design spaces for pre-tuning.
115+
* \note Pre-tuning is supposed to be called before the tuning process and after the
116+
* initialization. Because the search strategy is stateful, we can always call pretuning
117+
* and reset the search strategy.
118+
*/
119+
virtual void PreTuning(const Array<tir::Schedule>& design_spaces) = 0;
120+
121+
/*!
122+
* \brief Post-tuning for the search strategy.
123+
* \note Post-tuning is supposed to be called after the tuning process and before we reset the
124+
* search strategy with another pre-tuning. Post-tuning can be empty.
125+
*/
126+
virtual void PostTuning() = 0;
127+
128+
/*!
129+
* \brief Generate measure candidates from design spaces for measurement.
130+
* \return The measure candidates generated, nullptr if finished.
131+
*/
132+
virtual Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() = 0;
133+
134+
/*!
135+
* \brief Update the search strategy with measurement results.
136+
* \param results The measurement results from the runner.
137+
*/
138+
virtual void NotifyRunnerResults(const Array<RunnerResult>& results) = 0;
139+
140+
static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
141+
TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object);
142+
};
143+
144+
/*! \brief The python side customizable class for measure candidate generation */
145+
class PySearchStrategyNode : public SearchStrategyNode {
146+
public:
147+
/*!
148+
* \brief The function type of `InitializeWithTuneContext` method.
149+
* \param tune_context The tuning context for initialization.
150+
*/
151+
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
152+
/*!
153+
* \brief The function type of `PreTuning` method.
154+
* \param design_spaces The design spaces for pre-tuning.
155+
*/
156+
using FPreTuning = runtime::TypedPackedFunc<void(const Array<tir::Schedule>&)>;
157+
/*! \brief The function type of `PostTuning` method. */
158+
using FPostTuning = runtime::TypedPackedFunc<void()>;
159+
/*!
160+
* \brief The function type of `GenerateMeasureCandidates` method.
161+
* \return The measure candidates generated, nullptr if finished.
162+
*/
163+
using FGenerateMeasureCandidates = runtime::TypedPackedFunc<Optional<Array<MeasureCandidate>>()>;
164+
/*!
165+
* \brief The function type of `NotifyRunnerResults` method.
166+
* \param results The measurement results from the runner.
167+
*/
168+
using FNotifyRunnerResults = runtime::TypedPackedFunc<void(const Array<RunnerResult>&)>;
169+
170+
/*! \brief The packed function to the `InitializeWithTuneContext` method. */
171+
FInitializeWithTuneContext f_initialize_with_tune_context;
172+
/*! \brief The packed function to the `PreTuning` method. */
173+
FPreTuning f_pre_tuning;
174+
/*! \brief The packed function to the `PostTuning` method. */
175+
FPostTuning f_post_tuning;
176+
/*! \brief The packed function to the `GenerateMeasureCandidates` method. */
177+
FGenerateMeasureCandidates f_generate_measure_candidates;
178+
/*! \brief The packed function to the `NotifyRunnerResults` method. */
179+
FNotifyRunnerResults f_notify_runner_results;
180+
181+
void VisitAttrs(tvm::AttrVisitor* v) {
182+
// `f_initialize_with_tune_context` is not visited
183+
// `f_pre_tuning` is not visited
184+
// `f_post_tuning` is not visited
185+
// `f_generate_measure_candidates` is not visited
186+
// `f_notify_runner_results` is not visited
187+
}
188+
189+
void InitializeWithTuneContext(const TuneContext& context) final {
190+
this->f_initialize_with_tune_context(context);
191+
}
192+
193+
void PreTuning(const Array<tir::Schedule>& design_spaces) final {
194+
this->f_pre_tuning(design_spaces);
195+
}
196+
197+
void PostTuning() final { this->f_post_tuning(); }
198+
199+
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
200+
return this->f_generate_measure_candidates();
201+
}
202+
203+
void NotifyRunnerResults(const Array<RunnerResult>& results) final {
204+
this->f_notify_runner_results(results);
205+
}
206+
207+
static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
208+
TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode);
209+
};
210+
211+
/*!
212+
* \brief Managed reference to SearchStrategyNode.
213+
* \sa SearchStrategyNode
214+
*/
215+
class SearchStrategy : public runtime::ObjectRef {
216+
public:
217+
/*!
218+
* \brief Create a search strategy with customized methods on the python-side.
219+
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
220+
* \param f_pre_tuning The packed function of `PreTuning`.
221+
* \param f_post_tuning The packed function of `PostTuning`.
222+
* \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`.
223+
* \param f_notify_runner_results The packed function of `NotifyRunnerResults`.
224+
* \return The search strategy created.
225+
*/
226+
TVM_DLL static SearchStrategy PySearchStrategy(
227+
PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
228+
PySearchStrategyNode::FPreTuning f_pre_tuning, //
229+
PySearchStrategyNode::FPostTuning f_post_tuning, //
230+
PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, //
231+
PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results);
232+
233+
/*!
234+
* \brief Constructor of replay trace search strategy.
235+
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
236+
* \param num_trials_total The total number of trials for trace replaying.
237+
*/
238+
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);
239+
240+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
241+
};
242+
243+
} // namespace meta_schedule
244+
} // namespace tvm
245+
246+
#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_

include/tvm/meta_schedule/space_generator.h

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,42 @@ namespace meta_schedule {
2828
// Forward declaration
2929
class TuneContext;
3030

31-
/*! \brief The abstract class for design space generation. */
31+
/*!
32+
* \brief The abstract class for design space generation.
33+
* \note The relationship between SpaceGenerator and other classes are as follows:
34+
┌──────────────────────────────────────────────────────────────┐
35+
┌──┴───────────────────────────────────────────────────────────┐ │
36+
┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
37+
│ ┌─────────────────────┐ │ │ │
38+
│ │ │ Generate │ │ │
39+
│ │ Space Generator ├──────────────┐ │ │ │
40+
│ │ │ │ │ │ │
41+
│ └─────────────────────┘ ▼ │ │ │
42+
│ Design Space │ │ │
43+
│ ┌─────────────────────┐ │ │ │ │
44+
│ Generate │ │ Pretuning │ │ │ │
45+
│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
46+
│ │ │ │ │ ├──┘
47+
│ │ └─────────────────────┘ ├──┘
48+
└────┼─────────────────────────────────────────────────────────┘
49+
50+
51+
┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
52+
│ │ ┌───────────┐ │
53+
│ │ Send to │ │ Send to │
54+
│ ▼ ┌─────────────►│ Builder ├──────────┐ │
55+
│ Measure Candidate │ Builder │ │ Runner │ │
56+
│ │ │ └───────────┘ │ │
57+
│ │ ┌────────────┴────────┐ │ │
58+
│ │ │ │ ┌───────────┐ │ │
59+
│ └────►│ Task Scheduler │ │ │ │ │
60+
│ │ │ │ Runner │◄─────────┘ │
61+
│ └─────────────────────┘ │ │ │
62+
│ ▲ └─────┬─────┘ │
63+
│ │ │ │
64+
│ └─── Runner Future ◄────┘ │
65+
└─────────────────────────────────────────────────────────────────────┘
66+
*/
3267
class SpaceGeneratorNode : public Object {
3368
public:
3469
/*! \brief Default destructor */
@@ -37,6 +72,7 @@ class SpaceGeneratorNode : public Object {
3772
/*!
3873
* \brief Initialize the design space generator with tuning context.
3974
* \param tune_context The tuning context for initialization.
75+
* \note This method is supposed to be called only once before every other method.
4076
*/
4177
virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;
4278

0 commit comments

Comments
 (0)