Skip to content

Commit

Permalink
Fix eval/jsrl logging
Browse files Browse the repository at this point in the history
  • Loading branch information
steventango committed Jul 20, 2023
1 parent f85abf9 commit 4d74537
Showing 1 changed file with 46 additions and 14 deletions.
60 changes: 46 additions & 14 deletions src/jsrl/jsrl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.base_class import BaseAlgorithm, Logger
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import MaybeCallback
from stable_baselines3.common.callbacks import EvalCallback, CallbackList, BaseCallback
Expand All @@ -12,24 +12,60 @@ def __init__(self, policy, logger, *args, **kwargs):
self.policy = policy
self.logger = logger
self.best_moving_mean_reward = -np.inf
self.tolerated_moving_mean_reward = -np.inf
self.mean_rewards = np.full(policy.window_size, -np.inf, dtype=np.float32)

def _on_step(self) -> bool:
self.logger.record("jsrl/horizon", self.policy.horizons[self.policy.horizon_step])
self.mean_rewards = np.roll(self.mean_rewards, 1)
self.mean_rewards[0] = self.parent.last_mean_reward
moving_mean_reward = np.mean(self.mean_rewards)

self.logger.record("jsrl/horizon", self.policy.horizon)
self.logger.record("jsrl/moving_mean_reward", moving_mean_reward)
tolerated_moving_mean_reward = self.best_moving_mean_reward - self.policy.tolerance * np.abs(
self.best_moving_mean_reward
)
self.logger.record("jsrl/best_moving_mean_reward", moving_mean_reward)
self.logger.record("jsrl/tolerated_moving_mean_reward", self.tolerated_moving_mean_reward)
self.logger.dump(self.num_timesteps)

if self.mean_rewards[-1] == -np.inf or self.policy.horizon <= 0:
return
return True
elif self.best_moving_mean_reward == -np.inf:
self.best_moving_mean_reward = moving_mean_reward
elif moving_mean_reward > tolerated_moving_mean_reward:
elif moving_mean_reward > self.tolerated_moving_mean_reward:
self.policy.update_horizon()
self.best_moving_mean_reward = max(self.best_moving_mean_reward, moving_mean_reward)
if moving_mean_reward > self.best_moving_mean_reward:
self.best_moving_mean_reward = moving_mean_reward
self.tolerated_moving_mean_reward = moving_mean_reward - self.policy.tolerance * np.abs(moving_mean_reward)
return True


class JSRLEvalCallback(EvalCallback):
def init_callback(self, model: BaseAlgorithm) -> None:
super().init_callback(model)
self.logger = JSRLLogger(self.logger)


class JSRLLogger():
def __init__(self, logger: Logger):
self._logger = logger

def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
:param key: save to log this key
:param value: save to log this value
:param exclude: outputs to be excluded
"""
key = key.replace("eval/", "jsrl/")
self._logger.record(key, value, exclude)

def dump(self, step: int = 0) -> None:
"""
Write all of the diagnostics from the current iteration
"""
self._logger.dump(step)


def get_jsrl_policy(ExplorationPolicy: BasePolicy):
Expand Down Expand Up @@ -164,7 +200,7 @@ def _init_callback(
:return: A hybrid callback calling `callback` and performing evaluation.
"""
callback = super()._init_callback(callback, progress_bar)
eval_callback = EvalCallback(
eval_callback = JSRLEvalCallback(
self.env,
callback_after_eval=JSRLAfterEvalCallback(
self.policy,
Expand All @@ -181,10 +217,6 @@ def _init_callback(
]
)
callback.init_callback(self)
default_record = eval_callback.logger.record
eval_callback.logger.record = lambda key, *args, **kwargs: default_record(
key.replace("eval/", "jsrl/"), *args, **kwargs
)
return callback

def predict(
Expand Down

0 comments on commit 4d74537

Please sign in to comment.