From 1b53767da71ab94169c60f46f0b8e9d820fd59d5 Mon Sep 17 00:00:00 2001 From: Nicole White Date: Tue, 10 Oct 2017 12:00:55 -0700 Subject: [PATCH] Fix off-by-one error in EarlyStopping callback (#8100) --- keras/callbacks.py | 2 +- tests/keras/test_callbacks.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/keras/callbacks.py b/keras/callbacks.py index 50f606516dd..cab51b1fe22 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -497,10 +497,10 @@ def on_epoch_end(self, epoch, logs=None): self.best = current self.wait = 0 else: + self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True - self.wait += 1 def on_train_end(self, logs=None): if self.stopped_epoch > 0 and self.verbose > 0: diff --git a/tests/keras/test_callbacks.py b/tests/keras/test_callbacks.py index 63d552d6d75..a62eb7d096a 100644 --- a/tests/keras/test_callbacks.py +++ b/tests/keras/test_callbacks.py @@ -256,6 +256,31 @@ def test_EarlyStopping_reuse(): assert len(hist.epoch) >= patience +@keras_test +def test_EarlyStopping_patience(): + class DummyModel(object): + def __init__(self): + self.stop_training = False + + early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=2) + early_stop.model = DummyModel() + + losses = [0.0860, 0.1096, 0.1040, 0.1019] + + # Should stop after epoch 3, as the loss has not improved after patience=2 epochs. + epochs_trained = 0 + early_stop.on_train_begin() + + for epoch in range(len(losses)): + epochs_trained += 1 + early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]}) + + if early_stop.model.stop_training: + break + + assert epochs_trained == 3 + + @keras_test def test_LearningRateScheduler(): np.random.seed(1337)