Skip to content

Commit 3cef514

Browse files
committed
Update _Dist/NeuralNetworks
1 parent 57d0728 commit 3cef514

File tree

4 files changed

+106
-140
lines changed

4 files changed

+106
-140
lines changed

_Dist/NeuralNetworks/Base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def fit(self, x, y, x_test=None, y_test=None, sample_weights=None, names=("train
620620
if time_limit <= 0:
621621
print("Time limit exceeded before training process started")
622622
return self
623-
monitor = TrainMonitor(Metrics.sign_dict[self._metric_name], snapshot_ratio).start_new_run()
623+
monitor = TrainMonitor(Metrics.sign_dict[self._metric_name], snapshot_ratio)
624624

625625
if verbose >= 2:
626626
prepare_tensorboard_verbose(self._sess)
@@ -670,14 +670,14 @@ def fit(self, x, y, x_test=None, y_test=None, sample_weights=None, names=("train
670670
break
671671
self.log["epoch_loss"].append(epoch_loss / (j + 1))
672672
if use_monitor:
673-
if i_epoch == n_epoch and i_epoch < self.max_epoch and not monitor.rs["terminate"]:
673+
if i_epoch == n_epoch and i_epoch < self.max_epoch and not monitor.info["terminate"]:
674674
monitor.flat_flag = True
675675
monitor.punish_extension()
676676
n_epoch = min(n_epoch + monitor.extension, self.max_epoch)
677677
print(" - Extending n_epoch to {}".format(n_epoch))
678678
if i_epoch == self.max_epoch:
679679
terminate = True
680-
if not monitor.rs["terminate"]:
680+
if not monitor.info["terminate"]:
681681
if not over_fitting_flag:
682682
print(
683683
" - Model seems to be under-fitting but max_epoch reached. "

_Dist/NeuralNetworks/DistBase.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def fit(self, x, y, x_test=None, y_test=None, sample_weights=None, names=("train
565565
level=logging.INFO, logger=logger
566566
)
567567
return self
568-
monitor = TrainMonitor(Metrics.sign_dict[self._metric_name], snapshot_ratio).start_new_run()
568+
monitor = TrainMonitor(Metrics.sign_dict[self._metric_name], snapshot_ratio)
569569

570570
if verbose >= 2:
571571
prepare_tensorboard_verbose(self._sess)
@@ -618,15 +618,15 @@ def fit(self, x, y, x_test=None, y_test=None, sample_weights=None, names=("train
618618
break
619619
self.log["epoch_loss"].append(epoch_loss / (j + 1))
620620
if use_monitor:
621-
if i_epoch == n_epoch and i_epoch < self.max_epoch and not monitor.rs["terminate"]:
621+
if i_epoch == n_epoch and i_epoch < self.max_epoch and not monitor.info["terminate"]:
622622
monitor.flat_flag = True
623623
monitor.punish_extension()
624624
n_epoch = min(n_epoch + monitor.extension, self.max_epoch)
625625
self.log_msg("Extending n_epoch to {}".format(n_epoch), logger=logger)
626626
bar.set_max(n_epoch)
627627
if i_epoch == self.max_epoch:
628628
terminate = True
629-
if not monitor.rs["terminate"]:
629+
if not monitor.info["terminate"]:
630630
if not over_fitting_flag:
631631
self.log_msg(
632632
"Model seems to be under-fitting but max_epoch reached. "

_Dist/NeuralNetworks/NNUtil.py

+97-131
Original file line numberDiff line numberDiff line change
@@ -403,171 +403,137 @@ def get_numerical_idx(feature_sets, all_num_idx, all_unique_idx, logger=None):
403403

404404

405405
class TrainMonitor:
406-
def __init__(self, sign, snapshot_ratio, level=3, history_ratio=3, tolerance_ratio=2,
406+
def __init__(self, sign, snapshot_ratio, history_ratio=3, tolerance_ratio=2,
407407
extension=5, std_floor=0.001, std_ceiling=0.01):
408-
self.sign, self.flat_flag = sign, False
409-
self.snapshot_ratio, self.level = snapshot_ratio, max(1, int(level))
408+
self.sign = sign
409+
self.snapshot_ratio = snapshot_ratio
410410
self.n_history = int(snapshot_ratio * history_ratio)
411-
if level < 3:
412-
if level == 1:
413-
tolerance_ratio /= 2
414411
self.n_tolerance = int(snapshot_ratio * tolerance_ratio)
415412
self.extension = extension
416413
self.std_floor, self.std_ceiling = std_floor, std_ceiling
417-
self._run_id = -1
418-
self._rs = None
419414
self._scores = []
420-
self._running_sum = self._running_square_sum = self._running_best = None
421-
self._is_best = self._over_fit_performance = self._best_checkpoint_performance = None
422-
self._descend_counter = self._flat_counter = self._over_fitting_flag = None
415+
self.flat_flag = False
416+
self._is_best = self._running_best = None
417+
self._running_sum = self._running_square_sum = None
423418
self._descend_increment = self.n_history * extension / 30
424419

425-
@property
426-
def rs(self):
427-
return self._rs
428-
429-
@property
430-
def params(self):
431-
return {
432-
"level": self.level, "n_history": self.n_history, "n_tolerance": self.n_tolerance,
433-
"extension": self.extension, "std_floor": self.std_floor, "std_ceiling": self.std_ceiling
434-
}
435-
436-
@property
437-
def descend_counter(self):
438-
return self._descend_counter
439-
440-
@property
441-
def over_fitting_flag(self):
442-
return self._over_fitting_flag
443-
444-
@property
445-
def n_epoch(self):
446-
return int(len(self._scores) / self.snapshot_ratio)
447-
448-
def _reset_rs(self):
449-
self._rs = {"terminate": False, "save_checkpoint": False, "save_best": False, "info": None}
450-
451-
def reset_monitors(self):
452-
self._reset_rs()
453420
self._over_fit_performance = math.inf
454421
self._best_checkpoint_performance = -math.inf
455-
self._descend_counter = self._flat_counter = self._over_fitting_flag = 0
456-
457-
def start_new_run(self):
458-
self._run_id += 1
459-
self.reset_monitors()
460-
return self
422+
self._descend_counter = self._flat_counter = self.over_fitting_flag = 0
423+
self.info = {"terminate": False, "save_checkpoint": False, "save_best": False, "info": None}
461424

462425
def punish_extension(self):
463426
self._descend_counter += self._descend_increment
464427

465-
def check(self, new_score):
466-
scores = self._scores
467-
scores.append(new_score * self.sign)
468-
n_history = min(self.n_history, len(scores))
469-
if n_history == 1:
470-
return self._rs
471-
# Update running sum & square sum
472-
if n_history < self.n_history or len(scores) == self.n_history:
428+
def _update_running_info(self, last_score, n_history):
429+
if n_history <= self.n_history:
473430
if self._running_sum is None or self._running_square_sum is None:
474-
self._running_sum = scores[0] + scores[1]
475-
self._running_square_sum = scores[0] ** 2 + scores[1] ** 2
431+
self._running_sum = self._scores[0] + self._scores[1]
432+
self._running_square_sum = self._scores[0] ** 2 + self._scores[1] ** 2
476433
else:
477-
self._running_sum += scores[-1]
478-
self._running_square_sum += scores[-1] ** 2
434+
self._running_sum += last_score
435+
self._running_square_sum += last_score ** 2
479436
else:
480-
previous = scores[-n_history - 1]
481-
self._running_sum += scores[-1] - previous
482-
self._running_square_sum += scores[-1] ** 2 - previous ** 2
483-
# Update running best
437+
previous = self._scores[-n_history - 1]
438+
self._running_sum += last_score - previous
439+
self._running_square_sum += last_score ** 2 - previous ** 2
484440
if self._running_best is None:
485-
if scores[0] > scores[1]:
441+
if self._scores[0] > self._scores[1]:
486442
improvement = 0
487-
self._running_best, self._is_best = scores[0], False
443+
self._running_best, self._is_best = self._scores[0], False
488444
else:
489-
improvement = scores[1] - scores[0]
490-
self._running_best, self._is_best = scores[1], True
491-
elif self._running_best > scores[-1]:
445+
improvement = self._scores[1] - self._scores[0]
446+
self._running_best, self._is_best = self._scores[1], True
447+
elif self._running_best > last_score:
492448
improvement = 0
493449
self._is_best = False
494450
else:
495-
improvement = scores[-1] - self._running_best
496-
self._running_best = scores[-1]
451+
improvement = last_score - self._running_best
452+
self._running_best = last_score
497453
self._is_best = True
498-
# Check
499-
self._rs["save_checkpoint"] = False
454+
return improvement
455+
456+
def _handle_overfitting(self, last_score, res, std):
457+
if self._descend_counter == 0:
458+
self.info["save_best"] = True
459+
self._over_fit_performance = last_score
460+
self._descend_counter += min(self.n_tolerance / 3, -res / std)
461+
self.over_fitting_flag = 1
462+
463+
def _handle_recovering(self, improvement, last_score, res, std):
464+
if res > 3 * std and self._is_best and improvement > std:
465+
self.info["save_best"] = True
466+
new_counter = self._descend_counter - res / std
467+
if self._descend_counter > 0 >= new_counter:
468+
self._over_fit_performance = math.inf
469+
if last_score > self._best_checkpoint_performance:
470+
self._best_checkpoint_performance = last_score
471+
if last_score > self._running_best - std:
472+
self.info["save_checkpoint"] = True
473+
self.info["info"] = (
474+
"Current snapshot ({}) seems to be working well, "
475+
"saving checkpoint in case we need to restore".format(len(self._scores))
476+
)
477+
self.over_fitting_flag = 0
478+
self._descend_counter = max(new_counter, 0)
479+
480+
def _handle_is_best(self):
481+
if self._is_best:
482+
self.info["terminate"] = False
483+
if self.info["save_best"]:
484+
self.info["save_checkpoint"] = True
485+
self.info["save_best"] = False
486+
self.info["info"] = (
487+
"Current snapshot ({}) leads to best result we've ever had, "
488+
"saving checkpoint since ".format(len(self._scores))
489+
)
490+
if self.over_fitting_flag:
491+
self.info["info"] += "we've suffered from over-fitting"
492+
else:
493+
self.info["info"] += "performance has improved significantly"
494+
495+
def _handle_period(self, last_score):
496+
if len(self._scores) % self.snapshot_ratio == 0 and last_score > self._best_checkpoint_performance:
497+
self._best_checkpoint_performance = last_score
498+
self.info["terminate"] = False
499+
self.info["save_checkpoint"] = True
500+
self.info["info"] = (
501+
"Current snapshot ({}) leads to best checkpoint we've ever had, "
502+
"saving checkpoint in case we need to restore".format(len(self._scores))
503+
)
504+
505+
def check(self, new_metric):
506+
last_score = new_metric * self.sign
507+
self._scores.append(last_score)
508+
n_history = min(self.n_history, len(self._scores))
509+
if n_history == 1:
510+
return self.info
511+
improvement = self._update_running_info(last_score, n_history)
512+
self.info["save_checkpoint"] = False
500513
mean = self._running_sum / n_history
501514
std = math.sqrt(max(self._running_square_sum / n_history - mean ** 2, 1e-12))
502515
std = min(std, self.std_ceiling)
503516
if std < self.std_floor:
504517
if self.flat_flag:
505518
self._flat_counter += 1
506519
else:
507-
if self.level >= 3 or self._is_best:
508-
self._flat_counter = max(self._flat_counter - 1, 0)
509-
elif self.flat_flag and self.level < 3 and not self._is_best:
510-
self._flat_counter += 1
511-
res = scores[-1] - mean
512-
if res < -std and scores[-1] < self._over_fit_performance - std:
513-
if self._descend_counter == 0:
514-
self._rs["save_best"] = True
515-
self._over_fit_performance = scores[-1]
516-
if self._over_fit_performance > self._running_best:
517-
self._best_checkpoint_performance = self._over_fit_performance
518-
self._rs["save_checkpoint"] = True
519-
self._rs["info"] = (
520-
"Current snapshot ({}) seems to be over-fitting, "
521-
"saving checkpoint in case we need to restore".format(len(scores) + self._run_id)
522-
)
523-
self._descend_counter += min(self.n_tolerance / 3, -res / std)
524-
self._over_fitting_flag = 1
520+
self._flat_counter = max(self._flat_counter - 1, 0)
521+
res = last_score - mean
522+
if res < -std and last_score < self._over_fit_performance - std:
523+
self._handle_overfitting(last_score, res, std)
525524
elif res > std:
526-
if res > 3 * std and self._is_best and improvement > std:
527-
self._rs["save_best"] = True
528-
new_counter = self._descend_counter - res / std
529-
if self._descend_counter > 0 >= new_counter:
530-
self._over_fit_performance = math.inf
531-
if scores[-1] > self._best_checkpoint_performance:
532-
self._best_checkpoint_performance = scores[-1]
533-
if scores[-1] > self._running_best - std:
534-
self._rs["save_checkpoint"] = True
535-
self._rs["info"] = (
536-
"Current snapshot ({}) seems to be working well, "
537-
"saving checkpoint in case we need to restore".format(len(scores)+self._run_id)
538-
)
539-
self._over_fitting_flag = 0
540-
self._descend_counter = max(new_counter, 0)
525+
self._handle_recovering(improvement, last_score, res, std)
541526
if self._flat_counter >= self.n_tolerance * self.n_history:
542-
self._rs["info"] = "Performance not improving"
543-
self._rs["terminate"] = True
544-
return self._rs
527+
self.info["info"] = "Performance not improving"
528+
self.info["terminate"] = True
529+
return self.info
545530
if self._descend_counter >= self.n_tolerance:
546-
self._rs["info"] = "Over-fitting"
547-
self._rs["terminate"] = True
548-
return self._rs
549-
if self._is_best:
550-
self._rs["terminate"] = False
551-
if self._rs["save_best"]:
552-
self._rs["save_checkpoint"] = True
553-
self._rs["save_best"] = False
554-
self._rs["info"] = (
555-
"Current snapshot ({}) leads to best result we've ever had, "
556-
"saving checkpoint since ".format(len(scores) + self._run_id)
557-
)
558-
if self._over_fitting_flag:
559-
self._rs["info"] += "we've suffered from over-fitting"
560-
else:
561-
self._rs["info"] += "performance has improved significantly"
562-
if len(scores) % self.snapshot_ratio == 0 and scores[-1] > self._best_checkpoint_performance:
563-
self._best_checkpoint_performance = scores[-1]
564-
self._rs["terminate"] = False
565-
self._rs["save_checkpoint"] = True
566-
self._rs["info"] = (
567-
"Current snapshot ({}) leads to best checkpoint we've ever had, "
568-
"saving checkpoint in case we need to restore".format(len(scores) + self._run_id)
569-
)
570-
return self._rs
531+
self.info["info"] = "Over-fitting"
532+
self.info["terminate"] = True
533+
return self.info
534+
self._handle_is_best()
535+
self._handle_period(last_score)
536+
return self.info
571537

572538

573539
class DNDF:

_Dist/NeuralNetworks/_Tests/Pruner/Base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def fit(self, x, y, x_test=None, y_test=None, sample_weights=None, names=("train
573573
over_fitting_flag = 0
574574
n_epoch = self.n_epoch
575575
tmp_checkpoint_folder = os.path.join(self.model_saving_path, "tmp")
576-
monitor = TrainMonitor(Metrics.sign_dict[self._metric_name], snapshot_ratio).start_new_run()
576+
monitor = TrainMonitor(Metrics.sign_dict[self._metric_name], snapshot_ratio)
577577

578578
if verbose >= 2:
579579
prepare_tensorboard_verbose(self._sess)
@@ -709,14 +709,14 @@ def fit(self, x, y, x_test=None, y_test=None, sample_weights=None, names=("train
709709

710710
self.log["epoch_loss"].append(epoch_loss / (j + 1))
711711
if use_monitor:
712-
if i_epoch == n_epoch and i_epoch < self.max_epoch and not monitor.rs["terminate"]:
712+
if i_epoch == n_epoch and i_epoch < self.max_epoch and not monitor.info["terminate"]:
713713
monitor.flat_flag = True
714714
monitor.punish_extension()
715715
n_epoch = min(n_epoch + monitor.extension, self.max_epoch)
716716
print(" - Extending n_epoch to {}".format(n_epoch))
717717
if i_epoch == self.max_epoch:
718718
terminate = True
719-
if not monitor.rs["terminate"]:
719+
if not monitor.info["terminate"]:
720720
if not over_fitting_flag:
721721
print(
722722
" - Model seems to be under-fitting but max_epoch reached. "

0 commit comments

Comments
 (0)