From 585f46e0f2c4fe397dbc3f1e4ad50b08687fd75a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 12:17:08 -0700 Subject: [PATCH 01/16] Introduce BaseRecurrentActorCriticPolicy, and two possible kinds of ActorCriticPolicies --- .../common/recurrent/policies.py | 473 +++++++++++++----- .../common/recurrent/torch_layers.py | 158 ++++++ 2 files changed, 513 insertions(+), 118 deletions(-) create mode 100644 stable_baselines3/common/recurrent/torch_layers.py diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index e18f5c59c..a7260bcf8 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +import abc +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union import torch as th from gymnasium import spaces @@ -6,7 +7,14 @@ from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.preprocessing import preprocess_obs from stable_baselines3.common.pytree_dataclass import tree_flatten +from stable_baselines3.common.recurrent.torch_layers import ( + GRUNatureCNNExtractor, + GRUWrappedFeaturesExtractor, + RecurrentFeaturesExtractor, + RecurrentState, +) from stable_baselines3.common.recurrent.type_aliases import ( LSTMStates, RNNStates, @@ -23,7 +31,164 @@ from stable_baselines3.common.utils import zip_strict -class RecurrentActorCriticPolicy(ActorCriticPolicy): +class BaseRecurrentActorCriticPolicy(ActorCriticPolicy, Generic[RecurrentState]): + @abc.abstractmethod + def recurrent_initial_state( + self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None + ) -> RecurrentState: + ... + + @abc.abstractmethod + def forward( # type: ignore[override] + self, + obs: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RecurrentState]: + """Advances to the next hidden state, and computes all the outputs of a recurrent policy. + + In this docstring the dimension letters are: Time (T), Batch (B) and others (...). + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :param deterministic: if True return the best action, else a sample. + :returns: (actions, values, log_prob, state). The actions, values and log-action-probabilities for every time + step T, and the final state. + """ + ... + + @abc.abstractmethod + def get_distribution( # type: ignore[override] + self, + obs: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + ) -> Tuple[Distribution, RecurrentState]: + """ + Get the policy distribution for each step given the observations. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :return: the action distribution, the new hidden states. + """ + ... + + @abc.abstractmethod + def predict_values( # type: ignore[override] + self, + obs: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :return: The value for each time step. + """ + ... + + @abc.abstractmethod + def evaluate_actions( # type: ignore[override] + self, obs: TorchGymObs, actions: th.Tensor, state: RecurrentState, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param actions: The actions taken at each step. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + ... + + @abc.abstractmethod + def _predict( # type: ignore[override] + self, + observation: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, RecurrentState]: + """ + Get the action according to the policy for a given observation. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :param deterministic: if True return the best action, else a sample. + :return: the model's action and the next hidden state + """ + ... + + def predict( # type: ignore[override] + self, + obs: TorchGymObs, + state: Optional[RecurrentState] = None, + episode_start: Optional[th.Tensor] = None, + deterministic: bool = False, + ) -> Tuple[th.Tensor, Optional[RecurrentState]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :param deterministic: if True return the best action, else a sample. + :return: the model's action and the next hidden state + """ + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + obs, vectorized_env = self.obs_to_tensor(obs) + one_obs_tensor: th.Tensor + (one_obs_tensor, *_), _ = tree_flatten(obs) # type: ignore + n_envs = len(one_obs_tensor) + + if state is None: + state = self.recurrent_initial_state(n_envs, device=self.device) + + if episode_start is None: + episode_start = th.zeros(n_envs, dtype=th.bool) + + with th.no_grad(): + # Convert to PyTorch tensors + actions, state = self._predict(obs, state=state, episode_starts=episode_start, deterministic=deterministic) + + if isinstance(self.action_space, spaces.Box): + if callable(self.squash_output): + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = th.clip( + actions, th.as_tensor(self.action_space.low).to(actions), th.as_tensor(self.action_space.high).to(actions) + ) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions.squeeze(dim=0) + + return actions, state + + +class RecurrentActorCriticPolicy(BaseRecurrentActorCriticPolicy): """ Recurrent policy class for actor-critic algorithms (has both policy and value prediction). To be used with A2C, PPO and the likes. @@ -172,16 +337,6 @@ def _process_sequence( episode_starts: th.Tensor, lstm: nn.LSTM, ) -> Tuple[th.Tensor, LSTMStates]: - """ - Do a forward pass in the LSTM network. - - :param features: Input tensor - :param lstm_states: previous cell and hidden states of the LSTM - :param episode_starts: Indicates when a new episode starts, - in that case, we need to reset LSTM states. - :param lstm: LSTM object. - :return: LSTM output and updated LSTM states. - """ # LSTM logic # (sequence length, batch size, features dim) # (batch size = n_envs for data collection or n_seq when doing gradient update) @@ -277,16 +432,6 @@ def forward( # type: ignore[override] episode_starts: th.Tensor, deterministic: bool = False, ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: - """ - Forward pass in all the networks (actor and critic) - - :param obs: Observation. Observation - :param state: The last hidden and memory states for the LSTM. - :param episode_starts: Whether the observations correspond to new episodes - or not (we reset the lstm states in that case). - :param deterministic: Whether to sample or use deterministic actions - :return: action, value and log probability of the action - """ (latent_pi, latent_vf), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -304,16 +449,6 @@ def get_distribution( # type: ignore[override] state: RNNStates, episode_starts: th.Tensor, ) -> Tuple[Distribution, RNNStates]: - """ - Get the current policy distribution given the observations. - - :param obs: Observation. - :param state: The last hidden and memory states for the LSTM. - :param episode_starts: Whether the observations correspond to new episodes - or not (we reset the lstm states in that case). - :return: the action distribution and new hidden states. - """ - # Call the method from the parent of the parent class (latent_pi, _), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) return self._get_action_dist_from_latent(latent_pi), state @@ -324,15 +459,6 @@ def predict_values( # type: ignore[override] state: RNNStates, episode_starts: th.Tensor, ) -> th.Tensor: - """ - Get the estimated values according to the current policy given the observations. - - :param obs: Observation. - :param state: The last hidden and memory states for the LSTM. - :param episode_starts: Whether the observations correspond to new episodes - or not (we reset the lstm states in that case). - :return: the estimated values. - """ latent_vf = self._recurrent_latent_vf_nostate(obs, state, episode_starts) latent_vf = self.mlp_extractor.forward_critic(latent_vf) return self.value_net(latent_vf) @@ -340,19 +466,6 @@ def predict_values( # type: ignore[override] def evaluate_actions( # type: ignore[override] self, obs: TorchGymObs, actions: th.Tensor, state: RNNStates, episode_starts: th.Tensor ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: - """ - Evaluate actions according to the current policy, - given the observations. - - :param obs: Observation. - :param actions: - :param state: The last hidden and memory states for the LSTM. - :param episode_starts: Whether the observations correspond to new episodes - or not (we reset the lstm states in that case). - :return: estimated value, log likelihood of taking those actions - and entropy of the action distribution. - """ - # Preprocess the observation if needed (latent_pi, latent_vf), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -369,73 +482,9 @@ def _predict( # type: ignore[override] episode_starts: th.Tensor, deterministic: bool = False, ) -> Tuple[th.Tensor, RNNStates]: - """ - Get the action according to the policy for a given observation. - - :param observation: - :param state: The last hidden and memory states for the LSTM. - :param episode_starts: Whether the observations correspond to new episodes - or not (we reset the lstm states in that case). - :param deterministic: Whether to use stochastic or deterministic actions - :return: Taken action according to the policy and hidden states of the RNN - """ distribution, state = self.get_distribution(observation, state, episode_starts) return distribution.get_actions(deterministic=deterministic), state - def predict( # type: ignore[override] - self, - observation: TorchGymObs, - state: Optional[RNNStates] = None, - episode_start: Optional[th.Tensor] = None, - deterministic: bool = False, - ) -> Tuple[th.Tensor, Optional[RNNStates]]: - """ - Get the policy action from an observation (and optional hidden state). - Includes sugar-coating to handle different observations (e.g. normalizing images). - - :param observation: the input observation - :param state: The last hidden and memory states for the LSTM. - :param episode_starts: Whether the observations correspond to new episodes - or not (we reset the lstm states in that case). - :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next hidden state - (used in recurrent policies) - """ - # Switch to eval mode (this affects batch norm / dropout) - self.set_training_mode(False) - - observation, vectorized_env = self.obs_to_tensor(observation) - one_obs_tensor: th.Tensor - (one_obs_tensor, *_), _ = tree_flatten(observation) # type: ignore - n_envs = len(one_obs_tensor) - - if state is None: - state = self.recurrent_initial_state(n_envs) - - if episode_start is None: - episode_start = th.zeros(n_envs, dtype=th.bool) - - with th.no_grad(): - # Convert to PyTorch tensors - actions, state = self._predict(observation, state=state, episode_starts=episode_start, deterministic=deterministic) - - if isinstance(self.action_space, spaces.Box): - if self.squash_output: - # Rescale to proper domain when using squashing - actions = self.unscale_action(actions) - else: - # Actions could be on arbitrary scale, so clip the actions to avoid - # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = th.clip( - actions, th.as_tensor(self.action_space.low).to(actions), th.as_tensor(self.action_space.high).to(actions) - ) - - # Remove batch dimension if needed - if not vectorized_env: - actions = actions.squeeze(dim=0) - - return actions, state - class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): """ @@ -615,3 +664,191 @@ def __init__( enable_critic_lstm, lstm_kwargs, ) + + +class RecurrentFeaturesExtractorActorCriticPolicy(ActorCriticPolicy, Generic[RecurrentState]): + features_extractor: RecurrentFeaturesExtractor[RecurrentState] + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = GRUNatureCNNExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + # Automatically deactivate dtype and bounds checks + if normalize_images is False and issubclass(features_extractor_class, GRUNatureCNNExtractor): + features_extractor_kwargs = features_extractor_kwargs.copy() + features_extractor_kwargs.update(dict(normalized_image=True)) + + if not issubclass(features_extractor_class, RecurrentFeaturesExtractor): + base_features_extractor = features_extractor_class(observation_space, **features_extractor_kwargs) + + features_extractor_class = GRUWrappedFeaturesExtractor + new_features_extractor_kwargs = dict(base_extractor=base_features_extractor) + if "features_dim" in features_extractor_kwargs: + new_features_extractor_kwargs["features_dim"] = features_extractor_kwargs["features_dim"] + features_extractor_kwargs = new_features_extractor_kwargs + print(features_extractor_class, features_extractor_kwargs) + + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + def recurrent_initial_state( + self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None + ) -> RecurrentState: + return self.features_extractor.recurrent_initial_state(n_envs, device=device) + + def _recurrent_extract_features( + self, obs: TorchGymObs, state: RecurrentState, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, RecurrentState]: + if not self.share_features_extractor: + raise NotImplementedError("Non-shared features extractor not supported for recurrent extractors") + + preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) # type: ignore + return self.features_extractor(preprocessed_obs, state, episode_starts) + + def forward( # type: ignore[override] + self, + obs: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RecurrentState]: + """Advances to the next hidden state, and computes all the outputs of a recurrent policy. + + In this docstring the dimension letters are: Time (T), Batch (B) and others (...). + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :param deterministic: if True return the best action, else a sample. + :returns: (actions, values, log_prob, state). The actions, values and log-action-probabilities for every time + step T, and the final state. + """ + latents, state = self._recurrent_extract_features(obs, state, episode_starts) + latent_pi = self.mlp_extractor.forward_actor(latents) + latent_vf = self.mlp_extractor.forward_critic(latents) + + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob, state + + def get_distribution( # type: ignore[override] + self, + obs: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + ) -> Tuple[Distribution, RecurrentState]: + """ + Get the policy distribution for each step given the observations. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :return: the action distribution, the new hidden states. + """ + latent_pi, state = self._recurrent_extract_features(obs, state, episode_starts) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + return self._get_action_dist_from_latent(latent_pi), state + + def predict_values( # type: ignore[override] + self, + obs: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :return: The value for each time step. + """ + latent_vf, _ = self._recurrent_extract_features(obs, state, episode_starts) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + return self.value_net(latent_vf) + + def evaluate_actions( # type: ignore[override] + self, obs: TorchGymObs, actions: th.Tensor, state: RecurrentState, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param actions: The actions taken at each step. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + latents, state = self._recurrent_extract_features(obs, state, episode_starts) + latent_pi = self.mlp_extractor.forward_actor(latents) + latent_vf = self.mlp_extractor.forward_critic(latents) + + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, non_null(distribution.entropy()) + + def _predict( # type: ignore[override] + self, + observation: TorchGymObs, + state: RecurrentState, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, RecurrentState]: + """ + Get the action according to the policy for a given observation. + + :param obs: shape (T, B, ...) the policy will be applied in sequence to all the observations. + :param state: shape (B, ...), the hidden state of the recurrent network + :param episode_starts: shape (T, B), whether the current state is the start of an episode. This should be be 0 + everywhere except for T=0, where it may be 1. + :param deterministic: if True return the best action, else a sample. + :return: the model's action and the next hidden state + """ + distribution, state = self.get_distribution(observation, state, episode_starts) + return distribution.get_actions(deterministic=deterministic), state diff --git a/stable_baselines3/common/recurrent/torch_layers.py b/stable_baselines3/common/recurrent/torch_layers.py new file mode 100644 index 000000000..4e15f4ffd --- /dev/null +++ b/stable_baselines3/common/recurrent/torch_layers.py @@ -0,0 +1,158 @@ +import abc +from typing import Generic, Optional, Tuple, TypeVar + +import gymnasium as gym +import torch as th + +from stable_baselines3.common.pytree_dataclass import TensorTree, tree_flatten, tree_map +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + NatureCNN, +) +from stable_baselines3.common.type_aliases import TorchGymObs + +RecurrentState = TypeVar("RecurrentState", bound=TensorTree) + +RecurrentSubState = TypeVar("RecurrentSubState", bound=TensorTree) + + +class RecurrentFeaturesExtractor(BaseFeaturesExtractor, abc.ABC, Generic[RecurrentState]): + @abc.abstractmethod + def recurrent_initial_state( + self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None + ) -> RecurrentState: + ... + + @abc.abstractmethod + def forward( + self, observations: TorchGymObs, state: RecurrentState, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, RecurrentState]: + ... + + @staticmethod + def _process_sequence( + rnn: th.nn.RNNBase, inputs: th.Tensor, init_state: RecurrentSubState, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, RecurrentSubState]: + (state_example, *_), _ = tree_flatten(init_state, is_leaf=None) + n_layers, batch_sz, *_ = state_example.shape + + # Batch to sequence + # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) + seq_len = inputs.shape[0] // batch_sz + seq_inputs = inputs.view((batch_sz, seq_len, *inputs.shape[1:])).swapaxes(0, 1) + episode_starts = episode_starts.view((batch_sz, seq_len)).swapaxes(0, 1) + + if th.any(episode_starts[1:]): + raise NotImplementedError("Resetting state in the middle of a sequence is not supported") + + first_state_is_not_reset = (~episode_starts[0]).contiguous() + # Shape here is (n_layers, batch_sz) + init_state = tree_map(lambda x: x * first_state_is_not_reset.view((1, batch_sz, *(1,) * (x.ndim - 2))), init_state) + rnn_output, end_state = rnn(seq_inputs, init_state) + + # (seq_len, batch_size, ...) -> (batch_size, seq_len, ...) -> (batch_size * seq_len, ...) + rnn_output = rnn_output.transpose(0, 1).reshape((batch_sz * seq_len, *rnn_output.shape[2:])) + return rnn_output, end_state + + +GRURecurrentState = th.Tensor + + +class GRUWrappedFeaturesExtractor(RecurrentFeaturesExtractor[GRURecurrentState]): + def __init__( + self, + observation_space: gym.Space, + base_extractor: BaseFeaturesExtractor, + features_dim: Optional[int] = None, + num_layers: int = 1, + bias: bool = True, + dropout: float = 0.0, + ): + if features_dim is None: + # Ensure features_dim is at least 64 by default so it optimizes fine + features_dim = max(base_extractor.features_dim, 64) + + assert observation_space == base_extractor._observation_space + + super().__init__(observation_space, features_dim) + self.base_extractor = base_extractor + + self.rnn = th.nn.GRU( + input_size=base_extractor.features_dim, + hidden_size=features_dim, + num_layers=num_layers, + bias=bias, + batch_first=False, + dropout=dropout, + bidirectional=False, + ) + + def recurrent_initial_state( + self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None + ) -> GRURecurrentState: + shape: Tuple[int, ...] + if n_envs is None: + shape = (self.rnn.num_layers, self.rnn.hidden_size) + else: + shape = (self.rnn.num_layers, n_envs, self.rnn.hidden_size) + return th.zeros(shape, device=device) + + def forward( + self, observations: TorchGymObs, state: GRURecurrentState, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, GRURecurrentState]: + features: th.Tensor = self.base_extractor(observations) + return self._process_sequence(self.rnn, features, state, episode_starts) + + @property + def features_dim(self) -> int: + return self.rnn.hidden_size + + +class GRUFlattenExtractor(GRUWrappedFeaturesExtractor): + def __init__( + self, + observation_space: gym.Space, + features_dim: int = 64, + num_layers: int = 1, + bias: bool = True, + dropout: float = 0.0, + ) -> None: + base_extractor = FlattenExtractor(observation_space) + super().__init__( + observation_space, base_extractor, features_dim=features_dim, num_layers=num_layers, bias=bias, dropout=dropout + ) + + +class GRUNatureCNNExtractor(GRUWrappedFeaturesExtractor): + def __init__( + self, + observation_space: gym.Space, + features_dim: int = 512, + normalized_image: bool = False, + num_layers: int = 1, + bias: bool = True, + dropout: float = 0.0, + ) -> None: + base_extractor = NatureCNN(observation_space, features_dim=features_dim, normalized_image=normalized_image) + super().__init__( + observation_space, base_extractor, features_dim=features_dim, num_layers=num_layers, bias=bias, dropout=dropout + ) + + +class GRUCombinedExtractor(GRUWrappedFeaturesExtractor): + def __init__( + self, + observation_space: gym.spaces.Dict, + features_dim: int = 64, + cnn_output_dim: int = 256, + normalized_image: bool = False, + num_layers: int = 1, + bias: bool = True, + dropout: float = 0.0, + ) -> None: + base_extractor = CombinedExtractor(observation_space, cnn_output_dim=cnn_output_dim, normalized_image=normalized_image) + super().__init__( + observation_space, base_extractor, features_dim=features_dim, num_layers=num_layers, bias=bias, dropout=dropout + ) From 041a20b3f0af8a18e10ad1e21921c0519c2b30af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 12:44:49 -0700 Subject: [PATCH 02/16] In the process of making the LSTMs extractors, to de-duplicate code. --- .../common/recurrent/policies.py | 99 +++++-------------- .../common/recurrent/torch_layers.py | 69 ++++++++++--- .../common/recurrent/type_aliases.py | 14 ++- .../ppo_recurrent/ppo_recurrent.py | 4 +- 4 files changed, 96 insertions(+), 90 deletions(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index a7260bcf8..4cf292dc4 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -10,14 +10,16 @@ from stable_baselines3.common.preprocessing import preprocess_obs from stable_baselines3.common.pytree_dataclass import tree_flatten from stable_baselines3.common.recurrent.torch_layers import ( + ExtractorInput, GRUNatureCNNExtractor, GRUWrappedFeaturesExtractor, + LSTMFlattenExtractor, RecurrentFeaturesExtractor, RecurrentState, ) from stable_baselines3.common.recurrent.type_aliases import ( - LSTMStates, - RNNStates, + ActorCriticStates, + LSTMRecurrentState, non_null, ) from stable_baselines3.common.torch_layers import ( @@ -28,7 +30,6 @@ NatureCNN, ) from stable_baselines3.common.type_aliases import Schedule, TorchGymObs -from stable_baselines3.common.utils import zip_strict class BaseRecurrentActorCriticPolicy(ActorCriticPolicy, Generic[RecurrentState]): @@ -279,9 +280,9 @@ def __init__( self.lstm_kwargs = lstm_kwargs or {} self.shared_lstm = shared_lstm self.enable_critic_lstm = enable_critic_lstm - self.lstm_actor = nn.LSTM( - self.features_dim, - lstm_hidden_size, + self.lstm_actor = LSTMFlattenExtractor( + spaces.Box(-1e9, 1e9, (self.features_dim,)), + features_dim=lstm_hidden_size, num_layers=n_lstm_layers, **self.lstm_kwargs, ) @@ -306,9 +307,9 @@ def __init__( # Use a separate LSTM for the critic if self.enable_critic_lstm: - self.lstm_critic = nn.LSTM( - self.features_dim, - lstm_hidden_size, + self.lstm_critic = LSTMFlattenExtractor( + spaces.Box(-1e9, 1e9, (self.features_dim,)), + features_dim=lstm_hidden_size, num_layers=n_lstm_layers, **self.lstm_kwargs, ) @@ -330,67 +331,21 @@ def _build_mlp_extractor(self) -> None: device=self.device, ) - @staticmethod - def _process_sequence( - features: th.Tensor, - lstm_states: LSTMStates, - episode_starts: th.Tensor, - lstm: nn.LSTM, - ) -> Tuple[th.Tensor, LSTMStates]: - # LSTM logic - # (sequence length, batch size, features dim) - # (batch size = n_envs for data collection or n_seq when doing gradient update) - n_seq = lstm_states[0].shape[1] - # Batch to sequence - # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) - # note: max length (max sequence length) is always 1 during data collection - features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1) - episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1) - - # If we don't have to reset the state in the middle of a sequence - # we can avoid the for loop, which speeds up things - if not th.any(episode_starts[1:]): - not_reset_first = (~episode_starts[0]).view(1, n_seq, 1) - lstm_output, lstm_states = lstm( - features_sequence, (not_reset_first * lstm_states[0], not_reset_first * lstm_states[1]) - ) - lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1) - return lstm_output, lstm_states - - raise RuntimeError("The inefficient code path should not happen.") - - lstm_output = [] - # Iterate over the sequence - for features, episode_start in zip_strict(features_sequence, episode_starts): - hidden, lstm_states = lstm( - features.unsqueeze(dim=0), - ( - # Reset the states at the beginning of a new episode - (~episode_start).view(1, n_seq, 1) * lstm_states[0], - (~episode_start).view(1, n_seq, 1) * lstm_states[1], - ), - ) - lstm_output += [hidden] - # Sequence to batch - # (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim) - lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) - return lstm_output, lstm_states - def recurrent_initial_state(self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None): shape: tuple[int, ...] if n_envs is None: shape = (self.lstm_hidden_state_shape[0], self.lstm_hidden_state_shape[2]) else: shape = (self.lstm_hidden_state_shape[0], n_envs, self.lstm_hidden_state_shape[2]) - return RNNStates( + return ActorCriticStates( (th.zeros(shape, device=device), th.zeros(shape, device=device)), (th.zeros(shape, device=device), th.zeros(shape, device=device)), ) # Methods for getting `latent_vf` or `latent_pi` def _recurrent_latent_pi_and_vf( - self, obs: TorchGymObs, state: RNNStates, episode_starts: th.Tensor - ) -> Tuple[Tuple[th.Tensor, th.Tensor], RNNStates]: + self, obs: TorchGymObs, state: ActorCriticStates, episode_starts: th.Tensor + ) -> Tuple[Tuple[th.Tensor, th.Tensor], ActorCriticStates]: features = self.extract_features(obs) pi_features: th.Tensor vf_features: th.Tensor @@ -404,11 +359,11 @@ def _recurrent_latent_pi_and_vf( latent_vf, lstm_states_vf = self._recurrent_latent_vf_from_features(vf_features, state, episode_starts) if lstm_states_vf is None: lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) - return ((latent_pi, latent_vf), RNNStates(lstm_states_pi, lstm_states_vf)) + return ((latent_pi, latent_vf), ActorCriticStates(lstm_states_pi, lstm_states_vf)) def _recurrent_latent_vf_from_features( - self, vf_features: th.Tensor, state: RNNStates, episode_starts: th.Tensor - ) -> Tuple[th.Tensor, Optional[LSTMStates]]: + self, vf_features: th.Tensor, state: ActorCriticStates, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, Optional[LSTMRecurrentState]]: "Get only the vf features, not advancing the hidden state" if self.lstm_critic is None: if self.shared_lstm: @@ -421,17 +376,17 @@ def _recurrent_latent_vf_from_features( latent_vf, state_vf = self._process_sequence(vf_features, state.vf, episode_starts, self.lstm_critic) return latent_vf, state_vf - def _recurrent_latent_vf_nostate(self, obs: TorchGymObs, state: RNNStates, episode_starts: th.Tensor) -> th.Tensor: + def _recurrent_latent_vf_nostate(self, obs: TorchGymObs, state: ActorCriticStates, episode_starts: th.Tensor) -> th.Tensor: vf_features: th.Tensor = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor) return self._recurrent_latent_vf_from_features(vf_features, state, episode_starts)[0] def forward( # type: ignore[override] self, obs: TorchGymObs, - state: RNNStates, + state: ActorCriticStates, episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, ActorCriticStates]: (latent_pi, latent_vf), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -446,9 +401,9 @@ def forward( # type: ignore[override] def get_distribution( # type: ignore[override] self, obs: TorchGymObs, - state: RNNStates, + state: ActorCriticStates, episode_starts: th.Tensor, - ) -> Tuple[Distribution, RNNStates]: + ) -> Tuple[Distribution, ActorCriticStates]: (latent_pi, _), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) return self._get_action_dist_from_latent(latent_pi), state @@ -456,7 +411,7 @@ def get_distribution( # type: ignore[override] def predict_values( # type: ignore[override] self, obs: TorchGymObs, - state: RNNStates, + state: ActorCriticStates, episode_starts: th.Tensor, ) -> th.Tensor: latent_vf = self._recurrent_latent_vf_nostate(obs, state, episode_starts) @@ -464,7 +419,7 @@ def predict_values( # type: ignore[override] return self.value_net(latent_vf) def evaluate_actions( # type: ignore[override] - self, obs: TorchGymObs, actions: th.Tensor, state: RNNStates, episode_starts: th.Tensor + self, obs: TorchGymObs, actions: th.Tensor, state: ActorCriticStates, episode_starts: th.Tensor ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: (latent_pi, latent_vf), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) @@ -478,10 +433,10 @@ def evaluate_actions( # type: ignore[override] def _predict( # type: ignore[override] self, observation: TorchGymObs, - state: RNNStates, + state: ActorCriticStates, episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, RNNStates]: + ) -> Tuple[th.Tensor, ActorCriticStates]: distribution, state = self.get_distribution(observation, state, episode_starts) return distribution.get_actions(deterministic=deterministic), state @@ -666,8 +621,8 @@ def __init__( ) -class RecurrentFeaturesExtractorActorCriticPolicy(ActorCriticPolicy, Generic[RecurrentState]): - features_extractor: RecurrentFeaturesExtractor[RecurrentState] +class RecurrentFeaturesExtractorActorCriticPolicy(ActorCriticPolicy, Generic[ExtractorInput, RecurrentState]): + features_extractor: RecurrentFeaturesExtractor[ExtractorInput, RecurrentState] def __init__( self, diff --git a/stable_baselines3/common/recurrent/torch_layers.py b/stable_baselines3/common/recurrent/torch_layers.py index 4e15f4ffd..f79221366 100644 --- a/stable_baselines3/common/recurrent/torch_layers.py +++ b/stable_baselines3/common/recurrent/torch_layers.py @@ -1,10 +1,15 @@ import abc -from typing import Generic, Optional, Tuple, TypeVar +from typing import Any, Dict, Generic, Optional, Tuple, TypeVar import gymnasium as gym import torch as th +from stable_baselines3.common.preprocessing import get_flattened_obs_dim from stable_baselines3.common.pytree_dataclass import TensorTree, tree_flatten, tree_map +from stable_baselines3.common.recurrent.type_aliases import ( + GRURecurrentState, + LSTMRecurrentState, +) from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -17,8 +22,10 @@ RecurrentSubState = TypeVar("RecurrentSubState", bound=TensorTree) +ExtractorInput = TypeVar("ExtractorInput", bound=TorchGymObs) + -class RecurrentFeaturesExtractor(BaseFeaturesExtractor, abc.ABC, Generic[RecurrentState]): +class RecurrentFeaturesExtractor(BaseFeaturesExtractor, abc.ABC, Generic[ExtractorInput, RecurrentState]): @abc.abstractmethod def recurrent_initial_state( self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None @@ -27,7 +34,7 @@ def recurrent_initial_state( @abc.abstractmethod def forward( - self, observations: TorchGymObs, state: RecurrentState, episode_starts: th.Tensor + self, observations: ExtractorInput, state: RecurrentState, episode_starts: th.Tensor ) -> Tuple[th.Tensor, RecurrentState]: ... @@ -57,10 +64,7 @@ def _process_sequence( return rnn_output, end_state -GRURecurrentState = th.Tensor - - -class GRUWrappedFeaturesExtractor(RecurrentFeaturesExtractor[GRURecurrentState]): +class GRUWrappedFeaturesExtractor(RecurrentFeaturesExtractor[ExtractorInput, GRURecurrentState], Generic[ExtractorInput]): def __init__( self, observation_space: gym.Space, @@ -100,7 +104,7 @@ def recurrent_initial_state( return th.zeros(shape, device=device) def forward( - self, observations: TorchGymObs, state: GRURecurrentState, episode_starts: th.Tensor + self, observations: ExtractorInput, state: GRURecurrentState, episode_starts: th.Tensor ) -> Tuple[th.Tensor, GRURecurrentState]: features: th.Tensor = self.base_extractor(observations) return self._process_sequence(self.rnn, features, state, episode_starts) @@ -110,7 +114,7 @@ def features_dim(self) -> int: return self.rnn.hidden_size -class GRUFlattenExtractor(GRUWrappedFeaturesExtractor): +class GRUFlattenExtractor(GRUWrappedFeaturesExtractor[th.Tensor]): def __init__( self, observation_space: gym.Space, @@ -125,7 +129,7 @@ def __init__( ) -class GRUNatureCNNExtractor(GRUWrappedFeaturesExtractor): +class GRUNatureCNNExtractor(GRUWrappedFeaturesExtractor[th.Tensor]): def __init__( self, observation_space: gym.Space, @@ -141,7 +145,7 @@ def __init__( ) -class GRUCombinedExtractor(GRUWrappedFeaturesExtractor): +class GRUCombinedExtractor(GRUWrappedFeaturesExtractor[Dict[Any, th.Tensor]]): def __init__( self, observation_space: gym.spaces.Dict, @@ -156,3 +160,46 @@ def __init__( super().__init__( observation_space, base_extractor, features_dim=features_dim, num_layers=num_layers, bias=bias, dropout=dropout ) + + +class LSTMFlattenExtractor(RecurrentFeaturesExtractor[th.Tensor, LSTMRecurrentState]): + def __init__( + self, + observation_space: gym.Space, + features_dim: int = 64, + num_layers: int = 1, + bias: bool = True, + dropout: float = 0.0, + ): + super().__init__(observation_space, features_dim) + + self.rnn = th.nn.LSTM( + input_size=get_flattened_obs_dim(self._observation_space), + hidden_size=features_dim, + num_layers=num_layers, + bias=bias, + batch_first=False, + dropout=dropout, + bidirectional=False, + ) + self.base_extractor = FlattenExtractor(observation_space) + + def recurrent_initial_state( + self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None + ) -> LSTMRecurrentState: + shape: Tuple[int, ...] + if n_envs is None: + shape = (self.rnn.num_layers, self.rnn.hidden_size) + else: + shape = (self.rnn.num_layers, n_envs, self.rnn.hidden_size) + return (th.zeros(shape, device=device), th.zeros(shape, device=device)) + + def forward( + self, observations: th.Tensor, state: LSTMRecurrentState, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, LSTMRecurrentState]: + features: th.Tensor = self.base_extractor(observations) + return self._process_sequence(self.rnn, features, state, episode_starts) + + @property + def features_dim(self) -> int: + return self.rnn.hidden_size diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index fe4bfcb4a..70bb123ad 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, TypeVar +from typing import Generic, Optional, Tuple, TypeVar import torch as th @@ -13,12 +13,16 @@ def non_null(v: Optional[T]) -> T: return v -LSTMStates = Tuple[th.Tensor, th.Tensor] +TensorTreeT = TypeVar("TensorTreeT", bound=TensorTree) -class RNNStates(FrozenPyTreeDataclass[th.Tensor]): - pi: LSTMStates - vf: LSTMStates +LSTMRecurrentState = Tuple[th.Tensor, th.Tensor] +GRURecurrentState = th.Tensor + + +class ActorCriticStates(FrozenPyTreeDataclass[th.Tensor], Generic[TensorTreeT]): + pi: TensorTreeT + vf: TensorTreeT class RecurrentRolloutBufferData(FrozenPyTreeDataclass[th.Tensor]): diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index da479c9ba..126d083c2 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -14,8 +14,8 @@ from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer from stable_baselines3.common.recurrent.policies import RecurrentActorCriticPolicy from stable_baselines3.common.recurrent.type_aliases import ( + ActorCriticStates, RecurrentRolloutBufferData, - RNNStates, non_null, ) from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -177,7 +177,7 @@ def __init__( self.clip_range_vf: Schedule = clip_range_vf # type: ignore self.normalize_advantage = normalize_advantage self.target_kl = target_kl - self._last_lstm_states: Optional[RNNStates] = None + self._last_lstm_states: Optional[ActorCriticStates] = None if _init_setup_model: self._setup_model() From c0b0310513f149b5087d8a60098be5d63bd4cb74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 12:49:59 -0700 Subject: [PATCH 03/16] Make the LSTMs in the original ActorCritic be Extractors --- .../common/recurrent/policies.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 4cf292dc4..693f054a5 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -344,8 +344,8 @@ def recurrent_initial_state(self, n_envs: Optional[int] = None, *, device: Optio # Methods for getting `latent_vf` or `latent_pi` def _recurrent_latent_pi_and_vf( - self, obs: TorchGymObs, state: ActorCriticStates, episode_starts: th.Tensor - ) -> Tuple[Tuple[th.Tensor, th.Tensor], ActorCriticStates]: + self, obs: TorchGymObs, state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor + ) -> Tuple[Tuple[th.Tensor, th.Tensor], ActorCriticStates[LSTMRecurrentState]]: features = self.extract_features(obs) pi_features: th.Tensor vf_features: th.Tensor @@ -355,38 +355,40 @@ def _recurrent_latent_pi_and_vf( else: assert isinstance(features, tuple) pi_features, vf_features = features - latent_pi, lstm_states_pi = self._process_sequence(pi_features, state.pi, episode_starts, self.lstm_actor) + latent_pi, lstm_states_pi = self.lstm_actor.forward(pi_features, state.pi, episode_starts) latent_vf, lstm_states_vf = self._recurrent_latent_vf_from_features(vf_features, state, episode_starts) if lstm_states_vf is None: lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) return ((latent_pi, latent_vf), ActorCriticStates(lstm_states_pi, lstm_states_vf)) def _recurrent_latent_vf_from_features( - self, vf_features: th.Tensor, state: ActorCriticStates, episode_starts: th.Tensor + self, vf_features: th.Tensor, state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor ) -> Tuple[th.Tensor, Optional[LSTMRecurrentState]]: "Get only the vf features, not advancing the hidden state" if self.lstm_critic is None: if self.shared_lstm: with th.no_grad(): - latent_vf, _ = self._process_sequence(vf_features, state.pi, episode_starts, self.lstm_actor) + latent_vf, _ = self.lstm_actor.forward(vf_features, state.pi, episode_starts) else: latent_vf = non_null(self.critic)(vf_features) state_vf = None else: - latent_vf, state_vf = self._process_sequence(vf_features, state.vf, episode_starts, self.lstm_critic) + latent_vf, state_vf = self.lstm_critic(vf_features, state.vf, episode_starts) return latent_vf, state_vf - def _recurrent_latent_vf_nostate(self, obs: TorchGymObs, state: ActorCriticStates, episode_starts: th.Tensor) -> th.Tensor: + def _recurrent_latent_vf_nostate( + self, obs: TorchGymObs, state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor + ) -> th.Tensor: vf_features: th.Tensor = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor) return self._recurrent_latent_vf_from_features(vf_features, state, episode_starts)[0] def forward( # type: ignore[override] self, obs: TorchGymObs, - state: ActorCriticStates, + state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, ActorCriticStates]: + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, ActorCriticStates[LSTMRecurrentState]]: (latent_pi, latent_vf), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -401,9 +403,9 @@ def forward( # type: ignore[override] def get_distribution( # type: ignore[override] self, obs: TorchGymObs, - state: ActorCriticStates, + state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor, - ) -> Tuple[Distribution, ActorCriticStates]: + ) -> Tuple[Distribution, ActorCriticStates[LSTMRecurrentState]]: (latent_pi, _), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) return self._get_action_dist_from_latent(latent_pi), state @@ -411,7 +413,7 @@ def get_distribution( # type: ignore[override] def predict_values( # type: ignore[override] self, obs: TorchGymObs, - state: ActorCriticStates, + state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor, ) -> th.Tensor: latent_vf = self._recurrent_latent_vf_nostate(obs, state, episode_starts) @@ -419,7 +421,7 @@ def predict_values( # type: ignore[override] return self.value_net(latent_vf) def evaluate_actions( # type: ignore[override] - self, obs: TorchGymObs, actions: th.Tensor, state: ActorCriticStates, episode_starts: th.Tensor + self, obs: TorchGymObs, actions: th.Tensor, state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: (latent_pi, latent_vf), state = self._recurrent_latent_pi_and_vf(obs, state, episode_starts) latent_pi = self.mlp_extractor.forward_actor(latent_pi) @@ -433,10 +435,10 @@ def evaluate_actions( # type: ignore[override] def _predict( # type: ignore[override] self, observation: TorchGymObs, - state: ActorCriticStates, + state: ActorCriticStates[LSTMRecurrentState], episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, ActorCriticStates]: + ) -> Tuple[th.Tensor, ActorCriticStates[LSTMRecurrentState]]: distribution, state = self.get_distribution(observation, state, episode_starts) return distribution.get_actions(deterministic=deterministic), state From 658a523d1b2bf2729f1ff58acaec36fa601ce1e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 19:08:39 -0700 Subject: [PATCH 04/16] introduce checked_cast, test non_null --- .../common/recurrent/policies.py | 3 +-- .../common/recurrent/type_aliases.py | 11 +-------- stable_baselines3/common/type_aliases.py | 18 +++++++++++++++ .../ppo_recurrent/ppo_recurrent.py | 6 ++++- tests/test_type_aliases.py | 23 +++++++++++++++++++ 5 files changed, 48 insertions(+), 13 deletions(-) create mode 100644 tests/test_type_aliases.py diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 693f054a5..46f1c4361 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -20,7 +20,6 @@ from stable_baselines3.common.recurrent.type_aliases import ( ActorCriticStates, LSTMRecurrentState, - non_null, ) from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -29,7 +28,7 @@ MlpExtractor, NatureCNN, ) -from stable_baselines3.common.type_aliases import Schedule, TorchGymObs +from stable_baselines3.common.type_aliases import Schedule, TorchGymObs, non_null class BaseRecurrentActorCriticPolicy(ActorCriticPolicy, Generic[RecurrentState]): diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index 70bb123ad..5aba6acce 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -1,18 +1,9 @@ -from typing import Generic, Optional, Tuple, TypeVar +from typing import Generic, Tuple, TypeVar import torch as th from stable_baselines3.common.pytree_dataclass import FrozenPyTreeDataclass, TensorTree -T = TypeVar("T") - - -def non_null(v: Optional[T]) -> T: - if v is None: - raise ValueError("Expected a value, got None") - return v - - TensorTreeT = TypeVar("TensorTreeT", bound=TensorTree) diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 32b549e20..c578f3ca1 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -12,7 +12,10 @@ Protocol, SupportsFloat, Tuple, + Type, + TypeVar, Union, + get_origin, ) import gymnasium as gym @@ -117,3 +120,18 @@ def device(self) -> th.device: :return: the device on which this predictor lives """ ... + + +T = TypeVar("T") + + +def non_null(v: Optional[T]) -> T: + if v is None: + raise ValueError("Expected a value, got None") + return v + + +def check_cast(cls: Type[T], v: Any) -> T: + if not isinstance(v, get_origin(cls) or cls): + raise TypeError(f"{v} should be of type {cls}") + return v diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 126d083c2..70cfacb0e 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -16,9 +16,13 @@ from stable_baselines3.common.recurrent.type_aliases import ( ActorCriticStates, RecurrentRolloutBufferData, +) +from stable_baselines3.common.type_aliases import ( + GymEnv, + MaybeCallback, + Schedule, non_null, ) -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import ( explained_variance, get_schedule_fn, diff --git a/tests/test_type_aliases.py b/tests/test_type_aliases.py new file mode 100644 index 000000000..7bebc0077 --- /dev/null +++ b/tests/test_type_aliases.py @@ -0,0 +1,23 @@ +import pytest + +from stable_baselines3.common.type_aliases import check_cast, non_null + + +def test_non_null(): + for a in (1, "a", [2]): + assert non_null(a) == a + + with pytest.raises(ValueError): + non_null(None) + + +def test_check_cast(): + assert check_cast(dict, {}) == {} + assert check_cast(dict[str, int], {}) == {} + + with pytest.raises(TypeError): + check_cast(list[int], {}) + + # NOTE: check_cast does not check the template arguments, only the main class. + a: list[str] = ["a"] + assert check_cast(list[int], a) == a From b2e8fa167a3f784aa0c3f2ac8c10b5f3b27a4d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 19:11:25 -0700 Subject: [PATCH 05/16] test Uppercase types too --- tests/test_type_aliases.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_type_aliases.py b/tests/test_type_aliases.py index 7bebc0077..6744e9c3d 100644 --- a/tests/test_type_aliases.py +++ b/tests/test_type_aliases.py @@ -1,3 +1,5 @@ +from typing import Dict, List + import pytest from stable_baselines3.common.type_aliases import check_cast, non_null @@ -14,9 +16,11 @@ def test_non_null(): def test_check_cast(): assert check_cast(dict, {}) == {} assert check_cast(dict[str, int], {}) == {} + assert check_cast(Dict[str, int], {}) == {} with pytest.raises(TypeError): check_cast(list[int], {}) + check_cast(List[int], {}) # NOTE: check_cast does not check the template arguments, only the main class. a: list[str] = ["a"] From 177cb5a4e31f02dbbd5f075e238bd0619d313c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 19:47:34 -0700 Subject: [PATCH 06/16] Correct class ascendancy --- stable_baselines3/common/recurrent/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 693f054a5..5bb1b496b 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -623,7 +623,7 @@ def __init__( ) -class RecurrentFeaturesExtractorActorCriticPolicy(ActorCriticPolicy, Generic[ExtractorInput, RecurrentState]): +class RecurrentFeaturesExtractorActorCriticPolicy(BaseRecurrentActorCriticPolicy, Generic[ExtractorInput, RecurrentState]): features_extractor: RecurrentFeaturesExtractor[ExtractorInput, RecurrentState] def __init__( From 4361f5cbec064e6f6012428d538ae36fddfa127d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 20:41:07 -0700 Subject: [PATCH 07/16] Add some tests for the recurrent feature extractors --- .../ppo_recurrent/ppo_recurrent.py | 7 ++- tests/test_lstm.py | 59 +++++++++++++++++-- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 126d083c2..e2e97e676 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -12,7 +12,10 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.pytree_dataclass import tree_map from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer -from stable_baselines3.common.recurrent.policies import RecurrentActorCriticPolicy +from stable_baselines3.common.recurrent.policies import ( + BaseRecurrentActorCriticPolicy, + RecurrentActorCriticPolicy, +) from stable_baselines3.common.recurrent.type_aliases import ( ActorCriticStates, RecurrentRolloutBufferData, @@ -95,7 +98,7 @@ class RecurrentPPO(OnPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[RecurrentActorCriticPolicy]], + policy: Union[str, Type[BaseRecurrentActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 128, diff --git a/tests/test_lstm.py b/tests/test_lstm.py index e3f8ebbdd..176b84ba3 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -13,6 +13,15 @@ from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.recurrent.policies import ( + BaseRecurrentActorCriticPolicy, + RecurrentFeaturesExtractorActorCriticPolicy, +) +from stable_baselines3.common.recurrent.torch_layers import ( + GRUCombinedExtractor, + GRUFlattenExtractor, + GRUNatureCNNExtractor, +) from stable_baselines3.common.vec_env import VecNormalize @@ -102,6 +111,19 @@ def test_cnn(policy_kwargs): model.learn(total_timesteps=32) +def test_cnn_recurrent_extractor(): + model = RecurrentPPO( + RecurrentFeaturesExtractorActorCriticPolicy, + FakeImageEnv(screen_height=40, screen_width=40, n_channels=3), + n_steps=16, + seed=0, + policy_kwargs=dict(features_extractor_class=GRUNatureCNNExtractor, features_extractor_kwargs=dict(features_dim=32)), + n_epochs=2, + ) + + model.learn(total_timesteps=32) + + @pytest.mark.parametrize( "policy_kwargs", [ @@ -181,6 +203,20 @@ def test_run_sde(): model.learn(total_timesteps=200) +def test_run_sde_recurrent_extractor(): + model = RecurrentPPO( + RecurrentFeaturesExtractorActorCriticPolicy, + "Pendulum-v1", + n_steps=16, + seed=0, + sde_sample_freq=4, + use_sde=True, + clip_range_vf=0.1, + policy_kwargs=dict(features_extractor_class=GRUFlattenExtractor), + ) + model.learn(total_timesteps=200) + + @pytest.mark.parametrize( "policy_kwargs", [ @@ -206,8 +242,16 @@ def test_dict_obs(policy_kwargs): evaluate_policy(model, env, warn=False) +def test_dict_obs_recurrent_extractor(): + policy_kwargs = dict(features_extractor_class=GRUCombinedExtractor) + env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper) + model = RecurrentPPO(RecurrentFeaturesExtractorActorCriticPolicy, env, n_steps=32, policy_kwargs=policy_kwargs).learn(64) + evaluate_policy(model, env, warn=False) + + @pytest.mark.slow -def test_ppo_lstm_performance(): +@pytest.mark.parametrize("policy", ["MlpLstmPolicy", "GRUFeatureExtractorPolicy"]) +def test_ppo_lstm_performance(policy: str | type[BaseRecurrentActorCriticPolicy]): # env = make_vec_env("CartPole-v1", n_envs=16) def make_env(): env = CartPoleNoVelEnv() @@ -222,8 +266,16 @@ def make_env(): eval_freq=5000 // env.num_envs, ) + if policy == "GRUFeatureExtractorPolicy": + policy = RecurrentFeaturesExtractorActorCriticPolicy + extra_policy_kwargs = dict( + features_extractor_class=GRUFlattenExtractor, features_extractor_kwargs=dict(features_dim=64) + ) + else: + extra_policy_kwargs = dict(lstm_hidden_size=64, enable_critic_lstm=True) + model = RecurrentPPO( - "MlpLstmPolicy", + policy, env, n_steps=128, learning_rate=0.0007, @@ -235,9 +287,8 @@ def make_env(): gae_lambda=0.98, policy_kwargs=dict( net_arch=dict(vf=[64], pi=[]), - lstm_hidden_size=64, ortho_init=False, - enable_critic_lstm=True, + **extra_policy_kwargs, ), ) From 2ad87d5663e69cc8784eb6017593ad403f8475f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 09:56:27 -0700 Subject: [PATCH 08/16] Allow th.Tensor as input to tile_images --- stable_baselines3/common/vec_env/base_vec_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index cb76e4996..e95a0520e 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -24,7 +24,7 @@ EnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]] -def tile_images(images_nhwc: Sequence[th.Tensor]) -> th.Tensor: # pragma: no cover +def tile_images(images_nhwc: Sequence[th.Tensor] | th.Tensor) -> th.Tensor: # pragma: no cover """ Tile N images into one big PxQ image (P,Q) are chosen to be as close as possible, and if N From ffccebb2e9b9f672e9c2b1bb7b3fa045da3551e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:23:12 -0700 Subject: [PATCH 09/16] Add documentation comments --- stable_baselines3/common/type_aliases.py | 8 ++++++++ tests/test_type_aliases.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index c578f3ca1..f63ab2202 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -126,12 +126,20 @@ def device(self) -> th.device: def non_null(v: Optional[T]) -> T: + """ + Checks that `v` is not None, and returns it. + """ if v is None: raise ValueError("Expected a value, got None") return v def check_cast(cls: Type[T], v: Any) -> T: + """ + Checks that `v` is of type `cls`, and returns it. + + NOTE: this function does not check the template arguments, only the type itself. + """ if not isinstance(v, get_origin(cls) or cls): raise TypeError(f"{v} should be of type {cls}") return v diff --git a/tests/test_type_aliases.py b/tests/test_type_aliases.py index 6744e9c3d..c3ef81700 100644 --- a/tests/test_type_aliases.py +++ b/tests/test_type_aliases.py @@ -23,5 +23,8 @@ def test_check_cast(): check_cast(List[int], {}) # NOTE: check_cast does not check the template arguments, only the main class. + # Tests should give an accurate understanding of how the function works, so we still check for this behavior. a: list[str] = ["a"] - assert check_cast(list[int], a) == a + assert ( + check_cast(list[int], a) == a + ), "If you managed to write code to trigger this assert that's good! We'd like template arguments to be checked." From d6ccd985d074d202f4abda92414f8b25a64f423b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:33:55 -0700 Subject: [PATCH 10/16] specify device for default episode_start --- stable_baselines3/common/recurrent/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 5bb1b496b..a2339e90f 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -165,7 +165,7 @@ def predict( # type: ignore[override] state = self.recurrent_initial_state(n_envs, device=self.device) if episode_start is None: - episode_start = th.zeros(n_envs, dtype=th.bool) + episode_start = th.zeros(n_envs, dtype=th.bool, device=self.device) with th.no_grad(): # Convert to PyTorch tensors From 2e795e98aa7e8d4a30b02f6fa8a2725f717f2600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:35:30 -0700 Subject: [PATCH 11/16] Delete misleading comment --- stable_baselines3/common/recurrent/policies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index a2339e90f..58a9499e5 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -168,7 +168,6 @@ def predict( # type: ignore[override] episode_start = th.zeros(n_envs, dtype=th.bool, device=self.device) with th.no_grad(): - # Convert to PyTorch tensors actions, state = self._predict(obs, state=state, episode_starts=episode_start, deterministic=deterministic) if isinstance(self.action_space, spaces.Box): From fa3929edaa0bf0d842fb45d16374b6e004ab275d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:37:08 -0700 Subject: [PATCH 12/16] Make limits NaN --- stable_baselines3/common/recurrent/policies.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 58a9499e5..6dc351562 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -1,4 +1,5 @@ import abc +import math from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union import torch as th @@ -279,8 +280,10 @@ def __init__( self.lstm_kwargs = lstm_kwargs or {} self.shared_lstm = shared_lstm self.enable_critic_lstm = enable_critic_lstm + + LSTM_BOX_LIMIT = math.nan # It does not matter what the limit is, it won't get used. self.lstm_actor = LSTMFlattenExtractor( - spaces.Box(-1e9, 1e9, (self.features_dim,)), + spaces.Box(LSTM_BOX_LIMIT, LSTM_BOX_LIMIT, (self.features_dim,)), features_dim=lstm_hidden_size, num_layers=n_lstm_layers, **self.lstm_kwargs, @@ -307,7 +310,7 @@ def __init__( # Use a separate LSTM for the critic if self.enable_critic_lstm: self.lstm_critic = LSTMFlattenExtractor( - spaces.Box(-1e9, 1e9, (self.features_dim,)), + spaces.Box(LSTM_BOX_LIMIT, LSTM_BOX_LIMIT, (self.features_dim,)), features_dim=lstm_hidden_size, num_layers=n_lstm_layers, **self.lstm_kwargs, From f4303434c2b345323a03252e8744f0df9ecc7355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:42:22 -0700 Subject: [PATCH 13/16] Less weird way of updating kwargs --- stable_baselines3/common/recurrent/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 6dc351562..216606ee3 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -653,7 +653,7 @@ def __init__( # Automatically deactivate dtype and bounds checks if normalize_images is False and issubclass(features_extractor_class, GRUNatureCNNExtractor): features_extractor_kwargs = features_extractor_kwargs.copy() - features_extractor_kwargs.update(dict(normalized_image=True)) + features_extractor_kwargs["normalized_image"] = True if not issubclass(features_extractor_class, RecurrentFeaturesExtractor): base_features_extractor = features_extractor_class(observation_space, **features_extractor_kwargs) From 750d78e93e7c70c8d7b9b4411942f0c1755db0c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 10:45:33 -0700 Subject: [PATCH 14/16] Use kwargs in child inits --- .../common/recurrent/policies.py | 157 +++++++++--------- 1 file changed, 78 insertions(+), 79 deletions(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 216606ee3..19a50a14f 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -258,23 +258,23 @@ def __init__( ): self.lstm_output_dim = lstm_hidden_size super().__init__( - observation_space, - action_space, - lr_schedule, - net_arch, - activation_fn, - ortho_init, - use_sde, - log_std_init, - full_std, - use_expln, - squash_output, - features_extractor_class, - features_extractor_kwargs, - share_features_extractor, - normalize_images, - optimizer_class, - optimizer_kwargs, + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + use_sde=use_sde, + log_std_init=log_std_init, + full_std=full_std, + use_expln=use_expln, + squash_output=squash_output, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, ) self.lstm_kwargs = lstm_kwargs or {} @@ -510,28 +510,28 @@ def __init__( lstm_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__( - observation_space, - action_space, - lr_schedule, - net_arch, - activation_fn, - ortho_init, - use_sde, - log_std_init, - full_std, - use_expln, - squash_output, - features_extractor_class, - features_extractor_kwargs, - share_features_extractor, - normalize_images, - optimizer_class, - optimizer_kwargs, - lstm_hidden_size, - n_lstm_layers, - shared_lstm, - enable_critic_lstm, - lstm_kwargs, + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + use_sde=use_sde, + log_std_init=log_std_init, + full_std=full_std, + use_expln=use_expln, + squash_output=squash_output, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + lstm_hidden_size=lstm_hidden_size, + n_lstm_layers=n_lstm_layers, + shared_lstm=shared_lstm, + enable_critic_lstm=enable_critic_lstm, + lstm_kwargs=lstm_kwargs, ) @@ -600,28 +600,28 @@ def __init__( lstm_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__( - observation_space, - action_space, - lr_schedule, - net_arch, - activation_fn, - ortho_init, - use_sde, - log_std_init, - full_std, - use_expln, - squash_output, - features_extractor_class, - features_extractor_kwargs, - share_features_extractor, - normalize_images, - optimizer_class, - optimizer_kwargs, - lstm_hidden_size, - n_lstm_layers, - shared_lstm, - enable_critic_lstm, - lstm_kwargs, + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + use_sde=use_sde, + log_std_init=log_std_init, + full_std=full_std, + use_expln=use_expln, + squash_output=squash_output, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + lstm_hidden_size=lstm_hidden_size, + n_lstm_layers=n_lstm_layers, + shared_lstm=shared_lstm, + enable_critic_lstm=enable_critic_lstm, + lstm_kwargs=lstm_kwargs, ) @@ -663,26 +663,25 @@ def __init__( if "features_dim" in features_extractor_kwargs: new_features_extractor_kwargs["features_dim"] = features_extractor_kwargs["features_dim"] features_extractor_kwargs = new_features_extractor_kwargs - print(features_extractor_class, features_extractor_kwargs) super().__init__( - observation_space, - action_space, - lr_schedule, - net_arch, - activation_fn, - ortho_init, - use_sde, - log_std_init, - full_std, - use_expln, - squash_output, - features_extractor_class, - features_extractor_kwargs, - share_features_extractor, - normalize_images, - optimizer_class, - optimizer_kwargs, + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + use_sde=use_sde, + log_std_init=log_std_init, + full_std=full_std, + use_expln=use_expln, + squash_output=squash_output, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, ) def recurrent_initial_state( From 53d2f2e14c732194a9e3be8cd85495f0a39808fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 11:27:02 -0700 Subject: [PATCH 15/16] Turn shape comment into an assert --- stable_baselines3/common/recurrent/torch_layers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/recurrent/torch_layers.py b/stable_baselines3/common/recurrent/torch_layers.py index f79221366..d312a3d92 100644 --- a/stable_baselines3/common/recurrent/torch_layers.py +++ b/stable_baselines3/common/recurrent/torch_layers.py @@ -44,6 +44,7 @@ def _process_sequence( ) -> Tuple[th.Tensor, RecurrentSubState]: (state_example, *_), _ = tree_flatten(init_state, is_leaf=None) n_layers, batch_sz, *_ = state_example.shape + assert n_layers == rnn.num_layers # Batch to sequence # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) @@ -55,8 +56,13 @@ def _process_sequence( raise NotImplementedError("Resetting state in the middle of a sequence is not supported") first_state_is_not_reset = (~episode_starts[0]).contiguous() - # Shape here is (n_layers, batch_sz) - init_state = tree_map(lambda x: x * first_state_is_not_reset.view((1, batch_sz, *(1,) * (x.ndim - 2))), init_state) + + def _reset_state_component(state: th.Tensor) -> th.Tensor: + assert state.shape == (rnn.num_layers, batch_sz, rnn.hidden_size) + reset_mask = first_state_is_not_reset.view((1, batch_sz, 1)) + return state * reset_mask + + init_state = tree_map(_reset_state_component, init_state) rnn_output, end_state = rnn(seq_inputs, init_state) # (seq_len, batch_size, ...) -> (batch_size, seq_len, ...) -> (batch_size * seq_len, ...) From fe52f29df7daad0b3d463469a919053faa3bb3c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 13 Oct 2023 11:32:22 -0700 Subject: [PATCH 16/16] Check identity with `is`. --- tests/test_type_aliases.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_type_aliases.py b/tests/test_type_aliases.py index c3ef81700..66da5b581 100644 --- a/tests/test_type_aliases.py +++ b/tests/test_type_aliases.py @@ -14,17 +14,18 @@ def test_non_null(): def test_check_cast(): - assert check_cast(dict, {}) == {} - assert check_cast(dict[str, int], {}) == {} - assert check_cast(Dict[str, int], {}) == {} + EMPTY_DICT = {} + assert check_cast(dict, EMPTY_DICT) is EMPTY_DICT + assert check_cast(dict[str, int], EMPTY_DICT) is EMPTY_DICT + assert check_cast(Dict[str, int], EMPTY_DICT) is EMPTY_DICT with pytest.raises(TypeError): - check_cast(list[int], {}) - check_cast(List[int], {}) + check_cast(list[int], EMPTY_DICT) + check_cast(List[int], EMPTY_DICT) # NOTE: check_cast does not check the template arguments, only the main class. # Tests should give an accurate understanding of how the function works, so we still check for this behavior. a: list[str] = ["a"] assert ( - check_cast(list[int], a) == a + check_cast(list[int], a) is a ), "If you managed to write code to trigger this assert that's good! We'd like template arguments to be checked."