We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0a410ff commit e294c68Copy full SHA for e294c68
torchrl/modules/distributions/discrete.py
@@ -389,6 +389,17 @@ def sample(
389
) -> torch.Tensor:
390
...
391
392
+ @property
393
+ def deterministic_sample(self):
394
+ return self.mode
395
+
396
397
+ def mode(self) -> torch.Tensor:
398
+ if hasattr(self, "logits"):
399
+ return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
400
+ else:
401
+ return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
402
403
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
404
return super().log_prob(value.argmax(dim=-1))
405
0 commit comments