Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class PPOLoss(LossModule):
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.
log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
critic_coef (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
loss from the forward outputs.
Expand Down Expand Up @@ -349,6 +352,7 @@ def __init__(
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float | Mapping[str, float] = 0.01,
log_explained_variance: bool = True,
critic_coef: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down Expand Up @@ -413,6 +417,7 @@ def __init__(
self.critic_network_params = None
self.target_critic_network_params = None

self.log_explained_variance = log_explained_variance
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus
self.separate_losses = separate_losses
Expand Down Expand Up @@ -745,6 +750,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
self.loss_critic_type,
)

explained_variance = None
if self.log_explained_variance:
with torch.no_grad(): # <‑‑ break grad‐flow
tgt = target_return.detach()
pred = state_value.detach()
eps = torch.finfo(tgt.dtype).eps
resid = torch.var(tgt - pred, unbiased=False, dim=0)
total = torch.var(tgt, unbiased=False, dim=0)
explained_variance = 1.0 - resid / (total + eps)

self._clear_weakrefs(
tensordict,
"actor_network_params",
Expand All @@ -753,8 +768,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
"target_critic_network_params",
)
if self._has_critic:
return self.critic_coef * loss_value, clip_fraction
return loss_value, clip_fraction
return self.critic_coef * loss_value, clip_fraction, explained_variance
return loss_value, clip_fraction, explained_variance

@property
@_cache_values
Expand Down Expand Up @@ -804,10 +819,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
if explained_variance is not None:
td_out.set("explained_variance", explained_variance)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
Expand Down Expand Up @@ -1172,10 +1189,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
if explained_variance is not None:
td_out.set("explained_variance", explained_variance)

td_out.set("ESS", _reduce(ess, self.reduction) / batch)
td_out = td_out.named_apply(
Expand Down Expand Up @@ -1518,10 +1537,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
if explained_variance is not None:
td_out.set("explained_variance", explained_variance)
td_out = td_out.named_apply(
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
Expand Down