Skip to content

[BUG] FlattenObservation transform with OneHotDiscreteTensorSpec #1904

Closed as not planned
@svnv-svsv-jm

Description

@svnv-svsv-jm

Describe the bug

Not sure this is a bug, but I am unable to use the FlattenObservation transform when the observation_spec = CompositeSpec(observation=OneHotDiscreteTensorSpec(...))

To Reproduce

Create this env:

__all__ = ["_CustomEnv"]

from loguru import logger
import typing as ty

import torch

from tensordict import TensorDict
from torchrl.data import (
    CompositeSpec,
    UnboundedContinuousTensorSpec,
    BinaryDiscreteTensorSpec,
    OneHotDiscreteTensorSpec,
)
from torchrl.envs import EnvBase


WORST_REWARD = -1e6


class _CustomEnv(EnvBase):
    """Custom dummy environment."""

    def __init__(
        self,
        **kwargs: ty.Any,
    ) -> None:
        super().__init__(**kwargs)  # call the constructor of the base class
        # Action is a one-hot tensor
        self.action_spec = OneHotDiscreteTensorSpec(
            n=10,
            shape=(10,),
            device=self.device,
            dtype=torch.float32,
        )
        # Observation space
        observation_spec = OneHotDiscreteTensorSpec(
            n=13,
            shape=(8, 8, 13),
            device=self.device,
            dtype=torch.float32,
        )
        self.observation_spec = CompositeSpec(observation=observation_spec)
        # Unlimited reward space
        self.reward_spec = UnboundedContinuousTensorSpec(
            shape=torch.Size([1]),
            device=self.device,
            dtype=torch.float32,
        )
        # Done
        self.done_spec = BinaryDiscreteTensorSpec(
            n=1,
            shape=torch.Size([1]),
            device=self.device,
            dtype=torch.bool,
        )
        logger.debug(f"action_spec: {self.action_spec}")
        logger.debug(f"observation_spec: {self.observation_spec}")
        logger.debug(f"reward_spec: {self.reward_spec}")

    def _reset(self, tensordict: TensorDict = None, **kwargs: ty.Any) -> TensorDict:
        """The `_reset()` method potentialy takes in a `TensorDict` and some kwargs which may contain data used in the resetting of the environment and returns a new `TensorDict` with an initial observation of the environment.

        The output `TensorDict` has to be new because the input tensordict is immutable.

        Args:
            tensordict (TensorDict):
                Immutable input.

        Returns:
            TensorDict:
                Initial state.
        """
        logger.debug("Resetting environment.")
        # Return new TensorDict
        return TensorDict(
            {
                "observation": torch.zeros(
                    (8, 8, 13), dtype=self.observation_spec.dtype, device=self.device
                ),
                "reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(self.device),
                "done": False,
            },
            batch_size=torch.Size(),
            device=self.device,
        )

    def _step(self, tensordict: TensorDict) -> TensorDict:
        """The `_step()` method takes in a `TensorDict` from which it reads an action, applies the action and returns a new `TensorDict` containing the observation, reward and done signal for that timestep.

        Args:
            tensordict (TensorDict): _description_

        Returns:
            TensorDict: _description_
        """
        # Return new TensorDict
        td = TensorDict(
            {
                "observation": torch.zeros(
                    (8, 8, 13), dtype=self.observation_spec.dtype, device=self.device
                ),
                "reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(self.device),
                "done": True,
            },
            batch_size=torch.Size(),
            device=self.device,
        )
        logger.trace(f"Returning new TensorDict: {td}")
        return td

Then, transform it:

# return base_env
env = TransformedEnv(
            _CustomEnv(),
            transform=Compose(
                FlattenObservation(
                    first_dim=0,
                    last_dim=-1,
                    in_keys=self.in_keys,
                    allow_positive_dim=True,
                ),
                StepCounter(),
            ),
        )

Then use it anywhere. Even just printing will raise an error:

logger.debug(f"observation_spec: {env.observation_spec}")

Will raise:

logger.debug(f"observation_spec: {env.observation_spec}")
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/common.py:1239: in observation_spec
    observation_spec = self.output_spec["full_observation_spec"]
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:725: in output_spec
    output_spec = self.transform.transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:1025: in transform_output_spec
    output_spec = t.transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:5274: in transform_output_spec
    return super().transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:380: in transform_output_spec
    output_spec = output_spec.clone()
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3720: in clone
    {
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3721: in <dictcomp>
    key: item.clone() if item is not None else None
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3720: in clone
    {
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3721: in <dictcomp>
    key: item.clone() if item is not None else None
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:1250: in clone
    return self.__class__(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <[AttributeError("'OneHotDiscreteTensorSpec' object has no attribute 'shape'") raised in repr()] OneHotDiscreteTensorSpec object at 0x2867d5000>
n = 13, shape = torch.Size([832]), device = device(type='cpu'), dtype = torch.float32, use_register = False, mask = None

    def __init__(
        self,
        n: int,
        shape: Optional[torch.Size] = None,
        device: Optional[DEVICE_TYPING] = None,
        dtype: Optional[Union[str, torch.dtype]] = torch.bool,
        use_register: bool = False,
        mask: torch.Tensor | None = None,
    ):
        dtype, device = _default_dtype_and_device(dtype, device)
        self.use_register = use_register
        space = DiscreteBox(n)
        if shape is None:
            shape = torch.Size((space.n,))
        else:
            shape = torch.Size(shape)
            if not len(shape) or shape[-1] != space.n:
>               raise ValueError(
                    f"The last value of the shape must match n for transform of type {self.__class__}. "
                    f"Got n={space.n} and shape={shape}."
                )
E               ValueError: The last value of the shape must match n for transform of type <class 'torchrl.data.tensor_specs.OneHotDiscreteTensorSpec'>. Got n=13 and shape=torch.Size([832]).

../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:1206: ValueError

Calling env.reset() will raise the same error, too.

Expected behavior

TorchRL should not complain that the observation has wrong size. We want it to be of the "wrong" size as we want to flatten it.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
# > 0.3.0 1.26.4 3.10.10 (main, Sep 14 2023, 16:59:47) [Clang 14.0.3 (clang-1403.0.22.14.1)] darwin

Reason and Possible fixes

Not sure, but I'd be happy to work on one.

Checklist

  • [v] I have checked that there is no similar issue in the repo (required)
  • [v] I have read the documentation (required)
  • [v] I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

Good first issueA good way to start hacking torchrl!bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions