Skip to content

Commit

Permalink
Attempt to greatly simplify space_example
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Oct 11, 2023
1 parent 1978bc6 commit 61cae13
Showing 1 changed file with 1 addition and 26 deletions.
27 changes: 1 addition & 26 deletions stable_baselines3/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,7 @@ def space_to_example(
device: Optional[th.device] = None,
ensure_non_batch_dim: bool = False,
) -> TensorTree:
if isinstance(space, spaces.Dict):
return {
k: space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim)
for k, v in space.items()
}
if isinstance(space, spaces.Tuple):
return tuple(space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for v in space)

if isinstance(space, spaces.Box):
space_shape = space.shape
space_dtype = th.float32
elif isinstance(space, spaces.Discrete):
space_shape = ()
space_dtype = th.long
elif isinstance(space, spaces.MultiDiscrete):
space_shape = (len(space.nvec),)
space_dtype = th.long
elif isinstance(space, spaces.MultiBinary):
space_shape = space.n if isinstance(space.n, tuple) else (space.n,)
space_dtype = th.float32
else:
raise TypeError(f"Unknown space type {type(space)} for {space}")

if ensure_non_batch_dim and not space_shape:
space_shape = (1,)
return th.zeros((*batch_shape, *space_shape), dtype=space_dtype, device=device)
return tree_map(lambda x: th.as_tensor(x).expand((*batch_shape, *x.shape)), space.sample())


class RecurrentRolloutBuffer(RolloutBuffer):
Expand Down

0 comments on commit 61cae13

Please sign in to comment.