Skip to content

[Refactor] Make TensorSpec a real class and TensorSpecBase a base class #1222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ Here's an example:
TensorSpec
----------

The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
The `TensorSpecBase` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
as shape, device, dtype and domain.
It is important that your environment specs match the input and output that it sends and receives, as
:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Expand All @@ -203,6 +203,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:toctree: generated/
:template: rl_template.rst

TensorSpecBase
TensorSpec
BinaryDiscreteTensorSpec
BoundedTensorSpec
Expand Down
6 changes: 3 additions & 3 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ Each env will have the following attributes:
- :obj:`env.state_spec`: a :class:`~torchrl.data.CompositeSpec` object
containing all the input key-spec pairs (except action). For most stateful
environments, this container will be empty.
- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object
- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpecBase` object
representing the action spec.
- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing
- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpecBase` object representing
the reward spec.
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpecBase` object representing
the done-flag spec.
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the input keys (:obj:`"_action_spec"` and :obj:`"_state_spec"`).
Expand Down
2 changes: 1 addition & 1 deletion test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_imports():
from torchrl.data import (
PrioritizedReplayBuffer,
ReplayBuffer,
TensorSpec,
TensorSpecBase,
) # noqa: F401
from torchrl.envs import Transform, TransformedEnv # noqa: F401
from torchrl.envs.gym_like import GymLikeEnv # noqa: F401
Expand Down
10 changes: 5 additions & 5 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.tensor_specs import TensorSpecBase
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
Expand All @@ -62,7 +62,7 @@ class RandomPolicy:
This is a wrapper around the action_spec.rand method.

Args:
action_spec: TensorSpec object describing the action specs
action_spec: TensorSpecBase object describing the action specs

Examples:
>>> from tensordict import TensorDict
Expand All @@ -72,7 +72,7 @@ class RandomPolicy:
>>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1]
"""

def __init__(self, action_spec: TensorSpec):
def __init__(self, action_spec: TensorSpecBase):
self.action_spec = action_spec

def __call__(self, td: TensorDictBase) -> TensorDictBase:
Expand Down Expand Up @@ -185,7 +185,7 @@ def _get_policy_and_device(
]
] = None,
device: Optional[DEVICE_TYPING] = None,
observation_spec: TensorSpec = None,
observation_spec: TensorSpecBase = None,
) -> Tuple[TensorDictModule, torch.device, Union[None, Callable[[], dict]]]:
"""Util method to get a policy and its device given the collector __init__ inputs.

Expand All @@ -200,7 +200,7 @@ def _get_policy_and_device(
policy (TensorDictModule, optional): a policy to be used
device (int, str or torch.device, optional): device where to place
the policy
observation_spec (TensorSpec, optional): spec of the observations
observation_spec (TensorSpecBase, optional): spec of the observations

"""
if policy is None:
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
TensorSpecBase,
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)
59 changes: 35 additions & 24 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def __repr__(self):


@dataclass(repr=False)
class TensorSpec:
class TensorSpecBase:
"""Parent class of the tensor meta-data containers for observation, actions and rewards.

Args:
Expand Down Expand Up @@ -703,11 +703,11 @@ def zero(self, shape=None) -> torch.Tensor:
return torch.zeros((*shape, *self.shape), dtype=self.dtype, device=self.device)

@abc.abstractmethod
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec":
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpecBase":
raise NotImplementedError

@abc.abstractmethod
def clone(self) -> "TensorSpec":
def clone(self) -> "TensorSpecBase":
raise NotImplementedError

def __repr__(self):
Expand All @@ -733,7 +733,7 @@ def __torch_function__(
if kwargs is None:
kwargs = {}
if func not in cls.SPEC_HANDLED_FUNCTIONS or not all(
issubclass(t, (TensorSpec,)) for t in types
issubclass(t, (TensorSpecBase,)) for t in types
):
return NotImplemented(
f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}"
Expand Down Expand Up @@ -795,7 +795,7 @@ def __getitem__(self, item):
f"Indexing occured along dimension {dim_idx} but stacking was done along dim {self.dim}."
)
out = self._specs[item]
if isinstance(out, TensorSpec):
if isinstance(out, TensorSpecBase):
return out
return torch.stack(list(out), 0)
else:
Expand All @@ -817,7 +817,7 @@ def __getitem__(self, item):
for i, _item in enumerate(item):
if i == self.dim:
out = self._specs[_item]
if isinstance(out, TensorSpec):
if isinstance(out, TensorSpecBase):
return out
return torch.stack(list(out), 0)
elif isinstance(_item, slice):
Expand All @@ -834,7 +834,7 @@ def __getitem__(self, item):
f"Trying to index a {self.__class__.__name__} along dimension 0 when the stack dimension is {self.dim}."
)
out = self._specs[item]
if isinstance(out, TensorSpec):
if isinstance(out, TensorSpecBase):
return out
return torch.stack(list(out), 0)

Expand Down Expand Up @@ -897,7 +897,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T:
return torch.stack([spec.to(dest) for spec in self._specs], self.dim)


class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec):
class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpecBase], TensorSpecBase):
"""A lazy representation of a stack of tensor specs.

Stacks tensor-specs together along one dimension.
Expand Down Expand Up @@ -980,7 +980,7 @@ def set(self, name, spec):


@dataclass(repr=False)
class OneHotDiscreteTensorSpec(TensorSpec):
class OneHotDiscreteTensorSpec(TensorSpecBase):
"""A unidimensional, one-hot discrete tensor spec.

By default, TorchRL assumes that categorical variables are encoded as
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def to_categorical_spec(self) -> DiscreteTensorSpec:


@dataclass(repr=False)
class BoundedTensorSpec(TensorSpec):
class BoundedTensorSpec(TensorSpecBase):
"""A bounded continuous tensor spec.

Args:
Expand Down Expand Up @@ -1508,7 +1508,7 @@ def _is_nested_list(index, notuple=False):


@dataclass(repr=False)
class UnboundedContinuousTensorSpec(TensorSpec):
class UnboundedContinuousTensorSpec(TensorSpecBase):
"""An unbounded continuous tensor spec.

Args:
Expand Down Expand Up @@ -1585,8 +1585,15 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)


TensorSpec = type(
"TensorSpec",
UnboundedContinuousTensorSpec.__bases__,
dict(UnboundedContinuousTensorSpec.__dict__),
)


@dataclass(repr=False)
class UnboundedDiscreteTensorSpec(TensorSpec):
class UnboundedDiscreteTensorSpec(TensorSpecBase):
"""An unbounded discrete tensor spec.

Args:
Expand Down Expand Up @@ -1919,7 +1926,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
)


class DiscreteTensorSpec(TensorSpec):
class DiscreteTensorSpec(TensorSpecBase):
"""A discrete tensor spec.

An alternative to OneHotTensorSpec for categorical variables in TorchRL. Instead of
Expand Down Expand Up @@ -2415,7 +2422,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
)


class CompositeSpec(TensorSpec):
class CompositeSpec(TensorSpecBase):
"""A composition of TensorSpecs.

Args:
Expand Down Expand Up @@ -2928,13 +2935,15 @@ def __eq__(self, other):
and self._specs == other._specs
)

def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None:
def update(
self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpecBase]]
) -> None:
for key, item in dict_or_spec.items():
if key in self.keys(True) and isinstance(self[key], CompositeSpec):
self[key].update(item)
continue
try:
if isinstance(item, TensorSpec) and item.device != self.device:
if isinstance(item, TensorSpecBase) and item.device != self.device:
item = deepcopy(item)
if self.device is not None:
item = item.to(self.device)
Expand Down Expand Up @@ -3108,7 +3117,9 @@ class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec):

"""

def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None:
def update(
self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpecBase]]
) -> None:
pass

def __eq__(self, other):
Expand Down Expand Up @@ -3209,7 +3220,7 @@ def set(self, name, spec):


# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]:
@TensorSpec.implements_for_spec(torch.stack)
@TensorSpecBase.implements_for_spec(torch.stack)
def _stack_specs(list_of_spec, dim, out=None):
if out is not None:
raise NotImplementedError(
Expand All @@ -3219,11 +3230,11 @@ def _stack_specs(list_of_spec, dim, out=None):
if not len(list_of_spec):
raise ValueError("Cannot stack an empty list of specs.")
spec0 = list_of_spec[0]
if isinstance(spec0, TensorSpec):
if isinstance(spec0, TensorSpecBase):
device = spec0.device
all_equal = True
for spec in list_of_spec[1:]:
if not isinstance(spec, TensorSpec):
if not isinstance(spec, TensorSpecBase):
raise RuntimeError(
"Stacking specs cannot occur: Found more than one type of specs in the list."
)
Expand Down Expand Up @@ -3274,8 +3285,8 @@ def _stack_composite_specs(list_of_spec, dim, out=None):
raise NotImplementedError


@TensorSpec.implements_for_spec(torch.squeeze)
def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec:
@TensorSpecBase.implements_for_spec(torch.squeeze)
def _squeeze_spec(spec: TensorSpecBase, *args, **kwargs) -> TensorSpecBase:
return spec.squeeze(*args, **kwargs)


Expand All @@ -3284,8 +3295,8 @@ def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSp
return spec.squeeze(*args, **kwargs)


@TensorSpec.implements_for_spec(torch.unsqueeze)
def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec:
@TensorSpecBase.implements_for_spec(torch.unsqueeze)
def _unsqueeze_spec(spec: TensorSpecBase, *args, **kwargs) -> TensorSpecBase:
return spec.unsqueeze(*args, **kwargs)


Expand Down
Loading