Skip to content

Commit

Permalink
Callback to early stop the training if there is no model improvement …
Browse files Browse the repository at this point in the history
…after consecutive evaluations (#741)

* Added StopTrainingOnNoModelImprovement callback and callback_after_eval parameter in EvalCallback

* Correction in EvalCallback and tests for StopTrainingOnNoModelImprovement

* Update the docs related to new StopTrainingOnNoModelImprovement callback

* Update doc

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
3 people authored Feb 25, 2022
1 parent db5366f commit cdaa9ab
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 7 deletions.
32 changes: 31 additions & 1 deletion docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ It will save the best model if ``best_model_save_path`` folder is specified and

.. note::

You can pass a child callback via the ``callback_on_new_best`` argument. It will be triggered each time there is a new best model.
You can pass child callbacks via ``callback_after_eval`` and ``callback_on_new_best`` arguments. ``callback_after_eval`` will be triggered after every evaluation, and ``callback_on_new_best`` will be triggered each time there is a new best model.


.. warning::
Expand Down Expand Up @@ -333,6 +333,36 @@ and in total for ``max_episodes * n_envs`` episodes.
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)
.. _StopTrainingOnNoModelImprovement:

StopTrainingOnNoModelImprovement
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations.
The idea is to save time in experiments when you know that the learning curves are somehow well behaved and, therefore,
after many evaluations without improvement the learning has probably stabilized.
It must be used with the :ref:`EvalCallback` and use the event triggered after every evaluation.


.. code-block:: python
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training if there is no improvement after more than 3 evaluations
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1)
# Almost infinite number of timesteps, but the training will stop early
# as soon as the the number of consecutive evaluations without model
# improvement is greater than 3
model.learn(int(1e10), callback=eval_callback)
.. automodule:: stable_baselines3.common.callbacks
:members:
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Added ``StopTrainingOnNoModelImprovement`` to callback collection (@caburu)
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
depending on desired maximum width of output.
- Allow PPO to turn of advantage normalization (see `PR #763 <https://github.com/DLR-RM/stable-baselines3/pull/763>`_) @vwxyzjn
Expand Down Expand Up @@ -925,4 +926,4 @@ And all the contributors:
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu
72 changes: 67 additions & 5 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class EvalCallback(EventCallback):
:param eval_env: The environment used for initialization
:param callback_on_new_best: Callback to trigger
when there is a new best model according to the ``mean_reward``
:param callback_after_eval: Callback to trigger after every evaluation
:param n_eval_episodes: The number of episodes to test the agent
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
Expand All @@ -296,6 +297,7 @@ def __init__(
self,
eval_env: Union[gym.Env, VecEnv],
callback_on_new_best: Optional[BaseCallback] = None,
callback_after_eval: Optional[BaseCallback] = None,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
log_path: Optional[str] = None,
Expand All @@ -305,7 +307,13 @@ def __init__(
verbose: int = 1,
warn: bool = True,
):
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose)

self.callback_on_new_best = callback_on_new_best
if self.callback_on_new_best is not None:
# Give access to the parent
self.callback_on_new_best.parent = self

self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
self.best_mean_reward = -np.inf
Expand Down Expand Up @@ -342,6 +350,10 @@ def _init_callback(self) -> None:
if self.log_path is not None:
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)

# Init callback called on new best model
if self.callback_on_new_best is not None:
self.callback_on_new_best.init_callback(self.model)

def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
"""
Callback passed to the ``evaluate_policy`` function
Expand All @@ -360,7 +372,10 @@ def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any

def _on_step(self) -> bool:

continue_training = True

if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:

# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:
Expand Down Expand Up @@ -432,11 +447,15 @@ def _on_step(self) -> bool:
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
# Trigger callback if needed
if self.callback is not None:
return self._on_event()
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()

return True
# Trigger callback after every evaluation, if needed
if self.callback is not None:
continue_training = continue_training and self._on_event()

return continue_training

def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -538,3 +557,46 @@ def _on_step(self) -> bool:
f"{mean_ep_str}"
)
return continue_training


class StopTrainingOnNoModelImprovement(BaseCallback):
"""
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
It must be used with the ``EvalCallback``.
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
:param min_evals: Number of evaluations before start to count evaluations without improvements.
:param verbose: Verbosity of the output (set to 1 for info messages)
"""

def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
self.min_evals = min_evals
self.last_best_mean_reward = -np.inf
self.no_improvement_evals = 0

def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"

continue_training = True

if self.n_calls > self.min_evals:
if self.parent.best_mean_reward > self.last_best_mean_reward:
self.no_improvement_evals = 0
else:
self.no_improvement_evals += 1
if self.no_improvement_evals > self.max_no_improvement_evals:
continue_training = False

self.last_best_mean_reward = self.parent.best_mean_reward

if self.verbose > 0 and not continue_training:
print(
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
)

return continue_training
5 changes: 5 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EvalCallback,
EveryNTimesteps,
StopTrainingOnMaxEpisodes,
StopTrainingOnNoModelImprovement,
StopTrainingOnRewardThreshold,
)
from stable_baselines3.common.env_util import make_vec_env
Expand All @@ -35,9 +36,13 @@ def test_callbacks(tmp_path, model_class):
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)

# Stop training if there is no model improvement after 2 evaluations
callback_no_model_improvement = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=1, verbose=1)

eval_callback = EvalCallback(
eval_env,
callback_on_new_best=callback_on_best,
callback_after_eval=callback_no_model_improvement,
best_model_save_path=log_folder,
log_path=log_folder,
eval_freq=100,
Expand Down

0 comments on commit cdaa9ab

Please sign in to comment.