Skip to content

Commit

Permalink
Record the proportion of values that get clipped
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Mar 14, 2024
1 parent 05adf12 commit 6e9fc9f
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import sys
import time
import warnings
Expand Down Expand Up @@ -412,13 +413,16 @@ def train(self) -> None:
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
else:
clip_range_vf = math.inf

entropy_losses = []
pg_losses, value_losses = [], []
value_diffs_mean = []
value_diffs_min = []
value_diffs_max = []
clip_fractions = []
clip_fractions_vf = []
approx_kl_divs = []

continue_training = True
Expand Down Expand Up @@ -478,6 +482,8 @@ def train(self) -> None:
value_diffs_mean.append(value_diff_abs.mean().item())
value_diffs_min.append(value_diff_abs.min().item())
value_diffs_max.append(value_diff_abs.max().item())
clip_fraction_vf = th.mean((value_diff_abs > clip_range_vf).float()).item()
clip_fractions_vf.append(clip_fraction_vf)

# Entropy loss favor exploration
if entropy is None:
Expand Down Expand Up @@ -528,6 +534,7 @@ def train(self) -> None:
self.logger.record("train/value_diff_max", np.max(value_diffs_max))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/clip_fraction", np.mean(clip_fractions_vf))
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var.item())
if hasattr(self.policy, "log_std"):
Expand Down

0 comments on commit 6e9fc9f

Please sign in to comment.