File tree 3 files changed +13
-5
lines changed
.github/unittest/linux/scripts
torchrl/modules/distributions
3 files changed +13
-5
lines changed Original file line number Diff line number Diff line change @@ -34,3 +34,5 @@ dependencies:
34
34
- transformers
35
35
- ninja
36
36
- timm
37
+ - gymnasium[atari,accept-rom-license]
38
+ - mo-gymnasium[mujoco]
Original file line number Diff line number Diff line change @@ -87,11 +87,6 @@ conda env update --file "${this_dir}/environment.yml" --prune
87
87
conda deactivate
88
88
conda activate " ${env_dir} "
89
89
90
- echo " installing gymnasium"
91
- pip3 install " gymnasium[atari,accept-rom-license]"
92
- pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
93
- pip3 install " mujoco" -U
94
-
95
90
# sanity check: remove?
96
91
python3 -c " " "
97
92
import dm_control
Original file line number Diff line number Diff line change @@ -389,6 +389,17 @@ def sample(
389
389
) -> torch .Tensor :
390
390
...
391
391
392
+ @property
393
+ def deterministic_sample (self ):
394
+ return self .mode
395
+
396
+ @property
397
+ def mode (self ) -> torch .Tensor :
398
+ if hasattr (self , "logits" ):
399
+ return (self .logits == self .logits .max (- 1 , True )[0 ]).to (torch .long )
400
+ else :
401
+ return (self .probs == self .probs .max (- 1 , True )[0 ]).to (torch .long )
402
+
392
403
def log_prob (self , value : torch .Tensor ) -> torch .Tensor :
393
404
return super ().log_prob (value .argmax (dim = - 1 ))
394
405
You can’t perform that action at this time.
0 commit comments