Closed as not planned
Description
Describe the bug
Not sure this is a bug, but I am unable to use the FlattenObservation
transform when the observation_spec = CompositeSpec(observation=OneHotDiscreteTensorSpec(...))
To Reproduce
Create this env:
__all__ = ["_CustomEnv"]
from loguru import logger
import typing as ty
import torch
from tensordict import TensorDict
from torchrl.data import (
CompositeSpec,
UnboundedContinuousTensorSpec,
BinaryDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.envs import EnvBase
WORST_REWARD = -1e6
class _CustomEnv(EnvBase):
"""Custom dummy environment."""
def __init__(
self,
**kwargs: ty.Any,
) -> None:
super().__init__(**kwargs) # call the constructor of the base class
# Action is a one-hot tensor
self.action_spec = OneHotDiscreteTensorSpec(
n=10,
shape=(10,),
device=self.device,
dtype=torch.float32,
)
# Observation space
observation_spec = OneHotDiscreteTensorSpec(
n=13,
shape=(8, 8, 13),
device=self.device,
dtype=torch.float32,
)
self.observation_spec = CompositeSpec(observation=observation_spec)
# Unlimited reward space
self.reward_spec = UnboundedContinuousTensorSpec(
shape=torch.Size([1]),
device=self.device,
dtype=torch.float32,
)
# Done
self.done_spec = BinaryDiscreteTensorSpec(
n=1,
shape=torch.Size([1]),
device=self.device,
dtype=torch.bool,
)
logger.debug(f"action_spec: {self.action_spec}")
logger.debug(f"observation_spec: {self.observation_spec}")
logger.debug(f"reward_spec: {self.reward_spec}")
def _reset(self, tensordict: TensorDict = None, **kwargs: ty.Any) -> TensorDict:
"""The `_reset()` method potentialy takes in a `TensorDict` and some kwargs which may contain data used in the resetting of the environment and returns a new `TensorDict` with an initial observation of the environment.
The output `TensorDict` has to be new because the input tensordict is immutable.
Args:
tensordict (TensorDict):
Immutable input.
Returns:
TensorDict:
Initial state.
"""
logger.debug("Resetting environment.")
# Return new TensorDict
return TensorDict(
{
"observation": torch.zeros(
(8, 8, 13), dtype=self.observation_spec.dtype, device=self.device
),
"reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(self.device),
"done": False,
},
batch_size=torch.Size(),
device=self.device,
)
def _step(self, tensordict: TensorDict) -> TensorDict:
"""The `_step()` method takes in a `TensorDict` from which it reads an action, applies the action and returns a new `TensorDict` containing the observation, reward and done signal for that timestep.
Args:
tensordict (TensorDict): _description_
Returns:
TensorDict: _description_
"""
# Return new TensorDict
td = TensorDict(
{
"observation": torch.zeros(
(8, 8, 13), dtype=self.observation_spec.dtype, device=self.device
),
"reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(self.device),
"done": True,
},
batch_size=torch.Size(),
device=self.device,
)
logger.trace(f"Returning new TensorDict: {td}")
return td
Then, transform it:
# return base_env
env = TransformedEnv(
_CustomEnv(),
transform=Compose(
FlattenObservation(
first_dim=0,
last_dim=-1,
in_keys=self.in_keys,
allow_positive_dim=True,
),
StepCounter(),
),
)
Then use it anywhere. Even just printing will raise an error:
logger.debug(f"observation_spec: {env.observation_spec}")
Will raise:
logger.debug(f"observation_spec: {env.observation_spec}")
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/common.py:1239: in observation_spec
observation_spec = self.output_spec["full_observation_spec"]
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:725: in output_spec
output_spec = self.transform.transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:1025: in transform_output_spec
output_spec = t.transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:5274: in transform_output_spec
return super().transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:380: in transform_output_spec
output_spec = output_spec.clone()
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3720: in clone
{
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3721: in <dictcomp>
key: item.clone() if item is not None else None
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3720: in clone
{
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3721: in <dictcomp>
key: item.clone() if item is not None else None
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:1250: in clone
return self.__class__(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <[AttributeError("'OneHotDiscreteTensorSpec' object has no attribute 'shape'") raised in repr()] OneHotDiscreteTensorSpec object at 0x2867d5000>
n = 13, shape = torch.Size([832]), device = device(type='cpu'), dtype = torch.float32, use_register = False, mask = None
def __init__(
self,
n: int,
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.bool,
use_register: bool = False,
mask: torch.Tensor | None = None,
):
dtype, device = _default_dtype_and_device(dtype, device)
self.use_register = use_register
space = DiscreteBox(n)
if shape is None:
shape = torch.Size((space.n,))
else:
shape = torch.Size(shape)
if not len(shape) or shape[-1] != space.n:
> raise ValueError(
f"The last value of the shape must match n for transform of type {self.__class__}. "
f"Got n={space.n} and shape={shape}."
)
E ValueError: The last value of the shape must match n for transform of type <class 'torchrl.data.tensor_specs.OneHotDiscreteTensorSpec'>. Got n=13 and shape=torch.Size([832]).
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:1206: ValueError
Calling env.reset()
will raise the same error, too.
Expected behavior
TorchRL should not complain that the observation has wrong size. We want it to be of the "wrong" size as we want to flatten it.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
# > 0.3.0 1.26.4 3.10.10 (main, Sep 14 2023, 16:59:47) [Clang 14.0.3 (clang-1403.0.22.14.1)] darwin
Reason and Possible fixes
Not sure, but I'd be happy to work on one.
Checklist
- [v] I have checked that there is no similar issue in the repo (required)
- [v] I have read the documentation (required)
- [v] I have provided a minimal working example to reproduce the bug (required)