Skip to content

Commit

Permalink
use alpha not log alpha in autotune (vwxyzjn#414)
Browse files Browse the repository at this point in the history
* use alpha not log alpha in autotune

* log_alpha.exp() instead of torch.exp(log_alpha) is slightly more clean.

* fix alpha optimization in sac_atari: use log_alpha.exp() instead of log_alpha

* update docs to use log_alpha.exp() instead of log_alpha in autotune

---------

Co-authored-by: Rousslan F.J. Dossa <dosssman@hotmail.fr>
Co-authored-by: Timo Klein <tkleia34@gmail.com>
  • Loading branch information
3 people authored Sep 14, 2023
1 parent f36d4a6 commit 7e24ae2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cleanrl/sac_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def get_action(self, x):

if args.autotune:
# re-use action probabilities for temperature loss
alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean()
alpha_loss = (action_probs.detach() * (-log_alpha.exp() * (log_pi + target_entropy).detach())).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
Expand Down
2 changes: 1 addition & 1 deletion cleanrl/sac_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def get_action(self, x):
if args.autotune:
with torch.no_grad():
_, log_pi, _ = actor.get_action(data.observations)
alpha_loss = (-log_alpha * (log_pi + target_entropy)).mean()
alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
Expand Down
2 changes: 1 addition & 1 deletion docs/rl-algorithms/sac.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ Surpassing Human-Level Performance on ImageNet Classification"](https://arxiv.or
```python hl_lines="3"
if args.autotune:
# re-use action probabilities for temperature loss
alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean()
alpha_loss = (action_probs.detach() * (-log_alpha.exp() * (log_pi + target_entropy).detach())).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
Expand Down

0 comments on commit 7e24ae2

Please sign in to comment.