From 4d74537a50243e81cf942affefba2b412f607098 Mon Sep 17 00:00:00 2001 From: Steven Tang <18170455+steventango@users.noreply.github.com> Date: Thu, 20 Jul 2023 16:42:30 -0600 Subject: [PATCH] Fix eval/jsrl logging --- src/jsrl/jsrl.py | 60 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/src/jsrl/jsrl.py b/src/jsrl/jsrl.py index 7877639..f79bbe3 100644 --- a/src/jsrl/jsrl.py +++ b/src/jsrl/jsrl.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.base_class import BaseAlgorithm, Logger from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.common.callbacks import EvalCallback, CallbackList, BaseCallback @@ -12,24 +12,60 @@ def __init__(self, policy, logger, *args, **kwargs): self.policy = policy self.logger = logger self.best_moving_mean_reward = -np.inf + self.tolerated_moving_mean_reward = -np.inf self.mean_rewards = np.full(policy.window_size, -np.inf, dtype=np.float32) def _on_step(self) -> bool: - self.logger.record("jsrl/horizon", self.policy.horizons[self.policy.horizon_step]) self.mean_rewards = np.roll(self.mean_rewards, 1) self.mean_rewards[0] = self.parent.last_mean_reward moving_mean_reward = np.mean(self.mean_rewards) + + self.logger.record("jsrl/horizon", self.policy.horizon) self.logger.record("jsrl/moving_mean_reward", moving_mean_reward) - tolerated_moving_mean_reward = self.best_moving_mean_reward - self.policy.tolerance * np.abs( - self.best_moving_mean_reward - ) + self.logger.record("jsrl/best_moving_mean_reward", moving_mean_reward) + self.logger.record("jsrl/tolerated_moving_mean_reward", self.tolerated_moving_mean_reward) + self.logger.dump(self.num_timesteps) + if self.mean_rewards[-1] == -np.inf or self.policy.horizon <= 0: - return + return True elif self.best_moving_mean_reward == -np.inf: self.best_moving_mean_reward = moving_mean_reward - elif moving_mean_reward > tolerated_moving_mean_reward: + elif moving_mean_reward > self.tolerated_moving_mean_reward: self.policy.update_horizon() - self.best_moving_mean_reward = max(self.best_moving_mean_reward, moving_mean_reward) + if moving_mean_reward > self.best_moving_mean_reward: + self.best_moving_mean_reward = moving_mean_reward + self.tolerated_moving_mean_reward = moving_mean_reward - self.policy.tolerance * np.abs(moving_mean_reward) + return True + + +class JSRLEvalCallback(EvalCallback): + def init_callback(self, model: BaseAlgorithm) -> None: + super().init_callback(model) + self.logger = JSRLLogger(self.logger) + + +class JSRLLogger(): + def __init__(self, logger: Logger): + self._logger = logger + + def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + + :param key: save to log this key + :param value: save to log this value + :param exclude: outputs to be excluded + """ + key = key.replace("eval/", "jsrl/") + self._logger.record(key, value, exclude) + + def dump(self, step: int = 0) -> None: + """ + Write all of the diagnostics from the current iteration + """ + self._logger.dump(step) def get_jsrl_policy(ExplorationPolicy: BasePolicy): @@ -164,7 +200,7 @@ def _init_callback( :return: A hybrid callback calling `callback` and performing evaluation. """ callback = super()._init_callback(callback, progress_bar) - eval_callback = EvalCallback( + eval_callback = JSRLEvalCallback( self.env, callback_after_eval=JSRLAfterEvalCallback( self.policy, @@ -181,10 +217,6 @@ def _init_callback( ] ) callback.init_callback(self) - default_record = eval_callback.logger.record - eval_callback.logger.record = lambda key, *args, **kwargs: default_record( - key.replace("eval/", "jsrl/"), *args, **kwargs - ) return callback def predict(