Skip to content

Commit

Permalink
Base recurrent policy (re-submission) (#13)
Browse files Browse the repository at this point in the history
Re-submission of #8 with all dependent PRs.
  • Loading branch information
rhaps0dy authored Oct 13, 2023
2 parents c0ac130 + c47c736 commit bfbfc3b
Show file tree
Hide file tree
Showing 8 changed files with 791 additions and 276 deletions.
702 changes: 448 additions & 254 deletions stable_baselines3/common/recurrent/policies.py

Large diffs are not rendered by default.

211 changes: 211 additions & 0 deletions stable_baselines3/common/recurrent/torch_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import abc
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar

import gymnasium as gym
import torch as th

from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.pytree_dataclass import TensorTree, tree_flatten, tree_map
from stable_baselines3.common.recurrent.type_aliases import (
GRURecurrentState,
LSTMRecurrentState,
)
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
)
from stable_baselines3.common.type_aliases import TorchGymObs

RecurrentState = TypeVar("RecurrentState", bound=TensorTree)

RecurrentSubState = TypeVar("RecurrentSubState", bound=TensorTree)

ExtractorInput = TypeVar("ExtractorInput", bound=TorchGymObs)


class RecurrentFeaturesExtractor(BaseFeaturesExtractor, abc.ABC, Generic[ExtractorInput, RecurrentState]):
@abc.abstractmethod
def recurrent_initial_state(
self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None
) -> RecurrentState:
...

@abc.abstractmethod
def forward(
self, observations: ExtractorInput, state: RecurrentState, episode_starts: th.Tensor
) -> Tuple[th.Tensor, RecurrentState]:
...

@staticmethod
def _process_sequence(
rnn: th.nn.RNNBase, inputs: th.Tensor, init_state: RecurrentSubState, episode_starts: th.Tensor
) -> Tuple[th.Tensor, RecurrentSubState]:
(state_example, *_), _ = tree_flatten(init_state, is_leaf=None)
n_layers, batch_sz, *_ = state_example.shape
assert n_layers == rnn.num_layers

# Batch to sequence
# (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim)
seq_len = inputs.shape[0] // batch_sz
seq_inputs = inputs.view((batch_sz, seq_len, *inputs.shape[1:])).swapaxes(0, 1)
episode_starts = episode_starts.view((batch_sz, seq_len)).swapaxes(0, 1)

if th.any(episode_starts[1:]):
raise NotImplementedError("Resetting state in the middle of a sequence is not supported")

first_state_is_not_reset = (~episode_starts[0]).contiguous()

def _reset_state_component(state: th.Tensor) -> th.Tensor:
assert state.shape == (rnn.num_layers, batch_sz, rnn.hidden_size)
reset_mask = first_state_is_not_reset.view((1, batch_sz, 1))
return state * reset_mask

init_state = tree_map(_reset_state_component, init_state)
rnn_output, end_state = rnn(seq_inputs, init_state)

# (seq_len, batch_size, ...) -> (batch_size, seq_len, ...) -> (batch_size * seq_len, ...)
rnn_output = rnn_output.transpose(0, 1).reshape((batch_sz * seq_len, *rnn_output.shape[2:]))
return rnn_output, end_state


class GRUWrappedFeaturesExtractor(RecurrentFeaturesExtractor[ExtractorInput, GRURecurrentState], Generic[ExtractorInput]):
def __init__(
self,
observation_space: gym.Space,
base_extractor: BaseFeaturesExtractor,
features_dim: Optional[int] = None,
num_layers: int = 1,
bias: bool = True,
dropout: float = 0.0,
):
if features_dim is None:
# Ensure features_dim is at least 64 by default so it optimizes fine
features_dim = max(base_extractor.features_dim, 64)

assert observation_space == base_extractor._observation_space

super().__init__(observation_space, features_dim)
self.base_extractor = base_extractor

self.rnn = th.nn.GRU(
input_size=base_extractor.features_dim,
hidden_size=features_dim,
num_layers=num_layers,
bias=bias,
batch_first=False,
dropout=dropout,
bidirectional=False,
)

def recurrent_initial_state(
self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None
) -> GRURecurrentState:
shape: Tuple[int, ...]
if n_envs is None:
shape = (self.rnn.num_layers, self.rnn.hidden_size)
else:
shape = (self.rnn.num_layers, n_envs, self.rnn.hidden_size)
return th.zeros(shape, device=device)

def forward(
self, observations: ExtractorInput, state: GRURecurrentState, episode_starts: th.Tensor
) -> Tuple[th.Tensor, GRURecurrentState]:
features: th.Tensor = self.base_extractor(observations)
return self._process_sequence(self.rnn, features, state, episode_starts)

@property
def features_dim(self) -> int:
return self.rnn.hidden_size


class GRUFlattenExtractor(GRUWrappedFeaturesExtractor[th.Tensor]):
def __init__(
self,
observation_space: gym.Space,
features_dim: int = 64,
num_layers: int = 1,
bias: bool = True,
dropout: float = 0.0,
) -> None:
base_extractor = FlattenExtractor(observation_space)
super().__init__(
observation_space, base_extractor, features_dim=features_dim, num_layers=num_layers, bias=bias, dropout=dropout
)


class GRUNatureCNNExtractor(GRUWrappedFeaturesExtractor[th.Tensor]):
def __init__(
self,
observation_space: gym.Space,
features_dim: int = 512,
normalized_image: bool = False,
num_layers: int = 1,
bias: bool = True,
dropout: float = 0.0,
) -> None:
base_extractor = NatureCNN(observation_space, features_dim=features_dim, normalized_image=normalized_image)
super().__init__(
observation_space, base_extractor, features_dim=features_dim, num_layers=num_layers, bias=bias, dropout=dropout
)


class GRUCombinedExtractor(GRUWrappedFeaturesExtractor[Dict[Any, th.Tensor]]):
def __init__(
self,
observation_space: gym.spaces.Dict,
features_dim: int = 64,
cnn_output_dim: int = 256,
normalized_image: bool = False,
num_layers: int = 1,
bias: bool = True,
dropout: float = 0.0,
) -> None:
base_extractor = CombinedExtractor(observation_space, cnn_output_dim=cnn_output_dim, normalized_image=normalized_image)
super().__init__(
observation_space, base_extractor, features_dim=features_dim, num_layers=num_layers, bias=bias, dropout=dropout
)


class LSTMFlattenExtractor(RecurrentFeaturesExtractor[th.Tensor, LSTMRecurrentState]):
def __init__(
self,
observation_space: gym.Space,
features_dim: int = 64,
num_layers: int = 1,
bias: bool = True,
dropout: float = 0.0,
):
super().__init__(observation_space, features_dim)

self.rnn = th.nn.LSTM(
input_size=get_flattened_obs_dim(self._observation_space),
hidden_size=features_dim,
num_layers=num_layers,
bias=bias,
batch_first=False,
dropout=dropout,
bidirectional=False,
)
self.base_extractor = FlattenExtractor(observation_space)

def recurrent_initial_state(
self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None
) -> LSTMRecurrentState:
shape: Tuple[int, ...]
if n_envs is None:
shape = (self.rnn.num_layers, self.rnn.hidden_size)
else:
shape = (self.rnn.num_layers, n_envs, self.rnn.hidden_size)
return (th.zeros(shape, device=device), th.zeros(shape, device=device))

def forward(
self, observations: th.Tensor, state: LSTMRecurrentState, episode_starts: th.Tensor
) -> Tuple[th.Tensor, LSTMRecurrentState]:
features: th.Tensor = self.base_extractor(observations)
return self._process_sequence(self.rnn, features, state, episode_starts)

@property
def features_dim(self) -> int:
return self.rnn.hidden_size
19 changes: 7 additions & 12 deletions stable_baselines3/common/recurrent/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
from typing import Optional, Tuple, TypeVar
from typing import Generic, Tuple, TypeVar

import torch as th

from stable_baselines3.common.pytree_dataclass import FrozenPyTreeDataclass, TensorTree

T = TypeVar("T")
TensorTreeT = TypeVar("TensorTreeT", bound=TensorTree)


def non_null(v: Optional[T]) -> T:
if v is None:
raise ValueError("Expected a value, got None")
return v
LSTMRecurrentState = Tuple[th.Tensor, th.Tensor]
GRURecurrentState = th.Tensor


LSTMStates = Tuple[th.Tensor, th.Tensor]


class RNNStates(FrozenPyTreeDataclass[th.Tensor]):
pi: LSTMStates
vf: LSTMStates
class ActorCriticStates(FrozenPyTreeDataclass[th.Tensor], Generic[TensorTreeT]):
pi: TensorTreeT
vf: TensorTreeT


class RecurrentRolloutBufferData(FrozenPyTreeDataclass[th.Tensor]):
Expand Down
26 changes: 26 additions & 0 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
Protocol,
SupportsFloat,
Tuple,
Type,
TypeVar,
Union,
get_origin,
)

import gymnasium as gym
Expand Down Expand Up @@ -117,3 +120,26 @@ def device(self) -> th.device:
:return: the device on which this predictor lives
"""
...


T = TypeVar("T")


def non_null(v: Optional[T]) -> T:
"""
Checks that `v` is not None, and returns it.
"""
if v is None:
raise ValueError("Expected a value, got None")
return v


def check_cast(cls: Type[T], v: Any) -> T:
"""
Checks that `v` is of type `cls`, and returns it.
NOTE: this function does not check the template arguments, only the type itself.
"""
if not isinstance(v, get_origin(cls) or cls):
raise TypeError(f"{v} should be of type {cls}")
return v
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
EnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]


def tile_images(images_nhwc: Sequence[th.Tensor]) -> th.Tensor: # pragma: no cover
def tile_images(images_nhwc: Sequence[th.Tensor] | th.Tensor) -> th.Tensor: # pragma: no cover
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
Expand Down
17 changes: 12 additions & 5 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.pytree_dataclass import tree_map
from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer
from stable_baselines3.common.recurrent.policies import RecurrentActorCriticPolicy
from stable_baselines3.common.recurrent.policies import (
BaseRecurrentActorCriticPolicy,
RecurrentActorCriticPolicy,
)
from stable_baselines3.common.recurrent.type_aliases import (
ActorCriticStates,
RecurrentRolloutBufferData,
RNNStates,
)
from stable_baselines3.common.type_aliases import (
GymEnv,
MaybeCallback,
Schedule,
non_null,
)
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import (
explained_variance,
get_schedule_fn,
Expand Down Expand Up @@ -95,7 +102,7 @@ class RecurrentPPO(OnPolicyAlgorithm):

def __init__(
self,
policy: Union[str, Type[RecurrentActorCriticPolicy]],
policy: Union[str, Type[BaseRecurrentActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 128,
Expand Down Expand Up @@ -177,7 +184,7 @@ def __init__(
self.clip_range_vf: Schedule = clip_range_vf # type: ignore
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
self._last_lstm_states: Optional[RNNStates] = None
self._last_lstm_states: Optional[ActorCriticStates] = None

if _init_setup_model:
self._setup_model()
Expand Down
Loading

0 comments on commit bfbfc3b

Please sign in to comment.