Skip to content

Commit 60b59b1

Browse files
committed
Fix specs for BinarizeReward and CatFrames transforms (#86)
Failing tests for SSL reason
1 parent 308ed49 commit 60b59b1

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

test/test_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,10 @@ def test_noop_reset_env(self, random):
663663
def test_binerized_reward(self, device):
664664
pass
665665

666+
@pytest.mark.parametrize("device", get_available_devices())
667+
def test_reward_scaling(self, device):
668+
pass
669+
666670
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda device found")
667671
@pytest.mark.parametrize("device", get_available_devices())
668672
def test_pin_mem(self, device):

torchrl/envs/transforms/transforms.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
NdUnboundedContinuousTensorSpec,
2828
TensorSpec,
2929
UnboundedContinuousTensorSpec,
30+
BinaryDiscreteTensorSpec,
3031
)
3132
from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict
3233
from torchrl.envs.common import _EnvClass, make_tensordict
@@ -50,7 +51,7 @@
5051
"DoubleToFloat",
5152
"CatTensors",
5253
"NoopResetEnv",
53-
"BinerizeReward",
54+
"BinarizeReward",
5455
"PinMemoryTransform",
5556
"VecNorm",
5657
"gSDENoise",
@@ -576,7 +577,7 @@ def __repr__(self) -> str:
576577
)
577578

578579

579-
class BinerizeReward(Transform):
580+
class BinarizeReward(Transform):
580581
"""
581582
Maps the reward to a binary value (0 or 1) if the reward is null or
582583
non-null, respectively.
@@ -591,19 +592,14 @@ def __init__(self, keys: Optional[Sequence[str]] = None):
591592
super().__init__(keys=keys)
592593

593594
def _apply(self, reward: torch.Tensor) -> torch.Tensor:
594-
return (reward != 0.0).to(reward.dtype)
595+
if not reward.shape or reward.shape[-1] != 1:
596+
raise RuntimeError(
597+
f"Reward shape last dimension must be singleton, got reward of shape {reward.shape}"
598+
)
599+
return (reward > 0.0).to(torch.long)
595600

596601
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
597-
if isinstance(reward_spec, UnboundedContinuousTensorSpec):
598-
return BoundedTensorSpec(
599-
0.0, 1.0, device=reward_spec.device, dtype=reward_spec.dtype
600-
)
601-
else:
602-
raise NotImplementedError(
603-
f"{self.__class__.__name__}.transform_reward_spec not "
604-
f"implemented for tensor spec of type "
605-
f"{type(reward_spec).__name__}"
606-
)
602+
return BinaryDiscreteTensorSpec(n=1, device=reward_spec.device)
607603

608604

609605
class Resize(ObservationTransform):
@@ -860,14 +856,12 @@ def reset(self, tensordict: _TensorDict) -> _TensorDict:
860856

861857
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
862858
if isinstance(observation_spec, CompositeSpec):
863-
return CompositeSpec(
864-
**{
865-
key: self.transform_observation_spec(_obs_spec)
866-
if key in self.keys
867-
else _obs_spec
868-
for key, _obs_spec in observation_spec._specs.items()
869-
}
870-
)
859+
keys = [key for key in observation_spec.keys() if key in self.keys]
860+
for key in keys:
861+
observation_spec[key] = self.transform_observation_spec(
862+
observation_spec[key]
863+
)
864+
return observation_spec
871865
else:
872866
_observation_spec = observation_spec
873867
space = _observation_spec.space

0 commit comments

Comments
 (0)