Skip to content

Commit

Permalink
Make n_updates track the number of partial epoch updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Feb 5, 2024
1 parent 4da11bc commit df269be
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,7 @@ def train(self) -> None:
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1

self._n_updates += 1
if not continue_training:
break
self.policy.optimizer.zero_grad(set_to_none=True) # Free gradients until the next call to .train()
Expand Down

0 comments on commit df269be

Please sign in to comment.