@@ -246,6 +246,9 @@ def __init__(
246
246
# task_cts[i] saves how many times task i is tuned
247
247
self .task_cts = [0 for _ in range (len (self .tasks ))]
248
248
249
+ # task_best_cts[i] saves the round task i found the best latency
250
+ self .task_best_cts = [0 for _ in range (len (self .tasks ))]
251
+
249
252
# task_costs_history[i] saves the latency history of task i
250
253
self .task_costs_history = [[] for _ in range (len (self .tasks ))]
251
254
@@ -281,13 +284,14 @@ def tune(
281
284
search_policy = "default" ,
282
285
search_policy_params = None ,
283
286
adapative_training = False ,
287
+ per_task_early_stopping = None ,
284
288
):
285
289
"""Tune a batch of tasks together.
286
290
287
291
Parameters
288
292
----------
289
293
tune_option: TuningOptions
290
- The options of tuning
294
+ The tuning options applied to all tasks.
291
295
search_policy: : Union[str, List[SearchPolicy]] = "default"
292
296
The list of search policies.
293
297
If it is str,
@@ -299,10 +303,17 @@ def tune(
299
303
adapative_training : bool = False
300
304
Option used by XGBModel to reduce the model training frequency when there're
301
305
too many logs.
306
+ per_task_early_stopping : Optional[int]
307
+ Stop tuning a task early if getting no improvement after n measurements.
302
308
"""
303
309
# init members
304
310
self .tune_option = tune_option
305
- early_stopping = 1e20 if tune_option .early_stopping < 0 else tune_option .early_stopping
311
+ self .early_stopping_all = (
312
+ 1e20 if tune_option .early_stopping < 0 else tune_option .early_stopping
313
+ )
314
+ self .early_stopping_task = (
315
+ 1e20 if per_task_early_stopping is None else per_task_early_stopping
316
+ )
306
317
307
318
self .measurer = ProgramMeasurer (
308
319
tune_option .builder ,
@@ -417,13 +428,13 @@ def tune(
417
428
if self .cur_score < self .best_score :
418
429
self .best_score = self .cur_score
419
430
self .best_ct = self .ct
420
- elif self .ct - self .best_ct >= early_stopping and all (
431
+ elif self .ct - self .best_ct >= self . early_stopping_all and all (
421
432
cost < 1e9 for cost in self .best_costs
422
433
):
423
434
if self .tune_option .verbose >= 1 :
424
435
print (
425
436
"Stop early since no performance improvement in the last "
426
- + str (early_stopping )
437
+ + str (self . early_stopping_all )
427
438
+ " measurement trials."
428
439
)
429
440
break
@@ -439,15 +450,22 @@ def _tune_task(self, task_idx):
439
450
self .num_measures_per_round , self .measurer
440
451
)
441
452
453
+ self .task_cts [task_idx ] += 1
454
+
442
455
for res in measure_results :
443
456
cost = array_mean (res .costs )
444
457
if cost < self .best_costs [task_idx ]:
458
+ self .task_best_cts [task_idx ] = self .task_cts [task_idx ]
445
459
self .best_costs [task_idx ] = cost
446
460
447
- if len (measure_inputs ) == 0 :
461
+ # Stop tuning this task in the rest of the process if its search space has been
462
+ # fully explored or it has no improvement for a long while.
463
+ no_change_trials = (
464
+ self .task_cts [task_idx ] - self .task_best_cts [task_idx ]
465
+ ) * self .num_measures_per_round
466
+ if len (measure_inputs ) == 0 or no_change_trials > self .early_stopping_task :
448
467
self .dead_tasks .add (task_idx )
449
468
450
- self .task_cts [task_idx ] += 1
451
469
self .task_costs_history [task_idx ].append (self .best_costs [task_idx ])
452
470
453
471
self .ct += len (measure_inputs )
@@ -494,17 +512,24 @@ def _restore_status(self, log_file, num_measures_per_round):
494
512
if task_idx is None :
495
513
continue
496
514
515
+ self .task_cts [task_idx ] += 1
516
+
497
517
if res .error_no == 0 :
498
- self .best_costs [task_idx ] = min (self .best_costs [task_idx ], array_mean (res .costs ))
518
+ cost = array_mean (res .costs )
519
+ if self .best_costs [task_idx ] < cost :
520
+ self .best_costs [task_idx ] = cost
521
+ self .task_best_cts = self .task_cts [task_idx ]
499
522
500
- self .task_cts [task_idx ] += 1
523
+ for idx in range (len (self .tasks )):
524
+ if self .task_cts [idx ] - self .task_best_cts [idx ] > self .early_stopping_task :
525
+ self .dead_tasks .add (idx )
501
526
502
- for i in range (len (self .tasks )):
503
527
# The computation of taks_cts is just an estimation.
504
528
# The estimation may not be accurate if the log file is changed externally or
505
529
# `num_measures_per_round` is different from the last tuning.
506
- self .task_cts [i ] = int (self .task_cts [i ] / num_measures_per_round + 0.5 )
507
- self .task_costs_history [i ].append (self .best_costs [i ])
530
+ self .task_cts [idx ] = int (self .task_cts [idx ] / num_measures_per_round + 0.5 )
531
+ self .task_best_cts [idx ] = int (self .task_best_cts [idx ] / num_measures_per_round + 0.5 )
532
+ self .task_costs_history [idx ].append (self .best_costs [idx ])
508
533
509
534
self .cur_score = self ._compute_score (self .best_costs )
510
535
0 commit comments