Skip to content

Commit deef000

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 51b0133 + 5708bf1 commit deef000

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@ dependencies:
3434
- transformers
3535
- ninja
3636
- timm
37+
- gymnasium[atari,accept-rom-license]
38+
- mo-gymnasium[mujoco]

.github/unittest/linux/scripts/run_all.sh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,6 @@ conda env update --file "${this_dir}/environment.yml" --prune
8787
conda deactivate
8888
conda activate "${env_dir}"
8989

90-
echo "installing gymnasium"
91-
pip3 install "gymnasium[atari,accept-rom-license]"
92-
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
93-
pip3 install "mujoco" -U
94-
9590
# sanity check: remove?
9691
python3 -c """
9792
import dm_control

torchrl/modules/distributions/discrete.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,17 @@ def sample(
389389
) -> torch.Tensor:
390390
...
391391

392+
@property
393+
def deterministic_sample(self):
394+
return self.mode
395+
396+
@property
397+
def mode(self) -> torch.Tensor:
398+
if hasattr(self, "logits"):
399+
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
400+
else:
401+
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
402+
392403
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
393404
return super().log_prob(value.argmax(dim=-1))
394405

0 commit comments

Comments
 (0)