From 05adf12690981fdc9286d02adbfc01a9fb853a61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Sat, 9 Mar 2024 14:26:29 -0800 Subject: [PATCH] Record differences in value from old to new --- stable_baselines3/ppo_recurrent/ppo_recurrent.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 04a4d4da5..b047f9558 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -415,6 +415,9 @@ def train(self) -> None: entropy_losses = [] pg_losses, value_losses = [], [] + value_diffs_mean = [] + value_diffs_min = [] + value_diffs_max = [] clip_fractions = [] approx_kl_divs = [] @@ -459,18 +462,22 @@ def train(self) -> None: clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) + value_diff = values - rollout_data.old_values if self.clip_range_vf is None: # No clipping values_pred = values else: # Clip the difference between old and new value # NOTE: this depends on the reward scaling - values_pred = rollout_data.old_values + th.clamp( - values - rollout_data.old_values, -clip_range_vf, clip_range_vf - ) + values_pred = rollout_data.old_values + th.clamp(value_diff, -clip_range_vf, clip_range_vf) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) + with th.no_grad(): + value_diff_abs = value_diff.abs() + value_diffs_mean.append(value_diff_abs.mean().item()) + value_diffs_min.append(value_diff_abs.min().item()) + value_diffs_max.append(value_diff_abs.max().item()) # Entropy loss favor exploration if entropy is None: @@ -516,6 +523,9 @@ def train(self) -> None: self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/value_diff_mean", np.mean(value_diffs_mean)) + self.logger.record("train/value_diff_min", np.min(value_diffs_min)) + self.logger.record("train/value_diff_max", np.max(value_diffs_max)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item())