Skip to content

Commit d8313d0

Browse files
authored
[AutoScheduler] Support early_stopping per task (#7377)
* [AutoScheduler] Support early_stopping per task * address comment * fix test * Update python/tvm/auto_scheduler/task_scheduler.py * Update python/tvm/auto_scheduler/task_scheduler.py * trigger ci * trigger ci
1 parent 38c9eb1 commit d8313d0

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

python/tvm/auto_scheduler/task_scheduler.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def __init__(
246246
# task_cts[i] saves how many times task i is tuned
247247
self.task_cts = [0 for _ in range(len(self.tasks))]
248248

249+
# task_best_cts[i] saves the round task i found the best latency
250+
self.task_best_cts = [0 for _ in range(len(self.tasks))]
251+
249252
# task_costs_history[i] saves the latency history of task i
250253
self.task_costs_history = [[] for _ in range(len(self.tasks))]
251254

@@ -281,13 +284,14 @@ def tune(
281284
search_policy="default",
282285
search_policy_params=None,
283286
adapative_training=False,
287+
per_task_early_stopping=None,
284288
):
285289
"""Tune a batch of tasks together.
286290
287291
Parameters
288292
----------
289293
tune_option: TuningOptions
290-
The options of tuning
294+
The tuning options applied to all tasks.
291295
search_policy: : Union[str, List[SearchPolicy]] = "default"
292296
The list of search policies.
293297
If it is str,
@@ -299,10 +303,17 @@ def tune(
299303
adapative_training : bool = False
300304
Option used by XGBModel to reduce the model training frequency when there're
301305
too many logs.
306+
per_task_early_stopping : Optional[int]
307+
Stop tuning a task early if getting no improvement after n measurements.
302308
"""
303309
# init members
304310
self.tune_option = tune_option
305-
early_stopping = 1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping
311+
self.early_stopping_all = (
312+
1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping
313+
)
314+
self.early_stopping_task = (
315+
1e20 if per_task_early_stopping is None else per_task_early_stopping
316+
)
306317

307318
self.measurer = ProgramMeasurer(
308319
tune_option.builder,
@@ -417,13 +428,13 @@ def tune(
417428
if self.cur_score < self.best_score:
418429
self.best_score = self.cur_score
419430
self.best_ct = self.ct
420-
elif self.ct - self.best_ct >= early_stopping and all(
431+
elif self.ct - self.best_ct >= self.early_stopping_all and all(
421432
cost < 1e9 for cost in self.best_costs
422433
):
423434
if self.tune_option.verbose >= 1:
424435
print(
425436
"Stop early since no performance improvement in the last "
426-
+ str(early_stopping)
437+
+ str(self.early_stopping_all)
427438
+ " measurement trials."
428439
)
429440
break
@@ -439,15 +450,22 @@ def _tune_task(self, task_idx):
439450
self.num_measures_per_round, self.measurer
440451
)
441452

453+
self.task_cts[task_idx] += 1
454+
442455
for res in measure_results:
443456
cost = array_mean(res.costs)
444457
if cost < self.best_costs[task_idx]:
458+
self.task_best_cts[task_idx] = self.task_cts[task_idx]
445459
self.best_costs[task_idx] = cost
446460

447-
if len(measure_inputs) == 0:
461+
# Stop tuning this task in the rest of the process if its search space has been
462+
# fully explored or it has no improvement for a long while.
463+
no_change_trials = (
464+
self.task_cts[task_idx] - self.task_best_cts[task_idx]
465+
) * self.num_measures_per_round
466+
if len(measure_inputs) == 0 or no_change_trials > self.early_stopping_task:
448467
self.dead_tasks.add(task_idx)
449468

450-
self.task_cts[task_idx] += 1
451469
self.task_costs_history[task_idx].append(self.best_costs[task_idx])
452470

453471
self.ct += len(measure_inputs)
@@ -494,17 +512,24 @@ def _restore_status(self, log_file, num_measures_per_round):
494512
if task_idx is None:
495513
continue
496514

515+
self.task_cts[task_idx] += 1
516+
497517
if res.error_no == 0:
498-
self.best_costs[task_idx] = min(self.best_costs[task_idx], array_mean(res.costs))
518+
cost = array_mean(res.costs)
519+
if self.best_costs[task_idx] < cost:
520+
self.best_costs[task_idx] = cost
521+
self.task_best_cts = self.task_cts[task_idx]
499522

500-
self.task_cts[task_idx] += 1
523+
for idx in range(len(self.tasks)):
524+
if self.task_cts[idx] - self.task_best_cts[idx] > self.early_stopping_task:
525+
self.dead_tasks.add(idx)
501526

502-
for i in range(len(self.tasks)):
503527
# The computation of taks_cts is just an estimation.
504528
# The estimation may not be accurate if the log file is changed externally or
505529
# `num_measures_per_round` is different from the last tuning.
506-
self.task_cts[i] = int(self.task_cts[i] / num_measures_per_round + 0.5)
507-
self.task_costs_history[i].append(self.best_costs[i])
530+
self.task_cts[idx] = int(self.task_cts[idx] / num_measures_per_round + 0.5)
531+
self.task_best_cts[idx] = int(self.task_best_cts[idx] / num_measures_per_round + 0.5)
532+
self.task_costs_history[idx].append(self.best_costs[idx])
508533

509534
self.cur_score = self._compute_score(self.best_costs)
510535

0 commit comments

Comments
 (0)