Skip to content

Commit e294c68

Browse files
author
Vincent Moens
committed
[Feature] Deterministic sample for Masked one-hot
ghstack-source-id: 27787ea Pull Request resolved: #2440
1 parent 0a410ff commit e294c68

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torchrl/modules/distributions/discrete.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,17 @@ def sample(
389389
) -> torch.Tensor:
390390
...
391391

392+
@property
393+
def deterministic_sample(self):
394+
return self.mode
395+
396+
@property
397+
def mode(self) -> torch.Tensor:
398+
if hasattr(self, "logits"):
399+
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
400+
else:
401+
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
402+
392403
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
393404
return super().log_prob(value.argmax(dim=-1))
394405

0 commit comments

Comments
 (0)