We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5b9ff55 commit 37da520Copy full SHA for 37da520
torchrl/objectives/ppo.py
@@ -233,13 +233,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
233
f"and {log_weight.shape})"
234
)
235
gain1 = log_weight.exp() * advantage
236
- log_weight_clip = torch.empty_like(log_weight)
237
- # log_weight_clip.data.clamp_(*self._clip_bounds)
238
- idx_pos = advantage >= 0
239
- log_weight_clip[idx_pos] = log_weight[idx_pos].clamp_max(self._clip_bounds[1])
240
- log_weight_clip[~idx_pos] = log_weight[~idx_pos].clamp_min(self._clip_bounds[0])
241
+ log_weight_clip = log_weight.clamp(*self._clip_bounds)
242
gain2 = log_weight_clip.exp() * advantage
+
243
gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
244
td_out = TensorDict({"loss_objective": -gain.mean()}, [])
245
0 commit comments