Skip to content

Commit

Permalink
Record differences in value from old to new
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Mar 9, 2024
1 parent c33eeee commit 05adf12
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ def train(self) -> None:

entropy_losses = []
pg_losses, value_losses = [], []
value_diffs_mean = []
value_diffs_min = []
value_diffs_max = []
clip_fractions = []
approx_kl_divs = []

Expand Down Expand Up @@ -459,18 +462,22 @@ def train(self) -> None:
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)

value_diff = values - rollout_data.old_values
if self.clip_range_vf is None:
# No clipping
values_pred = values
else:
# Clip the difference between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
)
values_pred = rollout_data.old_values + th.clamp(value_diff, -clip_range_vf, clip_range_vf)
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
with th.no_grad():
value_diff_abs = value_diff.abs()
value_diffs_mean.append(value_diff_abs.mean().item())
value_diffs_min.append(value_diff_abs.min().item())
value_diffs_max.append(value_diff_abs.max().item())

# Entropy loss favor exploration
if entropy is None:
Expand Down Expand Up @@ -516,6 +523,9 @@ def train(self) -> None:
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/value_diff_mean", np.mean(value_diffs_mean))
self.logger.record("train/value_diff_min", np.min(value_diffs_min))
self.logger.record("train/value_diff_max", np.max(value_diffs_max))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/loss", loss.item())
Expand Down

0 comments on commit 05adf12

Please sign in to comment.