Skip to content

Commit

Permalink
Allow optional gradient clipping in OnPolicy algorithms (#22)
Browse files Browse the repository at this point in the history
Grad clipping takes a surprising amount of time according to profiles.
Maybe we can skip it for some runs? (probably not)
  • Loading branch information
rhaps0dy authored Feb 28, 2024
2 parents 498e1bf + 804ce10 commit 4554015
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
12 changes: 9 additions & 3 deletions stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch.nn import functional as F

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
BasePolicy,
MultiInputActorCriticPolicy,
)
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance

Expand Down Expand Up @@ -70,7 +75,7 @@ def __init__(
gae_lambda: float = 1.0,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
max_grad_norm: Optional[float] = 0.5,
rms_prop_eps: float = 1e-5,
use_rms_prop: bool = True,
use_sde: bool = False,
Expand Down Expand Up @@ -168,7 +173,8 @@ def train(self) -> None:
loss.backward()

# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
gae_lambda: float,
ent_coef: float,
vf_coef: float,
max_grad_norm: float,
max_grad_norm: Optional[float],
use_sde: bool,
sde_sample_freq: int,
stats_window_size: int = 100,
Expand Down
12 changes: 9 additions & 3 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from torch.nn import functional as F

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
BasePolicy,
MultiInputActorCriticPolicy,
)
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn

Expand Down Expand Up @@ -89,7 +94,7 @@ def __init__(
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
max_grad_norm: Optional[float] = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
target_kl: Optional[float] = None,
Expand Down Expand Up @@ -271,7 +276,8 @@ def train(self) -> None:
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

self._n_updates += 1
Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
max_grad_norm: Optional[float] = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
target_kl: Optional[float] = None,
Expand Down Expand Up @@ -470,7 +470,8 @@ def train(self) -> None:
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
Expand Down

0 comments on commit 4554015

Please sign in to comment.