Skip to content

Commit

Permalink
Add a callback to log raw stats (#216)
Browse files Browse the repository at this point in the history
* Add a callback to log raw stats

* Fixes and use tensorboard output directly

* Add test case and changelog

* fix CI

* Update test

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
vwxyzjn and araffin committed Mar 22, 2022
1 parent b7c948f commit f421dad
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
- Upgrade to Stable-Baselines3 (SB3) >= 1.4.1a1
- Upgrade to sb3-contrib >= 1.4.1a1
- Upgraded to gym 0.21
- Support experiment tracking via Weights and Biases (@vwxyzjn)
- Support experiment tracking via Weights and Biases via the `--track` flag (@vwxyzjn)
- Support tracking raw episodic stats via `RawStatisticsCallback` (@vwxyzjn, see https://github.com/DLR-RM/rl-baselines3-zoo/pull/216)

### New Features
- Verbose mode for each trial (when doing hyperparam optimization) can now be activated using the debug mode (verbose == 2)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import subprocess


def _assert_eq(left, right):
assert left == right, f"{left} != {right}"


def test_raw_stat_callback(tmp_path):
args = [
"-n",
str(200),
"--algo",
"ppo",
"--env",
"CartPole-v1",
"-params",
"callback:'utils.callbacks.RawStatisticsCallback'",
"--tensorboard-log",
f"{tmp_path}",
]

return_code = subprocess.call(["python", "train.py"] + args)
_assert_eq(return_code, 0)
34 changes: 34 additions & 0 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sb3_contrib import TQC
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
from stable_baselines3.common.vec_env import VecEnv


Expand Down Expand Up @@ -193,3 +194,36 @@ def _on_training_end(self) -> None:
if self.verbose > 0:
print("Waiting for training thread to terminate")
self.process.join()


class RawStatisticsCallback(BaseCallback):
"""
Callback used for logging raw episode data (return and episode length).
"""

def __init__(self, verbose=0):
super(RawStatisticsCallback, self).__init__(verbose)
# Custom counter to reports stats
# (and avoid reporting multiple values for the same step)
self._timesteps_counter = 0
self._tensorboard_writer = None

def _init_callback(self) -> None:
# Retrieve tensorboard writer to not flood the logger output
for out_format in self.logger.output_formats:
if isinstance(out_format, TensorBoardOutputFormat):
self._tensorboard_writer = out_format
assert self._tensorboard_writer is not None, "You must activate tensorboard logging when using RawStatisticsCallback"

def _on_step(self) -> bool:
for info in self.locals["infos"]:
if "episode" in info:
logger_dict = {
"raw/rollouts/episodic_return": info["episode"]["r"],
"raw/rollouts/episodic_length": info["episode"]["l"],
}
exclude_dict = {key: None for key in logger_dict.keys()}
self._timesteps_counter += info["episode"]["l"]
self._tensorboard_writer.write(logger_dict, exclude_dict, self._timesteps_counter)

return True

0 comments on commit f421dad

Please sign in to comment.