diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 79824b16a..5140497e8 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -114,32 +114,7 @@ def space_to_example( device: Optional[th.device] = None, ensure_non_batch_dim: bool = False, ) -> TensorTree: - if isinstance(space, spaces.Dict): - return { - k: space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) - for k, v in space.items() - } - if isinstance(space, spaces.Tuple): - return tuple(space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for v in space) - - if isinstance(space, spaces.Box): - space_shape = space.shape - space_dtype = th.float32 - elif isinstance(space, spaces.Discrete): - space_shape = () - space_dtype = th.long - elif isinstance(space, spaces.MultiDiscrete): - space_shape = (len(space.nvec),) - space_dtype = th.long - elif isinstance(space, spaces.MultiBinary): - space_shape = space.n if isinstance(space.n, tuple) else (space.n,) - space_dtype = th.float32 - else: - raise TypeError(f"Unknown space type {type(space)} for {space}") - - if ensure_non_batch_dim and not space_shape: - space_shape = (1,) - return th.zeros((*batch_shape, *space_shape), dtype=space_dtype, device=device) + return tree_map(lambda x: th.as_tensor(x).expand((*batch_shape, *x.shape)), space.sample()) class RecurrentRolloutBuffer(RolloutBuffer):