Skip to content

Commit 37da520

Browse files
authored
[BugFix] Fix PPO clip (#786)
1 parent 5b9ff55 commit 37da520

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

torchrl/objectives/ppo.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
233233
f"and {log_weight.shape})"
234234
)
235235
gain1 = log_weight.exp() * advantage
236-
log_weight_clip = torch.empty_like(log_weight)
237-
# log_weight_clip.data.clamp_(*self._clip_bounds)
238-
idx_pos = advantage >= 0
239-
log_weight_clip[idx_pos] = log_weight[idx_pos].clamp_max(self._clip_bounds[1])
240-
log_weight_clip[~idx_pos] = log_weight[~idx_pos].clamp_min(self._clip_bounds[0])
241236

237+
log_weight_clip = log_weight.clamp(*self._clip_bounds)
242238
gain2 = log_weight_clip.exp() * advantage
239+
243240
gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
244241
td_out = TensorDict({"loss_objective": -gain.mean()}, [])
245242

0 commit comments

Comments
 (0)