diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index fda20c9c0..68247bff1 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -5,7 +5,12 @@ from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +from stable_baselines3.common.policies import ( + ActorCriticCnnPolicy, + ActorCriticPolicy, + BasePolicy, + MultiInputActorCriticPolicy, +) from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -70,7 +75,7 @@ def __init__( gae_lambda: float = 1.0, ent_coef: float = 0.0, vf_coef: float = 0.5, - max_grad_norm: float = 0.5, + max_grad_norm: Optional[float] = 0.5, rms_prop_eps: float = 1e-5, use_rms_prop: bool = True, use_sde: bool = False, @@ -168,7 +173,8 @@ def train(self) -> None: loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 289a646af..9d1cc4b40 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -66,7 +66,7 @@ def __init__( gae_lambda: float, ent_coef: float, vf_coef: float, - max_grad_norm: float, + max_grad_norm: Optional[float], use_sde: bool, sde_sample_freq: int, stats_window_size: int = 100, diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 965b76cf9..49501431d 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -7,7 +7,12 @@ from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +from stable_baselines3.common.policies import ( + ActorCriticCnnPolicy, + ActorCriticPolicy, + BasePolicy, + MultiInputActorCriticPolicy, +) from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -89,7 +94,7 @@ def __init__( normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, - max_grad_norm: float = 0.5, + max_grad_norm: Optional[float] = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, target_kl: Optional[float] = None, @@ -271,7 +276,8 @@ def train(self) -> None: self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() self._n_updates += 1 diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index c8a4fecbc..bee7c563d 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -117,7 +117,7 @@ def __init__( normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, - max_grad_norm: float = 0.5, + max_grad_norm: Optional[float] = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, target_kl: Optional[float] = None, @@ -470,7 +470,8 @@ def train(self) -> None: self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() self._n_updates += 1 if not continue_training: