@@ -75,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object {
75
75
Runner runner{nullptr };
76
76
/* ! \brief The database of the scheduler. */
77
77
Database database{nullptr };
78
+ /* ! \brief The maximum number of trials allowed. */
79
+ int max_trials;
78
80
/* ! \brief The cost model of the scheduler. */
79
81
Optional<CostModel> cost_model;
80
82
/* ! \brief The list of measure callbacks of the scheduler. */
81
83
Array<MeasureCallback> measure_callbacks;
84
+ /* ! \brief The number of trials already conducted. */
85
+ int num_trials_already;
82
86
83
87
/* ! \brief The default destructor. */
84
88
virtual ~TaskSchedulerNode () = default ;
@@ -88,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object {
88
92
v->Visit (" builder" , &builder);
89
93
v->Visit (" runner" , &runner);
90
94
v->Visit (" database" , &database);
95
+ v->Visit (" max_trials" , &max_trials);
91
96
v->Visit (" cost_model" , &cost_model);
92
97
v->Visit (" measure_callbacks" , &measure_callbacks);
98
+ v->Visit (" num_trials_already" , &num_trials_already);
93
99
}
94
100
95
101
/* ! \brief Auto-tuning. */
@@ -102,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object {
102
108
virtual void InitializeTask (int task_id);
103
109
104
110
/* !
105
- * \brief Set specific task to be stopped.
106
- * \param task_id The task id to be stopped.
107
- */
108
- virtual void SetTaskStopped (int task_id);
109
-
110
- /* !
111
- * \brief Check whether the task is running.
111
+ * \brief Touch the task and update its status
112
112
* \param task_id The task id to be checked.
113
- * \return Whether the task is running.
114
113
*/
115
- virtual bool IsTaskRunning (int task_id);
114
+ virtual void TouchTask (int task_id);
116
115
117
116
/* !
118
117
* \brief Wait until the task is finished.
119
118
* \param task_id The task id to be joined.
120
119
*/
121
- virtual void JoinRunningTask (int task_id);
120
+ virtual Array<RunnerResult> JoinRunningTask (int task_id);
122
121
123
122
/* !
124
123
* \brief Fetch the next task id.
@@ -142,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
142
141
using FInitializeTask = runtime::TypedPackedFunc<void (int )>;
143
142
144
143
/* !
145
- * \brief The function type of `SetTaskStopped` method.
146
- * \param task_id The task id to be stopped.
147
- */
148
- using FSetTaskStopped = runtime::TypedPackedFunc<void (int )>;
149
-
150
- /* !
151
- * \brief The function type of `IsTaskRunning` method.
144
+ * \brief The function type of `TouchTask` method.
152
145
* \param task_id The task id to be checked.
153
146
* \return Whether the task is running.
154
147
*/
155
- using FIsTaskRunning = runtime::TypedPackedFunc<bool (int )>;
148
+ using FTouchTask = runtime::TypedPackedFunc<void (int )>;
156
149
157
150
/* !
158
151
* \brief The function type of `JoinRunningTask` method.
159
152
* \param task_id The task id to be joined.
160
153
*/
161
- using FJoinRunningTask = runtime::TypedPackedFunc<void (int )>;
154
+ using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult> (int )>;
162
155
163
156
/* !
164
157
* \brief The function type of `NextTaskId` method.
@@ -170,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
170
163
FTune f_tune;
171
164
/* ! \brief The packed function to the `InitializeTask` function. */
172
165
FInitializeTask f_initialize_task;
173
- /* ! \brief The packed function to the `SetTaskStopped` function. */
174
- FSetTaskStopped f_set_task_stopped;
175
- /* ! \brief The packed function to the `IsTaskRunning` function. */
176
- FIsTaskRunning f_is_task_running;
166
+ /* ! \brief The packed function to the `TouchTask` function. */
167
+ FTouchTask f_touch_task;
177
168
/* ! \brief The packed function to the `JoinRunningTask` function. */
178
169
FJoinRunningTask f_join_running_task;
179
170
/* ! \brief The packed function to the `NextTaskId` function. */
@@ -182,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
182
173
void VisitAttrs (tvm::AttrVisitor* v) {
183
174
// `f_tune` is not visited
184
175
// `f_initialize_task` is not visited
185
- // `f_set_task_stopped` is not visited
186
- // `f_is_task_running` is not visited
176
+ // `f_touch_task` is not visited
187
177
// `f_join_running_task` is not visited
188
178
// `f_next_task_id` is not visited
189
179
}
@@ -204,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
204
194
}
205
195
}
206
196
207
- void SetTaskStopped (int task_id) final {
208
- if (f_set_task_stopped == nullptr ) {
209
- TaskSchedulerNode::SetTaskStopped (task_id);
210
- } else {
211
- f_set_task_stopped (task_id);
212
- }
213
- }
214
-
215
- bool IsTaskRunning (int task_id) final {
216
- if (f_is_task_running == nullptr ) {
217
- return TaskSchedulerNode::IsTaskRunning (task_id);
197
+ void TouchTask (int task_id) final {
198
+ if (f_touch_task == nullptr ) {
199
+ return TaskSchedulerNode::TouchTask (task_id);
218
200
} else {
219
- return f_is_task_running (task_id);
201
+ return f_touch_task (task_id);
220
202
}
221
203
}
222
204
223
- void JoinRunningTask (int task_id) final {
205
+ Array<RunnerResult> JoinRunningTask (int task_id) final {
224
206
if (f_join_running_task == nullptr ) {
225
207
return TaskSchedulerNode::JoinRunningTask (task_id);
226
208
} else {
@@ -249,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef {
249
231
* \param builder The builder of the scheduler.
250
232
* \param runner The runner of the scheduler.
251
233
* \param database The database of the scheduler.
234
+ * \param max_trials The maximum number of trials.
252
235
* \param cost_model The cost model of the scheduler.
253
236
* \param measure_callbacks The measure callbacks of the scheduler.
254
237
* \return The task scheduler created.
@@ -257,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef {
257
240
Builder builder, //
258
241
Runner runner, //
259
242
Database database, //
243
+ int max_trials, //
260
244
Optional<CostModel> cost_model, //
261
245
Optional<Array<MeasureCallback>> measure_callbacks);
246
+ /* !
247
+ * \brief Create a task scheduler that fetches tasks in a gradient based fashion.
248
+ * \param tasks The tasks to be tuned.
249
+ * \param task_weights The weights of each task.
250
+ * \param builder The builder of the scheduler.
251
+ * \param runner The runner of the scheduler.
252
+ * \param database The database of the scheduler.
253
+ * \param max_trials The maximum number of trials.
254
+ * \param cost_model The cost model of the scheduler.
255
+ * \param measure_callbacks The measure callbacks of the scheduler.
256
+ * \param alpha The parameter alpha to control gradient computation.
257
+ * \param window_size The parameter to control backward window size.
258
+ * \param seed The random seed.
259
+ * \return The task scheduler created.
260
+ */
261
+ TVM_DLL static TaskScheduler GradientBased (Array<TuneContext> tasks,
262
+ Array<FloatImm> task_weights, //
263
+ Builder builder, //
264
+ Runner runner, //
265
+ Database database, //
266
+ int max_trials, //
267
+ Optional<CostModel> cost_model, //
268
+ Optional<Array<MeasureCallback>> measure_callbacks, //
269
+ double alpha, //
270
+ int window_size, //
271
+ support::LinearCongruentialEngine::TRandState seed);
262
272
/* !
263
273
* \brief Create a task scheduler with customized methods on the python-side.
264
274
* \param tasks The tasks to be tuned.
265
275
* \param builder The builder of the scheduler.
266
276
* \param runner The runner of the scheduler.
267
277
* \param database The database of the scheduler.
278
+ * \param max_trials The maximum number of trials.
268
279
* \param cost_model The cost model of the scheduler.
269
280
* \param measure_callbacks The measure callbacks of the scheduler.
270
281
* \param f_tune The packed function of `Tune`.
271
282
* \param f_initialize_task The packed function of `InitializeTask`.
272
- * \param f_set_task_stopped The packed function of `SetTaskStopped`.
273
- * \param f_is_task_running The packed function of `IsTaskRunning`.
283
+ * \param f_touch_task The packed function of `TouchTask`.
274
284
* \param f_join_running_task The packed function of `JoinRunningTask`.
275
285
* \param f_next_task_id The packed function of `NextTaskId`.
276
286
* \return The task scheduler created.
@@ -280,12 +290,12 @@ class TaskScheduler : public runtime::ObjectRef {
280
290
Builder builder, //
281
291
Runner runner, //
282
292
Database database, //
293
+ int max_trials, //
283
294
Optional<CostModel> cost_model, //
284
295
Optional<Array<MeasureCallback>> measure_callbacks, //
285
296
PyTaskSchedulerNode::FTune f_tune, //
286
297
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
287
- PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
288
- PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
298
+ PyTaskSchedulerNode::FTouchTask f_touch_task, //
289
299
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
290
300
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
291
301
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS (TaskScheduler, ObjectRef, TaskSchedulerNode);
0 commit comments