Skip to content

Commit

Permalink
Added earlystopping mechanism (OpenNMT#1389)
Browse files Browse the repository at this point in the history
* Added earlystopping mechanism
* Fixed earlystopping multi-gpu stoppage
  • Loading branch information
antoniovilarinholopes authored and vince62s committed Apr 9, 2019
1 parent d4edfc4 commit f7fc40e
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 3 deletions.
5 changes: 5 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,11 @@ def train_opts(parser):
help="Make a single pass over the training dataset.")
group.add('--epochs', '-epochs', type=int, default=0,
help='Deprecated epochs see train_steps')
group.add('--early_stopping', '-early_stopping', type=int, default=0,
help='Number of validation steps without improving.')
group.add('--early_stopping_criteria', '-early_stopping_criteria',
nargs="*", default=None,
help='Criteria to use for early stopping.')
group.add('--optim', '-optim', default='sgd',
choices=['sgd', 'adagrad', 'adadelta', 'adam',
'sparseadam', 'adafactor', 'fusedadam'],
Expand Down
17 changes: 15 additions & 2 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
n_gpu = 0
gpu_verbose_level = opt.gpu_verbose_level

earlystopper = onmt.utils.EarlyStopping(
opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) \
if opt.early_stopping > 0 else None

report_manager = onmt.utils.build_report_manager(opt)
trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
shard_size, norm_method,
Expand All @@ -62,7 +66,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
model_saver=model_saver if gpu_rank == 0 else None,
average_decay=average_decay,
average_every=average_every,
model_dtype=opt.model_dtype)
model_dtype=opt.model_dtype,
earlystopper=earlystopper)
return trainer


Expand Down Expand Up @@ -98,7 +103,8 @@ def __init__(self, model, train_loss, valid_loss, optim,
accum_steps=[0],
n_gpu=1, gpu_rank=1,
gpu_verbose_level=0, report_manager=None, model_saver=None,
average_decay=0, average_every=1, model_dtype='fp32'):
average_decay=0, average_every=1, model_dtype='fp32',
earlystopper=None):
# Basic attributes.
self.model = model
self.train_loss = train_loss
Expand All @@ -119,6 +125,7 @@ def __init__(self, model, train_loss, valid_loss, optim,
self.moving_average = None
self.average_every = average_every
self.model_dtype = model_dtype
self.earlystopper = earlystopper

for i in range(len(self.accum_count_l)):
assert self.accum_count_l[i] > 0
Expand Down Expand Up @@ -248,6 +255,12 @@ def train(self,
% (self.gpu_rank, step))
self._report_step(self.optim.learning_rate(),
step, valid_stats=valid_stats)
# Run patience mechanism
if self.earlystopper is not None:
self.earlystopper(valid_stats, step)
# If the patience has reached the limit, stop training
if self.earlystopper.has_stopped():
break

if (self.model_saver is not None
and (save_checkpoint_steps != 0
Expand Down
4 changes: 3 additions & 1 deletion onmt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from onmt.utils.statistics import Statistics
from onmt.utils.optimizers import MultipleOptimizer, \
Optimizer, AdaFactor
from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts

__all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr",
"build_report_manager", "Statistics",
"MultipleOptimizer", "Optimizer", "AdaFactor"]
"MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping",
"scorers_from_opts"]
194 changes: 194 additions & 0 deletions onmt/utils/earlystopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@

from enum import Enum
from onmt.utils.logging import logger


class PatienceEnum(Enum):
IMPROVING = 0
DECREASING = 1
STOPPED = 2


class Scorer(object):
def __init__(self, best_score, name):
self.best_score = best_score
self.name = name

def is_improving(self, stats):
raise NotImplementedError()

def is_decreasing(self, stats):
raise NotImplementedError()

def update(self, stats):
self.best_score = self._caller(stats)

def __call__(self, stats, **kwargs):
return self._caller(stats)

def _caller(self, stats):
raise NotImplementedError()


class PPLScorer(Scorer):

def __init__(self):
super(PPLScorer, self).__init__(float("inf"), "ppl")

def is_improving(self, stats):
return stats.ppl() < self.best_score

def is_decreasing(self, stats):
return stats.ppl() > self.best_score

def _caller(self, stats):
return stats.ppl()


class AccuracyScorer(Scorer):

def __init__(self):
super(AccuracyScorer, self).__init__(float("-inf"), "acc")

def is_improving(self, stats):
return stats.accuracy() > self.best_score

def is_decreasing(self, stats):
return stats.accuracy() < self.best_score

def _caller(self, stats):
return stats.accuracy()


DEFAULT_SCORERS = [PPLScorer(), AccuracyScorer()]


SCORER_BUILDER = {
"ppl": PPLScorer,
"accuracy": AccuracyScorer
}


def scorers_from_opts(opt):
if opt.early_stopping_criteria is None:
return DEFAULT_SCORERS
else:
scorers = []
for criterion in set(opt.early_stopping_criteria):
assert criterion in SCORER_BUILDER.keys(), \
"Criterion {} not found".format(criterion)
scorers.append(SCORER_BUILDER[criterion]())
return scorers


class EarlyStopping(object):

def __init__(self, tolerance, scorers=DEFAULT_SCORERS):
"""
Callable class to keep track of early stopping.
Args:
tolerance(int): number of validation steps without improving
scorer(fn): list of scorers to validate performance on dev
"""

self.tolerance = tolerance
self.stalled_tolerance = self.tolerance
self.current_tolerance = self.tolerance
self.early_stopping_scorers = scorers
self.status = PatienceEnum.IMPROVING
self.current_step_best = 0

def __call__(self, valid_stats, step):
"""
Update the internal state of early stopping mechanism, whether to
continue training or stop the train procedure.
Checks whether the scores from all pre-chosen scorers improved. If
every metric improve, then the status is switched to improving and the
tolerance is reset. If every metric deteriorate, then the status is
switched to decreasing and the tolerance is also decreased; if the
tolerance reaches 0, then the status is changed to stopped.
Finally, if some improved and others not, then it's considered stalled;
after tolerance number of stalled, the status is switched to stopped.
:param valid_stats: Statistics of dev set
"""

if self.status == PatienceEnum.STOPPED:
# Don't do anything
return

if all([scorer.is_improving(valid_stats) for scorer
in self.early_stopping_scorers]):
self._update_increasing(valid_stats, step)

elif all([scorer.is_decreasing(valid_stats) for scorer
in self.early_stopping_scorers]):
self._update_decreasing()

else:
self._update_stalled()

def _update_stalled(self):
self.stalled_tolerance -= 1

logger.info(
"Stalled patience: {}/{}".format(self.stalled_tolerance,
self.tolerance))

if self.stalled_tolerance == 0:
logger.info(
"Training finished after stalled validations. Early Stop!"
)
self._log_best_step()

self._decreasing_or_stopped_status_update(self.stalled_tolerance)

def _update_increasing(self, valid_stats, step):
self.current_step_best = step
for scorer in self.early_stopping_scorers:
logger.info(
"Model is improving {}: {:g} --> {:g}.".format(
scorer.name, scorer.best_score, scorer(valid_stats))
)
# Update best score of each criteria
scorer.update(valid_stats)

# Reset tolerance
self.current_tolerance = self.tolerance
self.stalled_tolerance = self.tolerance

# Update current status
self.status = PatienceEnum.IMPROVING

def _update_decreasing(self):
# Decrease tolerance
self.current_tolerance -= 1

# Log
logger.info(
"Decreasing patience: {}/{}".format(self.current_tolerance,
self.tolerance)
)
# Log
if self.current_tolerance == 0:
logger.info("Training finished after not improving. Early Stop!")
self._log_best_step()

self._decreasing_or_stopped_status_update(self.current_tolerance)

def _log_best_step(self):
logger.info("Best model found at step {}".format(
self.current_step_best))

def _decreasing_or_stopped_status_update(self, tolerance):
self.status = PatienceEnum.DECREASING \
if tolerance > 0 \
else PatienceEnum.STOPPED

def is_improving(self):
return self.status == PatienceEnum.IMPROVING

def has_stopped(self):
return self.status == PatienceEnum.STOPPED

0 comments on commit f7fc40e

Please sign in to comment.