Skip to content

Adds per-head entropy coefficients to PPOLoss #2972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Mapping

import torch
from tensordict import (
Expand Down Expand Up @@ -330,7 +331,7 @@ def __init__(
*,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
entropy_coef: float | Mapping[str, float] = 0.01,
critic_coef: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down Expand Up @@ -408,7 +409,22 @@ def __init__(
torch, "get_default_device", lambda: torch.device("cpu")
)()

self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
if isinstance(entropy_coef, Mapping):
# Store the mapping for per-head coefficients
self._entropy_coef_map = {str(k): float(v) for k, v in entropy_coef.items()}
# Register an empty buffer for compatibility
self.register_buffer("entropy_coef", torch.tensor(0.0))
elif isinstance(entropy_coef, (float, int, torch.Tensor)):
# Register the scalar entropy coefficient
coef = (
float(entropy_coef)
if not torch.is_tensor(entropy_coef)
else float(entropy_coef.item())
)
self.register_buffer("entropy_coef", torch.tensor(coef))
self._entropy_coef_map = None
else:
raise TypeError("entropy_coef must be a float or a Mapping[str, float]")
if critic_coef is not None:
self.register_buffer(
"critic_coef", torch.tensor(critic_coef, device=device)
Expand Down Expand Up @@ -540,7 +556,6 @@ def _get_entropy(
return entropy.unsqueeze(-1)

def _get_cur_log_prob(self, tensordict):

if isinstance(
self.actor_network,
(ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule),
Expand Down Expand Up @@ -589,7 +604,6 @@ def _get_cur_log_prob(self, tensordict):
def _log_weight(
self, tensordict: TensorDictBase, adv_shape: torch.Size
) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]:

prev_log_prob = _maybe_get_or_select(
tensordict,
self.tensor_keys.sample_log_prob,
Expand Down Expand Up @@ -745,9 +759,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
td_out.set(
"entropy", _sum_td_features(entropy).detach().mean()
) # for logging
else:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
Expand Down Expand Up @@ -814,6 +831,31 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
}
self._value_estimator.set_keys(**tensor_keys)

def _weighted_loss_entropy(
self, entropy: torch.Tensor | TensorDictBase
) -> torch.Tensor:
"""Compute the weighted entropy loss.

If `self._entropy_coef_map` is provided, apply per-head entropy coefficients.
Otherwise, use the scalar `self.entropy_coef`.
"""
if self._entropy_coef_map is None:
if is_tensor_collection(entropy):
entropy = _sum_td_features(entropy)
return -self.entropy_coef * entropy

loss_terms = []
for key, h in entropy.flatten_keys(separator=".").items():
name = key.split(".")[-1].removesuffix("_entropy")
try:
coeff = self._entropy_coef_map[name]
except KeyError as exc:
raise KeyError(f"Missing entropy coef for head '{name}'") from exc
coeff_t = torch.tensor(coeff, dtype=h.dtype, device=h.device)
loss_terms.append(coeff_t * h.mean())

return -torch.stack(loss_terms).sum()


class ClipPPOLoss(PPOLoss):
"""Clipped PPO loss.
Expand Down Expand Up @@ -939,7 +981,7 @@ def __init__(
clip_epsilon: float = 0.2,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
entropy_coef: float | Mapping[str, float] = 0.01,
critic_coef: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down Expand Up @@ -1066,7 +1108,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
Expand Down