@@ -403,171 +403,137 @@ def get_numerical_idx(feature_sets, all_num_idx, all_unique_idx, logger=None):
403
403
404
404
405
405
class TrainMonitor :
406
- def __init__ (self , sign , snapshot_ratio , level = 3 , history_ratio = 3 , tolerance_ratio = 2 ,
406
+ def __init__ (self , sign , snapshot_ratio , history_ratio = 3 , tolerance_ratio = 2 ,
407
407
extension = 5 , std_floor = 0.001 , std_ceiling = 0.01 ):
408
- self .sign , self . flat_flag = sign , False
409
- self .snapshot_ratio , self . level = snapshot_ratio , max ( 1 , int ( level ))
408
+ self .sign = sign
409
+ self .snapshot_ratio = snapshot_ratio
410
410
self .n_history = int (snapshot_ratio * history_ratio )
411
- if level < 3 :
412
- if level == 1 :
413
- tolerance_ratio /= 2
414
411
self .n_tolerance = int (snapshot_ratio * tolerance_ratio )
415
412
self .extension = extension
416
413
self .std_floor , self .std_ceiling = std_floor , std_ceiling
417
- self ._run_id = - 1
418
- self ._rs = None
419
414
self ._scores = []
420
- self ._running_sum = self . _running_square_sum = self . _running_best = None
421
- self ._is_best = self ._over_fit_performance = self . _best_checkpoint_performance = None
422
- self ._descend_counter = self ._flat_counter = self . _over_fitting_flag = None
415
+ self .flat_flag = False
416
+ self ._is_best = self ._running_best = None
417
+ self ._running_sum = self ._running_square_sum = None
423
418
self ._descend_increment = self .n_history * extension / 30
424
419
425
- @property
426
- def rs (self ):
427
- return self ._rs
428
-
429
- @property
430
- def params (self ):
431
- return {
432
- "level" : self .level , "n_history" : self .n_history , "n_tolerance" : self .n_tolerance ,
433
- "extension" : self .extension , "std_floor" : self .std_floor , "std_ceiling" : self .std_ceiling
434
- }
435
-
436
- @property
437
- def descend_counter (self ):
438
- return self ._descend_counter
439
-
440
- @property
441
- def over_fitting_flag (self ):
442
- return self ._over_fitting_flag
443
-
444
- @property
445
- def n_epoch (self ):
446
- return int (len (self ._scores ) / self .snapshot_ratio )
447
-
448
- def _reset_rs (self ):
449
- self ._rs = {"terminate" : False , "save_checkpoint" : False , "save_best" : False , "info" : None }
450
-
451
- def reset_monitors (self ):
452
- self ._reset_rs ()
453
420
self ._over_fit_performance = math .inf
454
421
self ._best_checkpoint_performance = - math .inf
455
- self ._descend_counter = self ._flat_counter = self ._over_fitting_flag = 0
456
-
457
- def start_new_run (self ):
458
- self ._run_id += 1
459
- self .reset_monitors ()
460
- return self
422
+ self ._descend_counter = self ._flat_counter = self .over_fitting_flag = 0
423
+ self .info = {"terminate" : False , "save_checkpoint" : False , "save_best" : False , "info" : None }
461
424
462
425
def punish_extension (self ):
463
426
self ._descend_counter += self ._descend_increment
464
427
465
- def check (self , new_score ):
466
- scores = self ._scores
467
- scores .append (new_score * self .sign )
468
- n_history = min (self .n_history , len (scores ))
469
- if n_history == 1 :
470
- return self ._rs
471
- # Update running sum & square sum
472
- if n_history < self .n_history or len (scores ) == self .n_history :
428
+ def _update_running_info (self , last_score , n_history ):
429
+ if n_history <= self .n_history :
473
430
if self ._running_sum is None or self ._running_square_sum is None :
474
- self ._running_sum = scores [0 ] + scores [1 ]
475
- self ._running_square_sum = scores [0 ] ** 2 + scores [1 ] ** 2
431
+ self ._running_sum = self . _scores [0 ] + self . _scores [1 ]
432
+ self ._running_square_sum = self . _scores [0 ] ** 2 + self . _scores [1 ] ** 2
476
433
else :
477
- self ._running_sum += scores [ - 1 ]
478
- self ._running_square_sum += scores [ - 1 ] ** 2
434
+ self ._running_sum += last_score
435
+ self ._running_square_sum += last_score ** 2
479
436
else :
480
- previous = scores [- n_history - 1 ]
481
- self ._running_sum += scores [- 1 ] - previous
482
- self ._running_square_sum += scores [- 1 ] ** 2 - previous ** 2
483
- # Update running best
437
+ previous = self ._scores [- n_history - 1 ]
438
+ self ._running_sum += last_score - previous
439
+ self ._running_square_sum += last_score ** 2 - previous ** 2
484
440
if self ._running_best is None :
485
- if scores [0 ] > scores [1 ]:
441
+ if self . _scores [0 ] > self . _scores [1 ]:
486
442
improvement = 0
487
- self ._running_best , self ._is_best = scores [0 ], False
443
+ self ._running_best , self ._is_best = self . _scores [0 ], False
488
444
else :
489
- improvement = scores [1 ] - scores [0 ]
490
- self ._running_best , self ._is_best = scores [1 ], True
491
- elif self ._running_best > scores [ - 1 ] :
445
+ improvement = self . _scores [1 ] - self . _scores [0 ]
446
+ self ._running_best , self ._is_best = self . _scores [1 ], True
447
+ elif self ._running_best > last_score :
492
448
improvement = 0
493
449
self ._is_best = False
494
450
else :
495
- improvement = scores [ - 1 ] - self ._running_best
496
- self ._running_best = scores [ - 1 ]
451
+ improvement = last_score - self ._running_best
452
+ self ._running_best = last_score
497
453
self ._is_best = True
498
- # Check
499
- self ._rs ["save_checkpoint" ] = False
454
+ return improvement
455
+
456
+ def _handle_overfitting (self , last_score , res , std ):
457
+ if self ._descend_counter == 0 :
458
+ self .info ["save_best" ] = True
459
+ self ._over_fit_performance = last_score
460
+ self ._descend_counter += min (self .n_tolerance / 3 , - res / std )
461
+ self .over_fitting_flag = 1
462
+
463
+ def _handle_recovering (self , improvement , last_score , res , std ):
464
+ if res > 3 * std and self ._is_best and improvement > std :
465
+ self .info ["save_best" ] = True
466
+ new_counter = self ._descend_counter - res / std
467
+ if self ._descend_counter > 0 >= new_counter :
468
+ self ._over_fit_performance = math .inf
469
+ if last_score > self ._best_checkpoint_performance :
470
+ self ._best_checkpoint_performance = last_score
471
+ if last_score > self ._running_best - std :
472
+ self .info ["save_checkpoint" ] = True
473
+ self .info ["info" ] = (
474
+ "Current snapshot ({}) seems to be working well, "
475
+ "saving checkpoint in case we need to restore" .format (len (self ._scores ))
476
+ )
477
+ self .over_fitting_flag = 0
478
+ self ._descend_counter = max (new_counter , 0 )
479
+
480
+ def _handle_is_best (self ):
481
+ if self ._is_best :
482
+ self .info ["terminate" ] = False
483
+ if self .info ["save_best" ]:
484
+ self .info ["save_checkpoint" ] = True
485
+ self .info ["save_best" ] = False
486
+ self .info ["info" ] = (
487
+ "Current snapshot ({}) leads to best result we've ever had, "
488
+ "saving checkpoint since " .format (len (self ._scores ))
489
+ )
490
+ if self .over_fitting_flag :
491
+ self .info ["info" ] += "we've suffered from over-fitting"
492
+ else :
493
+ self .info ["info" ] += "performance has improved significantly"
494
+
495
+ def _handle_period (self , last_score ):
496
+ if len (self ._scores ) % self .snapshot_ratio == 0 and last_score > self ._best_checkpoint_performance :
497
+ self ._best_checkpoint_performance = last_score
498
+ self .info ["terminate" ] = False
499
+ self .info ["save_checkpoint" ] = True
500
+ self .info ["info" ] = (
501
+ "Current snapshot ({}) leads to best checkpoint we've ever had, "
502
+ "saving checkpoint in case we need to restore" .format (len (self ._scores ))
503
+ )
504
+
505
+ def check (self , new_metric ):
506
+ last_score = new_metric * self .sign
507
+ self ._scores .append (last_score )
508
+ n_history = min (self .n_history , len (self ._scores ))
509
+ if n_history == 1 :
510
+ return self .info
511
+ improvement = self ._update_running_info (last_score , n_history )
512
+ self .info ["save_checkpoint" ] = False
500
513
mean = self ._running_sum / n_history
501
514
std = math .sqrt (max (self ._running_square_sum / n_history - mean ** 2 , 1e-12 ))
502
515
std = min (std , self .std_ceiling )
503
516
if std < self .std_floor :
504
517
if self .flat_flag :
505
518
self ._flat_counter += 1
506
519
else :
507
- if self .level >= 3 or self ._is_best :
508
- self ._flat_counter = max (self ._flat_counter - 1 , 0 )
509
- elif self .flat_flag and self .level < 3 and not self ._is_best :
510
- self ._flat_counter += 1
511
- res = scores [- 1 ] - mean
512
- if res < - std and scores [- 1 ] < self ._over_fit_performance - std :
513
- if self ._descend_counter == 0 :
514
- self ._rs ["save_best" ] = True
515
- self ._over_fit_performance = scores [- 1 ]
516
- if self ._over_fit_performance > self ._running_best :
517
- self ._best_checkpoint_performance = self ._over_fit_performance
518
- self ._rs ["save_checkpoint" ] = True
519
- self ._rs ["info" ] = (
520
- "Current snapshot ({}) seems to be over-fitting, "
521
- "saving checkpoint in case we need to restore" .format (len (scores ) + self ._run_id )
522
- )
523
- self ._descend_counter += min (self .n_tolerance / 3 , - res / std )
524
- self ._over_fitting_flag = 1
520
+ self ._flat_counter = max (self ._flat_counter - 1 , 0 )
521
+ res = last_score - mean
522
+ if res < - std and last_score < self ._over_fit_performance - std :
523
+ self ._handle_overfitting (last_score , res , std )
525
524
elif res > std :
526
- if res > 3 * std and self ._is_best and improvement > std :
527
- self ._rs ["save_best" ] = True
528
- new_counter = self ._descend_counter - res / std
529
- if self ._descend_counter > 0 >= new_counter :
530
- self ._over_fit_performance = math .inf
531
- if scores [- 1 ] > self ._best_checkpoint_performance :
532
- self ._best_checkpoint_performance = scores [- 1 ]
533
- if scores [- 1 ] > self ._running_best - std :
534
- self ._rs ["save_checkpoint" ] = True
535
- self ._rs ["info" ] = (
536
- "Current snapshot ({}) seems to be working well, "
537
- "saving checkpoint in case we need to restore" .format (len (scores )+ self ._run_id )
538
- )
539
- self ._over_fitting_flag = 0
540
- self ._descend_counter = max (new_counter , 0 )
525
+ self ._handle_recovering (improvement , last_score , res , std )
541
526
if self ._flat_counter >= self .n_tolerance * self .n_history :
542
- self ._rs ["info" ] = "Performance not improving"
543
- self ._rs ["terminate" ] = True
544
- return self ._rs
527
+ self .info ["info" ] = "Performance not improving"
528
+ self .info ["terminate" ] = True
529
+ return self .info
545
530
if self ._descend_counter >= self .n_tolerance :
546
- self ._rs ["info" ] = "Over-fitting"
547
- self ._rs ["terminate" ] = True
548
- return self ._rs
549
- if self ._is_best :
550
- self ._rs ["terminate" ] = False
551
- if self ._rs ["save_best" ]:
552
- self ._rs ["save_checkpoint" ] = True
553
- self ._rs ["save_best" ] = False
554
- self ._rs ["info" ] = (
555
- "Current snapshot ({}) leads to best result we've ever had, "
556
- "saving checkpoint since " .format (len (scores ) + self ._run_id )
557
- )
558
- if self ._over_fitting_flag :
559
- self ._rs ["info" ] += "we've suffered from over-fitting"
560
- else :
561
- self ._rs ["info" ] += "performance has improved significantly"
562
- if len (scores ) % self .snapshot_ratio == 0 and scores [- 1 ] > self ._best_checkpoint_performance :
563
- self ._best_checkpoint_performance = scores [- 1 ]
564
- self ._rs ["terminate" ] = False
565
- self ._rs ["save_checkpoint" ] = True
566
- self ._rs ["info" ] = (
567
- "Current snapshot ({}) leads to best checkpoint we've ever had, "
568
- "saving checkpoint in case we need to restore" .format (len (scores ) + self ._run_id )
569
- )
570
- return self ._rs
531
+ self .info ["info" ] = "Over-fitting"
532
+ self .info ["terminate" ] = True
533
+ return self .info
534
+ self ._handle_is_best ()
535
+ self ._handle_period (last_score )
536
+ return self .info
571
537
572
538
573
539
class DNDF :
0 commit comments