Skip to content

Commit b0282ca

Browse files
zxybazhjunrushao
andcommitted
[MetaSchedule] Add Gradient Based Task Scheduler (apache#10366)
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
1 parent 656f7c5 commit b0282ca

34 files changed

+894
-323
lines changed

include/tvm/meta_schedule/search_strategy.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,21 @@ class SearchStrategy : public runtime::ObjectRef {
252252
/*!
253253
* \brief Constructor of replay trace search strategy.
254254
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
255-
* \param num_trials_total The total number of trials for trace replaying.
255+
* \param max_trials_per_task The total number of trials for trace replaying.
256256
*/
257-
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);
257+
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task);
258258

259259
/*!
260260
* \brief Constructor of replay func search strategy.
261261
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
262-
* \param num_trials_total The total number of trials for func replaying.
262+
* \param max_trials_per_task The total number of trials for func replaying.
263263
*/
264-
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);
264+
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task);
265265

266266
/*!
267267
* \brief Constructor of evolutionary search strategy.
268268
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
269-
* \param num_trials_total The total number of trials for evolutionary search.
269+
* \param max_trials_per_task The total number of trials for evolutionary search.
270270
* \param population_size The initial sample population.
271271
* \param init_measured_ratio The ratio of measures samples in initial population.
272272
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
@@ -276,7 +276,7 @@ class SearchStrategy : public runtime::ObjectRef {
276276
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
277277
*/
278278
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
279-
int num_trials_total, //
279+
int max_trials_per_task, //
280280
int population_size, //
281281
double init_measured_ratio, //
282282
int init_min_unmeasured, //

include/tvm/meta_schedule/task_scheduler.h

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object {
7575
Runner runner{nullptr};
7676
/*! \brief The database of the scheduler. */
7777
Database database{nullptr};
78+
/*! \brief The maximum number of trials allowed. */
79+
int max_trials;
7880
/*! \brief The cost model of the scheduler. */
7981
Optional<CostModel> cost_model;
8082
/*! \brief The list of measure callbacks of the scheduler. */
8183
Array<MeasureCallback> measure_callbacks;
84+
/*! \brief The number of trials already conducted. */
85+
int num_trials_already;
8286

8387
/*! \brief The default destructor. */
8488
virtual ~TaskSchedulerNode() = default;
@@ -88,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object {
8892
v->Visit("builder", &builder);
8993
v->Visit("runner", &runner);
9094
v->Visit("database", &database);
95+
v->Visit("max_trials", &max_trials);
9196
v->Visit("cost_model", &cost_model);
9297
v->Visit("measure_callbacks", &measure_callbacks);
98+
v->Visit("num_trials_already", &num_trials_already);
9399
}
94100

95101
/*! \brief Auto-tuning. */
@@ -102,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object {
102108
virtual void InitializeTask(int task_id);
103109

104110
/*!
105-
* \brief Set specific task to be stopped.
106-
* \param task_id The task id to be stopped.
107-
*/
108-
virtual void SetTaskStopped(int task_id);
109-
110-
/*!
111-
* \brief Check whether the task is running.
111+
* \brief Touch the task and update its status
112112
* \param task_id The task id to be checked.
113-
* \return Whether the task is running.
114113
*/
115-
virtual bool IsTaskRunning(int task_id);
114+
virtual void TouchTask(int task_id);
116115

117116
/*!
118117
* \brief Wait until the task is finished.
119118
* \param task_id The task id to be joined.
120119
*/
121-
virtual void JoinRunningTask(int task_id);
120+
virtual Array<RunnerResult> JoinRunningTask(int task_id);
122121

123122
/*!
124123
* \brief Fetch the next task id.
@@ -142,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
142141
using FInitializeTask = runtime::TypedPackedFunc<void(int)>;
143142

144143
/*!
145-
* \brief The function type of `SetTaskStopped` method.
146-
* \param task_id The task id to be stopped.
147-
*/
148-
using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;
149-
150-
/*!
151-
* \brief The function type of `IsTaskRunning` method.
144+
* \brief The function type of `TouchTask` method.
152145
* \param task_id The task id to be checked.
153146
* \return Whether the task is running.
154147
*/
155-
using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;
148+
using FTouchTask = runtime::TypedPackedFunc<void(int)>;
156149

157150
/*!
158151
* \brief The function type of `JoinRunningTask` method.
159152
* \param task_id The task id to be joined.
160153
*/
161-
using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;
154+
using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult>(int)>;
162155

163156
/*!
164157
* \brief The function type of `NextTaskId` method.
@@ -170,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
170163
FTune f_tune;
171164
/*! \brief The packed function to the `InitializeTask` function. */
172165
FInitializeTask f_initialize_task;
173-
/*! \brief The packed function to the `SetTaskStopped` function. */
174-
FSetTaskStopped f_set_task_stopped;
175-
/*! \brief The packed function to the `IsTaskRunning` function. */
176-
FIsTaskRunning f_is_task_running;
166+
/*! \brief The packed function to the `TouchTask` function. */
167+
FTouchTask f_touch_task;
177168
/*! \brief The packed function to the `JoinRunningTask` function. */
178169
FJoinRunningTask f_join_running_task;
179170
/*! \brief The packed function to the `NextTaskId` function. */
@@ -182,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
182173
void VisitAttrs(tvm::AttrVisitor* v) {
183174
// `f_tune` is not visited
184175
// `f_initialize_task` is not visited
185-
// `f_set_task_stopped` is not visited
186-
// `f_is_task_running` is not visited
176+
// `f_touch_task` is not visited
187177
// `f_join_running_task` is not visited
188178
// `f_next_task_id` is not visited
189179
}
@@ -204,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
204194
}
205195
}
206196

207-
void SetTaskStopped(int task_id) final {
208-
if (f_set_task_stopped == nullptr) {
209-
TaskSchedulerNode::SetTaskStopped(task_id);
210-
} else {
211-
f_set_task_stopped(task_id);
212-
}
213-
}
214-
215-
bool IsTaskRunning(int task_id) final {
216-
if (f_is_task_running == nullptr) {
217-
return TaskSchedulerNode::IsTaskRunning(task_id);
197+
void TouchTask(int task_id) final {
198+
if (f_touch_task == nullptr) {
199+
return TaskSchedulerNode::TouchTask(task_id);
218200
} else {
219-
return f_is_task_running(task_id);
201+
return f_touch_task(task_id);
220202
}
221203
}
222204

223-
void JoinRunningTask(int task_id) final {
205+
Array<RunnerResult> JoinRunningTask(int task_id) final {
224206
if (f_join_running_task == nullptr) {
225207
return TaskSchedulerNode::JoinRunningTask(task_id);
226208
} else {
@@ -249,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef {
249231
* \param builder The builder of the scheduler.
250232
* \param runner The runner of the scheduler.
251233
* \param database The database of the scheduler.
234+
* \param max_trials The maximum number of trials.
252235
* \param cost_model The cost model of the scheduler.
253236
* \param measure_callbacks The measure callbacks of the scheduler.
254237
* \return The task scheduler created.
@@ -257,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef {
257240
Builder builder, //
258241
Runner runner, //
259242
Database database, //
243+
int max_trials, //
260244
Optional<CostModel> cost_model, //
261245
Optional<Array<MeasureCallback>> measure_callbacks);
246+
/*!
247+
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
248+
* \param tasks The tasks to be tuned.
249+
* \param task_weights The weights of each task.
250+
* \param builder The builder of the scheduler.
251+
* \param runner The runner of the scheduler.
252+
* \param database The database of the scheduler.
253+
* \param max_trials The maximum number of trials.
254+
* \param cost_model The cost model of the scheduler.
255+
* \param measure_callbacks The measure callbacks of the scheduler.
256+
* \param alpha The parameter alpha to control gradient computation.
257+
* \param window_size The parameter to control backward window size.
258+
* \param seed The random seed.
259+
* \return The task scheduler created.
260+
*/
261+
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks,
262+
Array<FloatImm> task_weights, //
263+
Builder builder, //
264+
Runner runner, //
265+
Database database, //
266+
int max_trials, //
267+
Optional<CostModel> cost_model, //
268+
Optional<Array<MeasureCallback>> measure_callbacks, //
269+
double alpha, //
270+
int window_size, //
271+
support::LinearCongruentialEngine::TRandState seed);
262272
/*!
263273
* \brief Create a task scheduler with customized methods on the python-side.
264274
* \param tasks The tasks to be tuned.
265275
* \param builder The builder of the scheduler.
266276
* \param runner The runner of the scheduler.
267277
* \param database The database of the scheduler.
278+
* \param max_trials The maximum number of trials.
268279
* \param cost_model The cost model of the scheduler.
269280
* \param measure_callbacks The measure callbacks of the scheduler.
270281
* \param f_tune The packed function of `Tune`.
271282
* \param f_initialize_task The packed function of `InitializeTask`.
272-
* \param f_set_task_stopped The packed function of `SetTaskStopped`.
273-
* \param f_is_task_running The packed function of `IsTaskRunning`.
283+
* \param f_touch_task The packed function of `TouchTask`.
274284
* \param f_join_running_task The packed function of `JoinRunningTask`.
275285
* \param f_next_task_id The packed function of `NextTaskId`.
276286
* \return The task scheduler created.
@@ -280,12 +290,12 @@ class TaskScheduler : public runtime::ObjectRef {
280290
Builder builder, //
281291
Runner runner, //
282292
Database database, //
293+
int max_trials, //
283294
Optional<CostModel> cost_model, //
284295
Optional<Array<MeasureCallback>> measure_callbacks, //
285296
PyTaskSchedulerNode::FTune f_tune, //
286297
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
287-
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
288-
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
298+
PyTaskSchedulerNode::FTouchTask f_touch_task, //
289299
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
290300
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
291301
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);

include/tvm/meta_schedule/tune_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class TuneContextNode : public runtime::Object {
6262
/*! \brief The task scheduler that owns the tune context */
6363
const TaskSchedulerNode* task_scheduler;
6464
/*! \brief Whether the tuning task has been stopped or finished. */
65-
bool is_stopped;
65+
bool is_terminated;
6666
/*! \brief The measure candidates. */
6767
Optional<Array<MeasureCandidate>> measure_candidates;
6868
/*! \brief The building results. */
@@ -81,7 +81,7 @@ class TuneContextNode : public runtime::Object {
8181
v->Visit("task_name", &task_name);
8282
v->Visit("rand_state", &rand_state);
8383
v->Visit("num_threads", &num_threads);
84-
v->Visit("is_stopped", &is_stopped);
84+
v->Visit("is_terminated", &is_terminated);
8585
v->Visit("builder_results", &builder_results);
8686
v->Visit("runner_futures", &runner_futures);
8787
v->Visit("measure_candidates", &measure_candidates);

include/tvm/support/random_engine.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,15 @@ class LinearCongruentialEngine {
9999
* \brief Change the start random state of RNG with the seed of a new random state value.
100100
* \param rand_state The random state given in result_type.
101101
*/
102-
void Seed(TRandState rand_state = 1) {
103-
ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!";
104-
rand_state %= modulus; // Make sure the seed is within the range of modulus.
105-
if (rand_state == 0)
106-
rand_state = 1; // Avoid getting all 0 given the current parameter set.
107-
else if (rand_state < 0)
108-
rand_state += modulus; // Make sure the rand state is non-negative.
109-
ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null.
110-
*rand_state_ptr_ = rand_state; // Change pointed random state to given random state value.
102+
void Seed(TRandState rand_state) {
103+
if (rand_state == -1) {
104+
rand_state = DeviceRandom();
105+
} else if (rand_state == 0) {
106+
rand_state = 1;
107+
}
108+
ICHECK(rand_state >= 0) << "The random state should be nonnegative";
109+
ICHECK(rand_state_ptr_ != nullptr);
110+
*rand_state_ptr_ = rand_state % modulus;
111111
}
112112

113113
/*!

include/tvm/tir/schedule/schedule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class ScheduleNode : public runtime::Object {
128128
* \brief Seed the randomness
129129
* \param seed The new random seed, -1 if use device random, otherwise non-negative
130130
*/
131-
virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
131+
virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
132132
/*! \brief Fork the random state */
133133
virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
134134

python/tvm/meta_schedule/search_strategy/evolutionary_search.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class EvolutionarySearch(SearchStrategy):
3434
----------
3535
num_trials_per_iter : int
3636
Number of trials per iteration.
37-
num_trials_total : int
37+
max_trials_per_task : int
3838
Total number of trials.
3939
population_size : int
4040
The initial population of traces from measured samples and randomly generated samples.
@@ -53,7 +53,7 @@ class EvolutionarySearch(SearchStrategy):
5353
"""
5454

5555
num_trials_per_iter: int
56-
num_trials_total: int
56+
max_trials_per_task: int
5757
population_size: int
5858
init_measured_ratio: int
5959
init_min_unmeasured: int
@@ -66,7 +66,7 @@ def __init__(
6666
self,
6767
*,
6868
num_trials_per_iter: int,
69-
num_trials_total: int,
69+
max_trials_per_task: int,
7070
population_size: int,
7171
init_measured_ratio: float,
7272
init_min_unmeasured: int,
@@ -79,7 +79,7 @@ def __init__(
7979
self.__init_handle_by_constructor__(
8080
_ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member
8181
num_trials_per_iter,
82-
num_trials_total,
82+
max_trials_per_task,
8383
population_size,
8484
init_measured_ratio,
8585
init_min_unmeasured,
@@ -94,7 +94,8 @@ class EvolutionarySearchConfig(NamedTuple):
9494
"""Configuration for EvolutionarySearch"""
9595

9696
num_trials_per_iter: int
97-
num_trials_total: int
97+
max_trials_per_task: int
98+
max_trials_global: int
9899
population_size: int = 2048
99100
init_measured_ratio: float = 0.2
100101
init_min_unmeasured: int = 50
@@ -106,7 +107,7 @@ class EvolutionarySearchConfig(NamedTuple):
106107
def create_strategy(self) -> EvolutionarySearch:
107108
return EvolutionarySearch(
108109
num_trials_per_iter=self.num_trials_per_iter,
109-
num_trials_total=self.num_trials_total,
110+
max_trials_per_task=self.max_trials_per_task,
110111
population_size=self.population_size,
111112
init_measured_ratio=self.init_measured_ratio,
112113
init_min_unmeasured=self.init_min_unmeasured,

python/tvm/meta_schedule/search_strategy/replay_func.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,32 @@ class ReplayFunc(SearchStrategy):
3333
----------
3434
num_trials_per_iter : int
3535
Number of trials per iteration.
36-
num_trials_total : int
36+
max_trials_per_task : int
3737
Total number of trials.
3838
"""
3939

4040
num_trials_per_iter: int
41-
num_trials_total: int
41+
max_trials_per_task: int
4242

4343
def __init__(
4444
self,
4545
num_trials_per_iter: int,
46-
num_trials_total: int,
46+
max_trials_per_task: int,
4747
):
4848
"""Constructor"""
4949
self.__init_handle_by_constructor__(
5050
_ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member
5151
num_trials_per_iter,
52-
num_trials_total,
52+
max_trials_per_task,
5353
)
5454

5555

5656
class ReplayFuncConfig(NamedTuple):
5757
"""Configuration for ReplayFunc"""
5858

5959
num_trials_per_iter: int
60-
num_trials_total: int
60+
max_trials_per_task: int
61+
max_trials_global: int
6162

6263
def create_strategy(self) -> ReplayFunc:
63-
return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
64+
return ReplayFunc(self.num_trials_per_iter, self.max_trials_per_task)

0 commit comments

Comments
 (0)