27
27
NdUnboundedContinuousTensorSpec ,
28
28
TensorSpec ,
29
29
UnboundedContinuousTensorSpec ,
30
+ BinaryDiscreteTensorSpec ,
30
31
)
31
32
from torchrl .data .tensordict .tensordict import _TensorDict , TensorDict
32
33
from torchrl .envs .common import _EnvClass , make_tensordict
50
51
"DoubleToFloat" ,
51
52
"CatTensors" ,
52
53
"NoopResetEnv" ,
53
- "BinerizeReward " ,
54
+ "BinarizeReward " ,
54
55
"PinMemoryTransform" ,
55
56
"VecNorm" ,
56
57
"gSDENoise" ,
@@ -576,7 +577,7 @@ def __repr__(self) -> str:
576
577
)
577
578
578
579
579
- class BinerizeReward (Transform ):
580
+ class BinarizeReward (Transform ):
580
581
"""
581
582
Maps the reward to a binary value (0 or 1) if the reward is null or
582
583
non-null, respectively.
@@ -591,19 +592,14 @@ def __init__(self, keys: Optional[Sequence[str]] = None):
591
592
super ().__init__ (keys = keys )
592
593
593
594
def _apply (self , reward : torch .Tensor ) -> torch .Tensor :
594
- return (reward != 0.0 ).to (reward .dtype )
595
+ if not reward .shape or reward .shape [- 1 ] != 1 :
596
+ raise RuntimeError (
597
+ f"Reward shape last dimension must be singleton, got reward of shape { reward .shape } "
598
+ )
599
+ return (reward > 0.0 ).to (torch .long )
595
600
596
601
def transform_reward_spec (self , reward_spec : TensorSpec ) -> TensorSpec :
597
- if isinstance (reward_spec , UnboundedContinuousTensorSpec ):
598
- return BoundedTensorSpec (
599
- 0.0 , 1.0 , device = reward_spec .device , dtype = reward_spec .dtype
600
- )
601
- else :
602
- raise NotImplementedError (
603
- f"{ self .__class__ .__name__ } .transform_reward_spec not "
604
- f"implemented for tensor spec of type "
605
- f"{ type (reward_spec ).__name__ } "
606
- )
602
+ return BinaryDiscreteTensorSpec (n = 1 , device = reward_spec .device )
607
603
608
604
609
605
class Resize (ObservationTransform ):
@@ -860,14 +856,12 @@ def reset(self, tensordict: _TensorDict) -> _TensorDict:
860
856
861
857
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
862
858
if isinstance (observation_spec , CompositeSpec ):
863
- return CompositeSpec (
864
- ** {
865
- key : self .transform_observation_spec (_obs_spec )
866
- if key in self .keys
867
- else _obs_spec
868
- for key , _obs_spec in observation_spec ._specs .items ()
869
- }
870
- )
859
+ keys = [key for key in observation_spec .keys () if key in self .keys ]
860
+ for key in keys :
861
+ observation_spec [key ] = self .transform_observation_spec (
862
+ observation_spec [key ]
863
+ )
864
+ return observation_spec
871
865
else :
872
866
_observation_spec = observation_spec
873
867
space = _observation_spec .space
0 commit comments