From df269be79958493121e9be0071d0ac8a847cde57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Mon, 5 Feb 2024 13:50:34 -0800 Subject: [PATCH] Make n_updates track the number of partial epoch updates --- stable_baselines3/ppo_recurrent/ppo_recurrent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index dd42040a5..937f59ac0 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -472,8 +472,7 @@ def train(self) -> None: # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() - self._n_updates += 1 - + self._n_updates += 1 if not continue_training: break self.policy.optimizer.zero_grad(set_to_none=True) # Free gradients until the next call to .train()