From 36f3de60617e6941299c804c6653e3938217af5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 15 Sep 2023 14:15:05 -0700 Subject: [PATCH 01/31] Bring over recurrent policies from sb3_contrib --- stable_baselines3/__init__.py | 2 + .../common/recurrent/__init__.py | 0 stable_baselines3/common/recurrent/buffers.py | 384 +++++++++++ .../common/recurrent/policies.py | 611 ++++++++++++++++++ .../common/recurrent/type_aliases.py | 33 + stable_baselines3/ppo_recurrent/__init__.py | 4 + stable_baselines3/ppo_recurrent/policies.py | 9 + .../ppo_recurrent/ppo_recurrent.py | 494 ++++++++++++++ tests/test_cnn.py | 18 +- tests/test_deterministic.py | 10 +- tests/test_dict_env.py | 19 +- tests/test_identity.py | 13 +- tests/test_lstm.py | 248 +++++++ tests/test_save_load.py | 4 +- tests/test_train_eval_mode.py | 14 +- 15 files changed, 1833 insertions(+), 30 deletions(-) create mode 100644 stable_baselines3/common/recurrent/__init__.py create mode 100644 stable_baselines3/common/recurrent/buffers.py create mode 100644 stable_baselines3/common/recurrent/policies.py create mode 100644 stable_baselines3/common/recurrent/type_aliases.py create mode 100644 stable_baselines3/ppo_recurrent/__init__.py create mode 100644 stable_baselines3/ppo_recurrent/policies.py create mode 100644 stable_baselines3/ppo_recurrent/ppo_recurrent.py create mode 100644 tests/test_lstm.py diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 0775a8ec5..4bb05dcd2 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -6,6 +6,7 @@ from stable_baselines3.dqn import DQN from stable_baselines3.her.her_replay_buffer import HerReplayBuffer from stable_baselines3.ppo import PPO +from stable_baselines3.ppo_recurrent import RecurrentPPO from stable_baselines3.sac import SAC from stable_baselines3.td3 import TD3 @@ -27,6 +28,7 @@ def HER(*args, **kwargs): "DDPG", "DQN", "PPO", + "RecurrentPPO", "SAC", "TD3", "HerReplayBuffer", diff --git a/stable_baselines3/common/recurrent/__init__.py b/stable_baselines3/common/recurrent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py new file mode 100644 index 000000000..0f1bcef46 --- /dev/null +++ b/stable_baselines3/common/recurrent/buffers.py @@ -0,0 +1,384 @@ +from functools import partial +from typing import Callable, Generator, Optional, Tuple, Union + +import numpy as np +import torch as th +from gymnasium import spaces + +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.recurrent.type_aliases import ( + RecurrentDictRolloutBufferSamples, + RecurrentRolloutBufferSamples, + RNNStates, +) +from stable_baselines3.common.vec_env import VecNormalize + + +def pad( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Chunk sequences and pad them to have constant dimensions. + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device + :param tensor: Tensor of shape (batch_size, *tensor_shape) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq, max_length, *tensor_shape) + """ + # Create sequences given start and end + seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] + return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) + + +def pad_and_flatten( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device (cpu, gpu, ...) + :param tensor: Tensor of shape (max_length, n_seq, 1) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq * max_length,) aka (padded_batch_size,) + """ + return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() + + +def create_sequencers( + episode_starts: np.ndarray, + env_change: np.ndarray, + device: th.device, +) -> Tuple[np.ndarray, Callable, Callable]: + """ + Create the utility function to chunk data into + sequences and pad them to create fixed size tensors. + + :param episode_starts: Indices where an episode starts + :param env_change: Indices where the data collected + come from a different env (when using multiple env for data collection) + :param device: PyTorch device + :return: Indices of the transitions that start a sequence, + pad and pad_and_flatten utilities tailored for this batch + (sequence starts and ends indices are fixed) + """ + # Create sequence if env changes too + seq_start = (episode_starts | env_change).flatten() + # First index is always the beginning of a sequence + seq_start[0] = True + # Retrieve indices of sequence starts + seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 + # End of sequence are just before sequence starts + # Last index is also always end of a sequence + seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) + + # Create padding method for this minibatch + # to avoid repeating arguments (seq_start_indices, seq_end_indices) + local_pad = partial(pad, seq_start_indices, seq_end_indices, device) + local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) + return seq_start_indices, local_pad, local_pad_and_flatten + + +class RecurrentRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the LSTM cell and hidden states. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + (n_steps, lstm.num_layers, n_envs, lstm.hidden_size) + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) + + def reset(self): + super().reset() + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Sampling strategy that allows any mini batch size but requires + # more complexity and use of padding + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + # Number of sequences + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) + # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) + # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + + return RecurrentRolloutBufferSamples( + # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) + observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)), + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + ) + + +class RecurrentDictRolloutBuffer(DictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RecurrentRolloutBuffer to use dictionary observations + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) + + def reset(self): + super().reset() + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + for tensor in [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentDictRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) + + observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} + observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} + + return RecurrentDictRolloutBufferSamples( + observations=observations, + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + ) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py new file mode 100644 index 000000000..5e5090b2e --- /dev/null +++ b/stable_baselines3/common/recurrent/policies.py @@ -0,0 +1,611 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common.distributions import Distribution +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import zip_strict +from torch import nn + +from sb3_contrib.common.recurrent.type_aliases import RNNStates + + +class RecurrentActorCriticPolicy(ActorCriticPolicy): + """ + Recurrent policy class for actor-critic algorithms (has both policy and value prediction). + To be used with A2C, PPO and the likes. + It assumes that both the actor and the critic LSTM + have the same architecture. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic + (in that case, only the actor gradient is used) + By default, the actor and the critic have two separate LSTM. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[Dict[str, Any]] = None, + ): + self.lstm_output_dim = lstm_hidden_size + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + self.lstm_kwargs = lstm_kwargs or {} + self.shared_lstm = shared_lstm + self.enable_critic_lstm = enable_critic_lstm + self.lstm_actor = nn.LSTM( + self.features_dim, + lstm_hidden_size, + num_layers=n_lstm_layers, + **self.lstm_kwargs, + ) + # For the predict() method, to initialize hidden states + # (n_lstm_layers, batch_size, lstm_hidden_size) + self.lstm_hidden_state_shape = (n_lstm_layers, 1, lstm_hidden_size) + self.critic = None + self.lstm_critic = None + assert not ( + self.shared_lstm and self.enable_critic_lstm + ), "You must choose between shared LSTM, seperate or no LSTM for the critic." + + assert not ( + self.shared_lstm and not self.share_features_extractor + ), "If the features extractor is not shared, the LSTM cannot be shared." + + # No LSTM for the critic, we still need to convert + # output of features extractor to the correct size + # (size of the output of the actor lstm) + if not (self.shared_lstm or self.enable_critic_lstm): + self.critic = nn.Linear(self.features_dim, lstm_hidden_size) + + # Use a separate LSTM for the critic + if self.enable_critic_lstm: + self.lstm_critic = nn.LSTM( + self.features_dim, + lstm_hidden_size, + num_layers=n_lstm_layers, + **self.lstm_kwargs, + ) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + self.mlp_extractor = MlpExtractor( + self.lstm_output_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + @staticmethod + def _process_sequence( + features: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + lstm: nn.LSTM, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Do a forward pass in the LSTM network. + + :param features: Input tensor + :param lstm_states: previous cell and hidden states of the LSTM + :param episode_starts: Indicates when a new episode starts, + in that case, we need to reset LSTM states. + :param lstm: LSTM object. + :return: LSTM output and updated LSTM states. + """ + # LSTM logic + # (sequence length, batch size, features dim) + # (batch size = n_envs for data collection or n_seq when doing gradient update) + n_seq = lstm_states[0].shape[1] + # Batch to sequence + # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) + # note: max length (max sequence length) is always 1 during data collection + features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1) + + # If we don't have to reset the state in the middle of a sequence + # we can avoid the for loop, which speeds up things + if th.all(episode_starts == 0.0): + lstm_output, lstm_states = lstm(features_sequence, lstm_states) + lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1) + return lstm_output, lstm_states + + lstm_output = [] + # Iterate over the sequence + for features, episode_start in zip_strict(features_sequence, episode_starts): + hidden, lstm_states = lstm( + features.unsqueeze(dim=0), + ( + # Reset the states at the beginning of a new episode + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0], + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1], + ), + ) + lstm_output += [hidden] + # Sequence to batch + # (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim) + lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) + return lstm_output, lstm_states + + def forward( + self, + obs: th.Tensor, + lstm_states: RNNStates, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation. Observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + pi_features = vf_features = features # alis + else: + pi_features, vf_features = features + # latent_pi, latent_vf = self.mlp_extractor(features) + latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor) + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Re-use LSTM features but do not backpropagate + latent_vf = latent_pi.detach() + lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) + else: + # Critic only has a feedforward network + latent_vf = self.critic(vf_features) + lstm_states_vf = lstm_states_pi + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) + + def get_distribution( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> Tuple[Distribution, Tuple[th.Tensor, ...]]: + """ + Get the current policy distribution given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the action distribution and new hidden states. + """ + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.pi_features_extractor) + latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + return self._get_action_dist_from_latent(latent_pi), lstm_states + + def predict_values( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the estimated values. + """ + # Call the method from the parent of the parent class + features = super(ActorCriticPolicy, self).extract_features(obs, self.vf_features_extractor) + + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Use LSTM from the actor + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + return self.value_net(latent_vf) + + def evaluate_actions( + self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation. + :param actions: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + pi_features = vf_features = features # alias + else: + pi_features, vf_features = features + latent_pi, _ = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor) + if self.lstm_critic is not None: + latent_vf, _ = self._process_sequence(vf_features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(vf_features) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def _predict( + self, + observation: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy and hidden states of the RNN + """ + distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts) + return distribution.get_actions(deterministic=deterministic), lstm_states + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + if isinstance(observation, dict): + n_envs = observation[next(iter(observation.keys()))].shape[0] + else: + n_envs = observation.shape[0] + # state : (n_layers, n_envs, dim) + if state is None: + # Initialize hidden states to zeros + state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) + state = (state, state) + + if episode_start is None: + episode_start = np.array([False for _ in range(n_envs)]) + + with th.no_grad(): + # Convert to PyTorch tensors + states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( + state[1], dtype=th.float32, device=self.device + ) + episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device) + actions, states = self._predict( + observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic + ) + states = (states[0].cpu().numpy(), states[1].cpu().numpy()) + + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions.squeeze(axis=0) + + return actions, states + + +class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): + """ + CNN recurrent policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + shared_lstm, + enable_critic_lstm, + lstm_kwargs, + ) + + +class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + share_features_extractor, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + shared_lstm, + enable_critic_lstm, + lstm_kwargs, + ) diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py new file mode 100644 index 000000000..21ac0e0d9 --- /dev/null +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -0,0 +1,33 @@ +from typing import NamedTuple, Tuple + +import torch as th +from stable_baselines3.common.type_aliases import TensorDict + + +class RNNStates(NamedTuple): + pi: Tuple[th.Tensor, ...] + vf: Tuple[th.Tensor, ...] + + +class RecurrentRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + mask: th.Tensor + + +class RecurrentDictRolloutBufferSamples(NamedTuple): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + mask: th.Tensor diff --git a/stable_baselines3/ppo_recurrent/__init__.py b/stable_baselines3/ppo_recurrent/__init__.py new file mode 100644 index 000000000..f8301048b --- /dev/null +++ b/stable_baselines3/ppo_recurrent/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy +from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO + +__all__ = ["CnnLstmPolicy", "MlpLstmPolicy", "MultiInputLstmPolicy", "RecurrentPPO"] diff --git a/stable_baselines3/ppo_recurrent/policies.py b/stable_baselines3/ppo_recurrent/policies.py new file mode 100644 index 000000000..d9b374582 --- /dev/null +++ b/stable_baselines3/ppo_recurrent/policies.py @@ -0,0 +1,9 @@ +from sb3_contrib.common.recurrent.policies import ( + RecurrentActorCriticCnnPolicy, + RecurrentActorCriticPolicy, + RecurrentMultiInputActorCriticPolicy, +) + +MlpLstmPolicy = RecurrentActorCriticPolicy +CnnLstmPolicy = RecurrentActorCriticCnnPolicy +MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py new file mode 100644 index 000000000..acd44c9c1 --- /dev/null +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -0,0 +1,494 @@ +import sys +import time +from copy import deepcopy +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.vec_env import VecEnv + +from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer +from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy +from sb3_contrib.common.recurrent.type_aliases import RNNStates +from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy + +SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO") + + +class RecurrentPPO(OnPolicyAlgorithm): + """ + Proximal Policy Optimization algorithm (PPO) (clip version) + with support for recurrent policies (LSTM). + + Based on the original Stable Baselines 3 implementation. + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average + the reported success rate, mean episode length, and mean reward over + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + "MlpLstmPolicy": MlpLstmPolicy, + "CnnLstmPolicy": CnnLstmPolicy, + "MultiInputLstmPolicy": MultiInputLstmPolicy, + } + + def __init__( + self, + policy: Union[str, Type[RecurrentActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 128, + batch_size: Optional[int] = 128, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + self._last_lstm_states = None + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + **self.policy_kwargs, # pytype:disable=not-instantiable + ) + self.policy = self.policy.to(self.device) + + # We assume that LSTM for the actor and the critic + # have the same architecture + lstm = self.policy.lstm_actor + + if not isinstance(self.policy, RecurrentActorCriticPolicy): + raise ValueError("Policy must subclass RecurrentActorCriticPolicy") + + single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) + # hidden and cell states for actor and critic + self._last_lstm_states = RNNStates( + ( + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), + ), + ( + th.zeros(single_hidden_state_shape, device=self.device), + th.zeros(single_hidden_state_shape, device=self.device), + ), + ) + + hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + + self.rollout_buffer = buffer_cls( + self.n_steps, + self.observation_space, + self.action_space, + hidden_state_buffer_shape, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert isinstance( + rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) + ), f"{rollout_buffer} doesn't support recurrent policy" + + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + lstm_states = deepcopy(self._last_lstm_states) + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device) + actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) + + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done_ in enumerate(dones): + if ( + done_ + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_lstm_state = ( + lstm_states.vf[0][:, idx : idx + 1, :].contiguous(), + lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), + ) + # terminal_lstm_state = None + episode_starts = th.tensor([False], dtype=th.float32, device=self.device) + terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values, + log_probs, + lstm_states=self._last_lstm_states, + ) + + self._last_obs = new_obs + self._last_episode_starts = dones + self._last_lstm_states = lstm_states + + with th.no_grad(): + # Compute value for the last timestep + episode_starts = th.tensor(dones, dtype=th.float32, device=self.device) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.on_rollout_end() + + return True + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Convert mask from float to bool + mask = rollout_data.mask > 1e-8 + + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) + + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + rollout_data.lstm_states, + rollout_data.episode_starts, + ) + + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + # Mask padded sequences + value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) + + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob[mask]) + else: + entropy_loss = -th.mean(entropy[mask]) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self: SelfRecurrentPPO, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "RecurrentPPO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfRecurrentPPO: + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) + + if continue_training is False: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + self.train() + + callback.on_training_end() + + return self diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 1436cb9be..c7bb1a31e 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -6,11 +6,20 @@ import torch as th from gymnasium import spaces -from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 +from stable_baselines3 import A2C, DQN, PPO, SAC, TD3, RecurrentPPO from stable_baselines3.common.envs import FakeImageEnv -from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first +from stable_baselines3.common.preprocessing import ( + is_image_space, + is_image_space_channels_first, +) from stable_baselines3.common.utils import zip_strict -from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped +from stable_baselines3.common.vec_env import ( + DummyVecEnv, + VecFrameStack, + VecNormalize, + VecTransposeImage, + is_vecenv_wrapped, +) @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) @@ -315,7 +324,7 @@ def test_image_space_checks(): assert not is_image_space_channels_first(channel_mid_space) -@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, SAC, TD3]) +@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, SAC, TD3, RecurrentPPO]) @pytest.mark.parametrize("normalize_images", [True, False]) def test_image_like_input(model_class, normalize_images): """ @@ -342,6 +351,7 @@ def test_image_like_input(model_class, normalize_images): ), seed=1, ) + policy = "CnnLstmPolicy" if model_class == RecurrentPPO else "CnnPolicy" if model_class in {A2C, PPO}: kwargs.update(dict(n_steps=64)) diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index c165e485d..86a336983 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -1,14 +1,14 @@ import numpy as np import pytest -from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 +from stable_baselines3 import A2C, DQN, PPO, SAC, TD3, RecurrentPPO from stable_baselines3.common.noise import NormalActionNoise N_STEPS_TRAINING = 500 SEED = 0 -@pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3]) +@pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3, RecurrentPPO]) def test_deterministic_training_common(algo): results = [[], []] rewards = [[], []] @@ -25,9 +25,13 @@ def test_deterministic_training_common(algo): kwargs.update({"learning_starts": 100, "target_update_interval": 100}) elif algo == PPO: kwargs.update({"n_steps": 64, "n_epochs": 4}) + elif algo == RecurrentPPO: + kwargs.update({"policy_kwargs": dict(net_arch=[], enable_critic_lstm=True, lstm_hidden_size=8)}) + kwargs.update({"n_steps": 50, "n_epochs": 4}) + policy_str = "MlpLstmPolicy" if algo == RecurrentPPO else "MlpPolicy" for i in range(2): - model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) + model = algo(policy_str, env_id, seed=SEED, **kwargs) model.learn(N_STEPS_TRAINING) env = model.get_env() obs = env.reset() diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 3a34ff0d2..fd2045ac7 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -5,12 +5,17 @@ import pytest from gymnasium import spaces -from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, RecurrentPPO from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize +from stable_baselines3.common.vec_env import ( + DummyVecEnv, + SubprocVecEnv, + VecFrameStack, + VecNormalize, +) from stable_baselines3.common.vec_env.util import obs_as_tensor @@ -110,7 +115,7 @@ def test_goal_env(model_class): evaluate_policy(model, model.get_env()) -@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3, RecurrentPPO]) def test_consistency(model_class): """ Make sure that dict obs with vector only vs using flatten obs is equivalent. @@ -157,7 +162,7 @@ def test_consistency(model_class): assert np.allclose(action_1, action_2) -@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3, RecurrentPPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_spaces(model_class, channel_last): """ @@ -201,7 +206,7 @@ def test_dict_spaces(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [PPO, A2C, SAC, DQN]) +@pytest.mark.parametrize("model_class", [PPO, A2C, SAC, DQN, RecurrentPPO]) def test_multiprocessing(model_class): use_discrete_actions = model_class not in [SAC, TD3, DDPG] @@ -238,7 +243,7 @@ def make_env(): model.learn(total_timesteps=n_steps) -@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3, RecurrentPPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_vec_framestack(model_class, channel_last): """ @@ -286,7 +291,7 @@ def test_dict_vec_framestack(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3, RecurrentPPO]) def test_vec_normalize(model_class): """ Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support diff --git a/tests/test_identity.py b/tests/test_identity.py index 10b7d3767..a4bf1e327 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -1,8 +1,13 @@ import numpy as np import pytest -from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 -from stable_baselines3.common.envs import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, RecurrentPPO +from stable_baselines3.common.envs import ( + IdentityEnv, + IdentityEnvBox, + IdentityEnvMultiBinary, + IdentityEnvMultiDiscrete, +) from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import DummyVecEnv @@ -10,7 +15,7 @@ DIM = 4 -@pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) +@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, RecurrentPPO]) @pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) @@ -30,7 +35,7 @@ def test_discrete(model_class, env): assert np.shape(model.predict(obs)[0]) == np.shape(obs) -@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, DDPG, TD3]) +@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, DDPG, TD3, RecurrentPPO]) def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) diff --git a/tests/test_lstm.py b/tests/test_lstm.py new file mode 100644 index 000000000..e3f8ebbdd --- /dev/null +++ b/tests/test_lstm.py @@ -0,0 +1,248 @@ +from typing import Dict, Optional + +import gymnasium as gym +import numpy as np +import pytest +from gymnasium import spaces +from gymnasium.envs.classic_control import CartPoleEnv +from gymnasium.wrappers.time_limit import TimeLimit + +from stable_baselines3 import RecurrentPPO +from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.envs import FakeImageEnv +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import VecNormalize + + +class ToDictWrapper(gym.Wrapper): + """ + Simple wrapper to test MultInputPolicy on Dict obs. + """ + + def __init__(self, env): + super().__init__(env) + self.observation_space = spaces.Dict({"obs": self.env.observation_space}) + + def reset(self, **kwargs): + return {"obs": self.env.reset(**kwargs)[0]}, {} + + def step(self, action): + obs, reward, done, truncated, infos = self.env.step(action) + return {"obs": obs}, reward, done, truncated, infos + + +class CartPoleNoVelEnv(CartPoleEnv): + """Variant of CartPoleEnv with velocity information removed. This task requires memory to solve.""" + + def __init__(self): + super().__init__() + high = np.array( + [ + self.x_threshold * 2, + self.theta_threshold_radians * 2, + ] + ) + self.observation_space = spaces.Box(-high, high, dtype=np.float32) + + @staticmethod + def _pos_obs(full_obs): + xpos, _xvel, thetapos, _thetavel = full_obs + return np.array([xpos, thetapos]) + + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + full_obs, info = super().reset(seed=seed, options=options) + return CartPoleNoVelEnv._pos_obs(full_obs), info + + def step(self, action): + full_obs, rew, terminated, truncated, info = super().step(action) + return CartPoleNoVelEnv._pos_obs(full_obs), rew, terminated, truncated, info + + +def test_env(): + check_env(CartPoleNoVelEnv()) + + +@pytest.mark.parametrize( + "policy_kwargs", + [ + {}, + {"share_features_extractor": False}, + dict(shared_lstm=True, enable_critic_lstm=False), + dict( + enable_critic_lstm=True, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, + ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, + ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + share_features_extractor=False, + ), + ], +) +def test_cnn(policy_kwargs): + model = RecurrentPPO( + "CnnLstmPolicy", + FakeImageEnv(screen_height=40, screen_width=40, n_channels=3), + n_steps=16, + seed=0, + policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)), + n_epochs=2, + ) + + model.learn(total_timesteps=32) + + +@pytest.mark.parametrize( + "policy_kwargs", + [ + {}, + dict(shared_lstm=True, enable_critic_lstm=False), + dict( + enable_critic_lstm=True, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, + ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, + ), + ], +) +def test_policy_kwargs(policy_kwargs): + model = RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + + model.learn(total_timesteps=32) + + +def test_check(): + policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=True) + with pytest.raises(AssertionError): + RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + + policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=False, share_features_extractor=False) + with pytest.raises(AssertionError): + RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + + +@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) +def test_run(env): + model = RecurrentPPO( + "MlpLstmPolicy", + env, + n_steps=16, + seed=0, + ) + + model.learn(total_timesteps=32) + + +def test_run_sde(): + model = RecurrentPPO( + "MlpLstmPolicy", + "Pendulum-v1", + n_steps=16, + seed=0, + sde_sample_freq=4, + use_sde=True, + clip_range_vf=0.1, + ) + + model.learn(total_timesteps=200) + + +@pytest.mark.parametrize( + "policy_kwargs", + [ + {}, + dict(shared_lstm=True, enable_critic_lstm=False), + dict( + enable_critic_lstm=True, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, + ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, + ), + ], +) +def test_dict_obs(policy_kwargs): + env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper) + model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64) + evaluate_policy(model, env, warn=False) + + +@pytest.mark.slow +def test_ppo_lstm_performance(): + # env = make_vec_env("CartPole-v1", n_envs=16) + def make_env(): + env = CartPoleNoVelEnv() + env = TimeLimit(env, max_episode_steps=500) + return env + + env = VecNormalize(make_vec_env(make_env, n_envs=8)) + + eval_callback = EvalCallback( + VecNormalize(make_vec_env(make_env, n_envs=4), training=False, norm_reward=False), + n_eval_episodes=20, + eval_freq=5000 // env.num_envs, + ) + + model = RecurrentPPO( + "MlpLstmPolicy", + env, + n_steps=128, + learning_rate=0.0007, + verbose=1, + batch_size=256, + seed=1, + n_epochs=10, + max_grad_norm=1, + gae_lambda=0.98, + policy_kwargs=dict( + net_arch=dict(vf=[64], pi=[]), + lstm_hidden_size=64, + ortho_init=False, + enable_critic_lstm=True, + ), + ) + + model.learn(total_timesteps=50_000, callback=eval_callback) + # Maximum episode reward is 500. + # In CartPole-v1, a non-recurrent policy can easily get >= 450. + # In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50. + evaluate_policy(model, env, reward_threshold=450) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 888bbf7a2..8161ee44a 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -14,7 +14,7 @@ import pytest import torch as th -from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, RecurrentPPO from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox @@ -22,7 +22,7 @@ from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG] +MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG, RecurrentPPO] def select_env(model_class: BaseAlgorithm) -> gym.Env: diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index dcbda74e1..30f989947 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -6,17 +6,11 @@ import torch as th import torch.nn as nn -from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 +from stable_baselines3 import A2C, DQN, PPO, SAC, TD3, RecurrentPPO from stable_baselines3.common.preprocessing import get_flattened_obs_dim from stable_baselines3.common.torch_layers import BaseFeaturesExtractor -MODEL_LIST = [ - PPO, - A2C, - TD3, - SAC, - DQN, -] +MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, RecurrentPPO] class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): @@ -268,7 +262,7 @@ def test_sac_train_with_batch_norm(): assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all() -@pytest.mark.parametrize("model_class", [A2C, PPO]) +@pytest.mark.parametrize("model_class", [A2C, PPO, RecurrentPPO]) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_a2c_ppo_train_with_batch_norm(model_class, env_id): model = model_class( @@ -319,7 +313,7 @@ def test_offpolicy_collect_rollout_batch_norm(model_class): assert th.isclose(param_before, param_after).all() -@pytest.mark.parametrize("model_class", [A2C, PPO]) +@pytest.mark.parametrize("model_class", [A2C, PPO, RecurrentPPO]) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id): model = model_class( From d9ac3f85dbff2c10e3fda886e3671275478303cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 15 Sep 2023 14:23:52 -0700 Subject: [PATCH 02/31] fix type annotaitons --- tests/test_train_eval_mode.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index 30f989947..c9c8c6dde 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Tuple, Union import gymnasium as gym import numpy as np @@ -37,7 +37,7 @@ def forward(self, observations: th.Tensor) -> th.Tensor: return result -def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor): +def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> Tuple[th.Tensor, th.Tensor]: """ Clone the bias and running mean from the given batch norm layer. @@ -47,7 +47,7 @@ def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor) return batch_norm.bias.clone(), batch_norm.running_mean.clone() -def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor): +def clone_dqn_batch_norm_stats(model: DQN) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]: """ Clone the bias and running mean from the Q-network and target network. @@ -65,7 +65,7 @@ def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, def clone_td3_batch_norm_stats( model: TD3, -) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor): +) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: """ Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks. @@ -98,7 +98,7 @@ def clone_td3_batch_norm_stats( def clone_sac_batch_norm_stats( model: SAC, -) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor): +) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: """ Clone the bias and running mean from the actor and critic networks and critic-target networks. @@ -117,7 +117,7 @@ def clone_sac_batch_norm_stats( return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean) -def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> (th.Tensor, th.Tensor): +def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> Tuple[th.Tensor, th.Tensor]: return clone_batch_norm_stats(model.policy.features_extractor.batch_norm) From 1eb682d41201335f71277585fcbc91f3ef8c2cb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 15 Sep 2023 15:09:04 -0700 Subject: [PATCH 03/31] A recurrent buffer candidate --- stable_baselines3/common/pytree_dataclass.py | 46 +++ stable_baselines3/common/recurrent/buffers.py | 309 ++++-------------- .../common/recurrent/type_aliases.py | 40 ++- 3 files changed, 138 insertions(+), 257 deletions(-) create mode 100644 stable_baselines3/common/pytree_dataclass.py diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py new file mode 100644 index 000000000..54160e59d --- /dev/null +++ b/stable_baselines3/common/pytree_dataclass.py @@ -0,0 +1,46 @@ +import dataclasses +from typing import Optional, Sequence, Type + +import optree as ot +from typing_extensions import dataclass_transform + +__all__ = ["register_dataclass_as_pytree", "dataclass_frozen_pytree", "tree_empty", "OT_NAMESPACE"] + +OT_NAMESPACE = "stable-baselines3" + + +def register_dataclass_as_pytree(Cls, whitelist: Optional[Sequence[str]] = None): + """Register a dataclass as a pytree, using the given whitelist of field names. + + :param Cls: The dataclass to register. + :param whitelist: The names of the fields to include in the pytree. If None, all fields are included. + :return: The dataclass, with the pytree registration applied. This is useful to be able to register a decorator. + """ + + assert dataclasses.is_dataclass(Cls) + + names = tuple(f.name for f in dataclasses.fields(Cls) if whitelist is None or f.name in whitelist) + + def flatten_fn(inst): + return (getattr(inst, n) for n in names), None, names + + def unflatten_fn(context, values): + return Cls(**dict(zip(names, values))) + + ot.register_pytree_node(Cls, flatten_fn, unflatten_fn, namespace=OT_NAMESPACE) + + Cls.__iter__ = lambda self: iter(getattr(self, n) for n in names) + return Cls + + +@dataclass_transform() +def dataclass_frozen_pytree(Cls: Type, **kwargs) -> Type[ot.PyTree]: + """Decorator to make a frozen dataclass and register it as a PyTree.""" + dataCls = dataclasses.dataclass(frozen=True, slots=True, **kwargs)(Cls) + register_dataclass_as_pytree(dataCls) + return dataCls + + +def tree_empty(tree: ot.PyTree) -> bool: + flattened_state, _ = ot.tree_flatten(tree, namespace=OT_NAMESPACE) + return not bool(len(flattened_state)) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 0f1bcef46..432cf1517 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -1,97 +1,62 @@ +import dataclasses from functools import partial -from typing import Callable, Generator, Optional, Tuple, Union +from typing import Callable, Generator, Optional, Tuple, Type, Union import numpy as np +import optree as ot import torch as th from gymnasium import spaces +from optree import PyTree from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree from stable_baselines3.common.recurrent.type_aliases import ( + HiddenState, RecurrentDictRolloutBufferSamples, RecurrentRolloutBufferSamples, - RNNStates, + space_to_example, ) from stable_baselines3.common.vec_env import VecNormalize -def pad( - seq_start_indices: np.ndarray, - seq_end_indices: np.ndarray, - device: th.device, - tensor: np.ndarray, - padding_value: float = 0.0, -) -> th.Tensor: - """ - Chunk sequences and pad them to have constant dimensions. - - :param seq_start_indices: Indices of the transitions that start a sequence - :param seq_end_indices: Indices of the transitions that end a sequence - :param device: PyTorch device - :param tensor: Tensor of shape (batch_size, *tensor_shape) - :param padding_value: Value used to pad sequence to the same length - (zero padding by default) - :return: (n_seq, max_length, *tensor_shape) - """ - # Create sequences given start and end - seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] - return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) - - -def pad_and_flatten( - seq_start_indices: np.ndarray, - seq_end_indices: np.ndarray, - device: th.device, - tensor: np.ndarray, - padding_value: float = 0.0, -) -> th.Tensor: - """ - Pad and flatten the sequences of scalar values, - while keeping the sequence order. - From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) - - :param seq_start_indices: Indices of the transitions that start a sequence - :param seq_end_indices: Indices of the transitions that end a sequence - :param device: PyTorch device (cpu, gpu, ...) - :param tensor: Tensor of shape (max_length, n_seq, 1) - :param padding_value: Value used to pad sequence to the same length - (zero padding by default) - :return: (n_seq * max_length,) aka (padded_batch_size,) - """ - return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() - - -def create_sequencers( - episode_starts: np.ndarray, - env_change: np.ndarray, - device: th.device, -) -> Tuple[np.ndarray, Callable, Callable]: - """ - Create the utility function to chunk data into - sequences and pad them to create fixed size tensors. - - :param episode_starts: Indices where an episode starts - :param env_change: Indices where the data collected - come from a different env (when using multiple env for data collection) - :param device: PyTorch device - :return: Indices of the transitions that start a sequence, - pad and pad_and_flatten utilities tailored for this batch - (sequence starts and ends indices are fixed) - """ - # Create sequence if env changes too - seq_start = (episode_starts | env_change).flatten() - # First index is always the beginning of a sequence - seq_start[0] = True - # Retrieve indices of sequence starts - seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 - # End of sequence are just before sequence starts - # Last index is also always end of a sequence - seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) - - # Create padding method for this minibatch - # to avoid repeating arguments (seq_start_indices, seq_end_indices) - local_pad = partial(pad, seq_start_indices, seq_end_indices, device) - local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) - return seq_start_indices, local_pad, local_pad_and_flatten +@dataclass_frozen_pytree +class RecurrentRolloutBufferData: + observations: PyTree[th.Tensor] + actions: th.Tensor + rewards: th.Tensor + returns: th.Tensor + episode_starts: th.Tensor + values: th.Tensor + log_probs: th.Tensor + advantages: th.Tensor + hidden_states: HiddenState + + @classmethod + def make_zeros( + cls: Type["RecurrentRolloutBufferData"], + batch_shape: Union[Tuple[int], Tuple[int, int]], + action_dim: int, + observation_space: spaces.Space, + hidden_state_example: HiddenState, + *, + device: Optional[th.device] = None + ) -> "RecurrentRolloutBufferData": + seq_shape = batch_shape[:-1] + batch_dim = batch_shape[-1] + return RecurrentRolloutBufferData( + observations=space_to_example(batch_shape, observation_space, device=device), + actions=th.zeros((*batch_shape, action_dim), dtype=th.float32, device=device), + rewards=th.zeros(batch_shape, dtype=th.float32, device=device), + returns=th.zeros(batch_shape, dtype=th.float32, device=device), + episode_starts=th.zeros(batch_shape, dtype=th.float32, device=device), + values=th.zeros(batch_shape, dtype=th.float32, device=device), + log_probs=th.zeros(batch_shape, dtype=th.float32, device=device), + advantages=th.zeros(batch_shape, dtype=th.float32, device=device), + hidden_states=ot.tree_map( + lambda x: th.zeros((*seq_shape, x.shape[0], batch_dim, x.shape[1:]), dtype=x.dtype, device=device), + hidden_state_example, + ), + ) class RecurrentRolloutBuffer(RolloutBuffer): @@ -115,33 +80,37 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - hidden_state_shape: Tuple[int, int, int, int], + hidden_state_example: HiddenState, device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - self.hidden_state_shape = hidden_state_shape - self.seq_start_indices, self.seq_end_indices = None, None + self.hidden_state_example = ot.tree_map( + lambda x: th.zeros((), dtype=x.dtype, device=device).expand_as(x), hidden_state_example + ) super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): - super().reset() - self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.data = RecurrentRolloutBufferData.make_zeros( + batch_shape=(self.buffer_size, self.n_envs), + action_dim=self.action_dim, + observation_space=self.observation_space, + hidden_state_example=self.hidden_state_example, + device=self.device, + ) + super(RolloutBuffer, self).reset() - def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: + def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) - self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) - self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) - self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) - - super().add(*args, **kwargs) + new_data = dataclasses.replace(data, actions=data.actions.reshape((self.n_envs, self.action_dim))) + ot.tree_map(lambda buf, x: buf[self.pos].copy_(x, non_blocking=True), self.data, new_data) + # Increment pos + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" @@ -176,6 +145,11 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs + if batch_size == self.buffer_size * self.n_envs: + return self._get_samples(slice(None)) + + raise NotImplementedError("only supports full batches") + # Sampling strategy that allows any mini batch size but requires # more complexity and use of padding # Trick to shuffle a bit: keep the sequence order @@ -235,150 +209,7 @@ def _get_samples( old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), advantages=self.pad_and_flatten(self.advantages[batch_inds]), returns=self.pad_and_flatten(self.returns[batch_inds]), - lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), - mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), - ) - - -class RecurrentDictRolloutBuffer(DictRolloutBuffer): - """ - Dict Rollout buffer used in on-policy algorithms like A2C/PPO. - Extends the RecurrentRolloutBuffer to use dictionary observations - - :param buffer_size: Max number of element in the buffer - :param observation_space: Observation space - :param action_space: Action space - :param hidden_state_shape: Shape of the buffer that will collect lstm states - :param device: PyTorch device - :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - Equivalent to classic advantage when set to 1. - :param gamma: Discount factor - :param n_envs: Number of parallel environments - """ - - def __init__( - self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - hidden_state_shape: Tuple[int, int, int, int], - device: Union[th.device, str] = "auto", - gae_lambda: float = 1, - gamma: float = 0.99, - n_envs: int = 1, - ): - self.hidden_state_shape = hidden_state_shape - self.seq_start_indices, self.seq_end_indices = None, None - super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) - - def reset(self): - super().reset() - self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - - def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: - """ - :param hidden_states: LSTM cell and hidden state - """ - self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) - self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) - self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) - self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) - - super().add(*args, **kwargs) - - def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: - assert self.full, "Rollout buffer must be full before sampling from it" - - # Prepare the data - if not self.generator_ready: - # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) - # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) - for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: - self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) - - for key, obs in self.observations.items(): - self.observations[key] = self.swap_and_flatten(obs) - - for tensor in [ - "actions", - "values", - "log_probs", - "advantages", - "returns", - "hidden_states_pi", - "cell_states_pi", - "hidden_states_vf", - "cell_states_vf", - "episode_starts", - ]: - self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) - self.generator_ready = True - - # Return everything, don't create minibatches - if batch_size is None: - batch_size = self.buffer_size * self.n_envs - - # Trick to shuffle a bit: keep the sequence order - # but split the indices in two - split_index = np.random.randint(self.buffer_size * self.n_envs) - indices = np.arange(self.buffer_size * self.n_envs) - indices = np.concatenate((indices[split_index:], indices[:split_index])) - - env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) - # Flag first timestep as change of environment - env_change[0, :] = 1.0 - env_change = self.swap_and_flatten(env_change) - - start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: - batch_inds = indices[start_idx : start_idx + batch_size] - yield self._get_samples(batch_inds, env_change) - start_idx += batch_size - - def _get_samples( - self, - batch_inds: np.ndarray, - env_change: np.ndarray, - env: Optional[VecNormalize] = None, - ) -> RecurrentDictRolloutBufferSamples: - # Retrieve sequence starts and utility function - self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( - self.episode_starts[batch_inds], env_change[batch_inds], self.device - ) - - n_seq = len(self.seq_start_indices) - max_length = self.pad(self.actions[batch_inds]).shape[1] - padded_batch_size = n_seq * max_length - # We retrieve the lstm hidden states that will allow - # to properly initialize the LSTM at the beginning of each sequence - lstm_states_pi = ( - # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_vf = ( - # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) - lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) - - observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} - observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} - - return RecurrentDictRolloutBufferSamples( - observations=observations, - actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), - old_values=self.pad_and_flatten(self.values[batch_inds]), - old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), - advantages=self.pad_and_flatten(self.advantages[batch_inds]), - returns=self.pad_and_flatten(self.returns[batch_inds]), - lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + lstm_states=HiddenState(lstm_states_pi, lstm_states_vf), episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), ) diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index 21ac0e0d9..189967531 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -1,33 +1,37 @@ -from typing import NamedTuple, Tuple +from typing import Iterable, NamedTuple, Optional, Sequence, Tuple import torch as th +from gymnasium import spaces +from optree import PyTree +from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree from stable_baselines3.common.type_aliases import TensorDict +HiddenState = PyTree[th.Tensor] -class RNNStates(NamedTuple): - pi: Tuple[th.Tensor, ...] - vf: Tuple[th.Tensor, ...] - -class RecurrentRolloutBufferSamples(NamedTuple): - observations: th.Tensor - actions: th.Tensor - old_values: th.Tensor - old_log_prob: th.Tensor - advantages: th.Tensor - returns: th.Tensor - lstm_states: RNNStates - episode_starts: th.Tensor - mask: th.Tensor +def space_to_example( + batch_shape: Tuple[int, ...], space: spaces.Space, *, device: Optional[th.device] = None +) -> PyTree[th.Tensor]: + if isinstance(space, spaces.Box): + return torch.zeros((*batch_shape, space.shape), dtype=th.float32, device=device) + elif isinstance(space, spaces.Discrete): + return torch.zeros((*batch_shape), dtype=th.int64, device=device) + elif isinstance(space, spaces.Dict): + return {k: space_to_example(v) for k, v in space.items()} + elif isinstance(space, spaces.Tuple): + return tuple(space_to_example(v) for v in space) + else: + raise TypeError(f"Unknown space type {type(space)} for {space}") -class RecurrentDictRolloutBufferSamples(NamedTuple): - observations: TensorDict +@dataclass_frozen_pytree +class RecurrentRolloutBufferSamples: + observations: PyTree[th.Tensor] actions: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor - lstm_states: RNNStates + hidden_states: RNNStates episode_starts: th.Tensor mask: th.Tensor From 8180318e26bb2fc30e715019a0f8bc377a1b0234 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Mon, 18 Sep 2023 15:52:09 -0700 Subject: [PATCH 04/31] Recurrent buffer using pytrees --- stable_baselines3/common/pytree_dataclass.py | 4 +- stable_baselines3/common/recurrent/buffers.py | 233 +++++++----------- .../common/recurrent/type_aliases.py | 36 ++- 3 files changed, 116 insertions(+), 157 deletions(-) diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py index 54160e59d..b5303284e 100644 --- a/stable_baselines3/common/pytree_dataclass.py +++ b/stable_baselines3/common/pytree_dataclass.py @@ -36,7 +36,9 @@ def unflatten_fn(context, values): @dataclass_transform() def dataclass_frozen_pytree(Cls: Type, **kwargs) -> Type[ot.PyTree]: """Decorator to make a frozen dataclass and register it as a PyTree.""" - dataCls = dataclasses.dataclass(frozen=True, slots=True, **kwargs)(Cls) + true_kwargs = dict(frozen=True, slots=True) + true_kwargs.update(kwargs) + dataCls = dataclasses.dataclass(**true_kwargs)(Cls) register_dataclass_as_pytree(dataCls) return dataCls diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 432cf1517..ac73dd6d6 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -1,62 +1,29 @@ import dataclasses -from functools import partial -from typing import Callable, Generator, Optional, Tuple, Type, Union +from typing import Any, Callable, Generator, Optional, Union -import numpy as np import optree as ot import torch as th from gymnasium import spaces -from optree import PyTree -from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer -from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS from stable_baselines3.common.recurrent.type_aliases import ( HiddenState, - RecurrentDictRolloutBufferSamples, + PyTreeGeneric, RecurrentRolloutBufferSamples, space_to_example, ) from stable_baselines3.common.vec_env import VecNormalize -@dataclass_frozen_pytree -class RecurrentRolloutBufferData: - observations: PyTree[th.Tensor] - actions: th.Tensor - rewards: th.Tensor - returns: th.Tensor - episode_starts: th.Tensor - values: th.Tensor - log_probs: th.Tensor - advantages: th.Tensor - hidden_states: HiddenState - - @classmethod - def make_zeros( - cls: Type["RecurrentRolloutBufferData"], - batch_shape: Union[Tuple[int], Tuple[int, int]], - action_dim: int, - observation_space: spaces.Space, - hidden_state_example: HiddenState, - *, - device: Optional[th.device] = None - ) -> "RecurrentRolloutBufferData": - seq_shape = batch_shape[:-1] - batch_dim = batch_shape[-1] - return RecurrentRolloutBufferData( - observations=space_to_example(batch_shape, observation_space, device=device), - actions=th.zeros((*batch_shape, action_dim), dtype=th.float32, device=device), - rewards=th.zeros(batch_shape, dtype=th.float32, device=device), - returns=th.zeros(batch_shape, dtype=th.float32, device=device), - episode_starts=th.zeros(batch_shape, dtype=th.float32, device=device), - values=th.zeros(batch_shape, dtype=th.float32, device=device), - log_probs=th.zeros(batch_shape, dtype=th.float32, device=device), - advantages=th.zeros(batch_shape, dtype=th.float32, device=device), - hidden_states=ot.tree_map( - lambda x: th.zeros((*seq_shape, x.shape[0], batch_dim, x.shape[1:]), dtype=x.dtype, device=device), - hidden_state_example, - ), - ) +def index_into_pytree( + idx: Any, + tree: PyTreeGeneric, + is_leaf: Optional[Union[bool, Callable[[PyTreeGeneric], bool]]] = None, + none_is_leaf: bool = False, + namespace: str = NS, +) -> PyTreeGeneric: + return ot.tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) class RecurrentRolloutBuffer(RolloutBuffer): @@ -91,20 +58,81 @@ def __init__( ) super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) - def reset(self): - self.data = RecurrentRolloutBufferData.make_zeros( - batch_shape=(self.buffer_size, self.n_envs), - action_dim=self.action_dim, - observation_space=self.observation_space, - hidden_state_example=self.hidden_state_example, - device=self.device, + batch_shape = (self.buffer_size, self.n_envs) + action_dim = self.action_dim + observation_space = self.observation_space + hidden_state_example = self.hidden_state_example + device = self.device + + self.data = RecurrentRolloutBufferSamples( + observations=space_to_example(batch_shape, observation_space, device=device, ensure_non_batch_dim=True), + actions=th.zeros((*batch_shape, action_dim), dtype=th.float32, device=device), + old_values=th.zeros(batch_shape, dtype=th.float32, device=device), + old_log_probs=th.zeros(batch_shape, dtype=th.float32, device=device), + advantages=th.zeros(batch_shape, dtype=th.float32, device=device), + returns=th.zeros(batch_shape, dtype=th.float32, device=device), + hidden_states=ot.tree_map( + lambda x: th.zeros( + (*batch_shape[:-1], x.shape[0], batch_shape[-1], x.shape[1:]), dtype=x.dtype, device=device + ), + hidden_state_example, + ), + episode_starts=th.zeros(batch_shape, dtype=th.float32, device=device), + rewards=th.zeros(batch_shape, dtype=th.float32, device=device), ) + + # Expose attributes of the RecurrentRolloutBufferData in the top-level to conform to the RolloutBuffer interface + @property + def episode_starts(self) -> th.Tensor: + return self.data.episode_starts + + @property + def values(self) -> th.Tensor: + return self.data.old_values + + @property + def rewards(self) -> th.Tensor: + assert self.data.rewards is not None, "RecurrentRolloutBufferData should store rewards" + return self.data.rewards + + @property + def advantages(self) -> th.Tensor: + return self.data.advantages + + @property + def returns(self) -> th.Tensor: + return self.data.returns + + @returns.setter + def _set_returns(self, new_returns: th.Tensor): + self.data.returns.copy_(new_returns, non_blocking=True) + + def reset(self): + ot.tree_map(lambda x: x.zero_(), self.data, namespace=NS) super(RolloutBuffer, self).reset() - def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: + def extend(self, *args) -> None: + """ + Add a new batch of transitions to the buffer + """ + + # Do a for loop along the batch axis. + # Treat lists as leaves to avoid flattening the infos. + def _is_list(t): + return isinstance(t, list) + + tensors, _ = ot.tree_flatten(args, is_leaf=_is_list, namespace=NS) + len_tensors = len(tensors[0]) + assert all(len(t) == len_tensors for t in tensors), "All tensors must have the same batch size" + for i in range(len_tensors): + self.add(*index_into_pytree(i, args, is_leaf=_is_list, namespace=NS)) + + def add(self, data: RecurrentRolloutBufferSamples, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ + if data.rewards is None: + raise ValueError("Recorded samples must contain a reward") new_data = dataclasses.replace(data, actions=data.actions.reshape((self.n_envs, self.action_dim))) ot.tree_map(lambda buf, x: buf[self.pos].copy_(x, non_blocking=True), self.data, new_data) # Increment pos @@ -115,101 +143,18 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" - # Prepare the data - if not self.generator_ready: - # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) - # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) - for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: - self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) - - # flatten but keep the sequence order - # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) - # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) - for tensor in [ - "observations", - "actions", - "values", - "log_probs", - "advantages", - "returns", - "hidden_states_pi", - "cell_states_pi", - "hidden_states_vf", - "cell_states_vf", - "episode_starts", - ]: - self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) - self.generator_ready = True - # Return everything, don't create minibatches - if batch_size is None: - batch_size = self.buffer_size * self.n_envs - - if batch_size == self.buffer_size * self.n_envs: - return self._get_samples(slice(None)) - - raise NotImplementedError("only supports full batches") - - # Sampling strategy that allows any mini batch size but requires - # more complexity and use of padding - # Trick to shuffle a bit: keep the sequence order - # but split the indices in two - split_index = np.random.randint(self.buffer_size * self.n_envs) - indices = np.arange(self.buffer_size * self.n_envs) - indices = np.concatenate((indices[split_index:], indices[:split_index])) - - env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) - # Flag first timestep as change of environment - env_change[0, :] = 1.0 - env_change = self.swap_and_flatten(env_change) + if batch_size is None or batch_size == self.buffer_size * self.n_envs: + yield self._get_samples(slice(None)) + return - start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: - batch_inds = indices[start_idx : start_idx + batch_size] - yield self._get_samples(batch_inds, env_change) - start_idx += batch_size + for start_idx in range(0, self.buffer_size * self.n_envs, batch_size): + yield self._get_samples(slice(start_idx, start_idx + batch_size, None)) def _get_samples( self, - batch_inds: np.ndarray, - env_change: np.ndarray, + batch_inds: Union[slice, th.Tensor], env: Optional[VecNormalize] = None, ) -> RecurrentRolloutBufferSamples: - # Retrieve sequence starts and utility function - self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( - self.episode_starts[batch_inds], env_change[batch_inds], self.device - ) - - # Number of sequences - n_seq = len(self.seq_start_indices) - max_length = self.pad(self.actions[batch_inds]).shape[1] - padded_batch_size = n_seq * max_length - # We retrieve the lstm hidden states that will allow - # to properly initialize the LSTM at the beginning of each sequence - lstm_states_pi = ( - # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) - # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) - # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_vf = ( - # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), - ) - lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) - lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) - - return RecurrentRolloutBufferSamples( - # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) - observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)), - actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), - old_values=self.pad_and_flatten(self.values[batch_inds]), - old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), - advantages=self.pad_and_flatten(self.advantages[batch_inds]), - returns=self.pad_and_flatten(self.returns[batch_inds]), - lstm_states=HiddenState(lstm_states_pi, lstm_states_vf), - episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), - mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), - ) + data_without_reward: RecurrentRolloutBufferSamples = dataclasses.replace(self.data, rewards=None) + return ot.tree_map(lambda tens: self.to_device(tens[batch_inds]), data_without_reward) diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index 189967531..a181de398 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -1,30 +1,42 @@ -from typing import Iterable, NamedTuple, Optional, Sequence, Tuple +from typing import Optional, Tuple, TypeVar import torch as th from gymnasium import spaces from optree import PyTree + from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree -from stable_baselines3.common.type_aliases import TensorDict HiddenState = PyTree[th.Tensor] +PyTreeGeneric = TypeVar("PyTreeGeneric", bound=PyTree) + + def space_to_example( - batch_shape: Tuple[int, ...], space: spaces.Space, *, device: Optional[th.device] = None + batch_shape: Tuple[int, ...], + space: spaces.Space, + *, + device: Optional[th.device] = None, + ensure_non_batch_dim: bool = False, ) -> PyTree[th.Tensor]: + if isinstance(space, spaces.Dict): + return {k: space_to_example(v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for k, v in space.items()} + if isinstance(space, spaces.Tuple): + return tuple(space_to_example(v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for v in space) + if isinstance(space, spaces.Box): - return torch.zeros((*batch_shape, space.shape), dtype=th.float32, device=device) + space_shape = space.shape elif isinstance(space, spaces.Discrete): - return torch.zeros((*batch_shape), dtype=th.int64, device=device) - elif isinstance(space, spaces.Dict): - return {k: space_to_example(v) for k, v in space.items()} - elif isinstance(space, spaces.Tuple): - return tuple(space_to_example(v) for v in space) + space_shape = () else: raise TypeError(f"Unknown space type {type(space)} for {space}") + if ensure_non_batch_dim and space_shape: + space_shape = (1,) + return th.zeros((*batch_shape, *space_shape), dtype=th.float32, device=device) + -@dataclass_frozen_pytree +@functools.partial(dataclass_frozen_pytree, frozen=False) class RecurrentRolloutBufferSamples: observations: PyTree[th.Tensor] actions: th.Tensor @@ -32,6 +44,6 @@ class RecurrentRolloutBufferSamples: old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor - hidden_states: RNNStates + hidden_states: HiddenState episode_starts: th.Tensor - mask: th.Tensor + rewards: Optional[th.Tensor] = None From 1596f9c25bf427db3ca4f4a1473941587b091301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Mon, 18 Sep 2023 17:16:27 -0700 Subject: [PATCH 05/31] Buffer passes the most basic test --- stable_baselines3/common/recurrent/buffers.py | 112 ++++++++++++------ .../common/recurrent/type_aliases.py | 44 +++---- tests/test_buffers.py | 52 ++++++-- 3 files changed, 130 insertions(+), 78 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index ac73dd6d6..ccae73f57 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -1,17 +1,18 @@ import dataclasses -from typing import Any, Callable, Generator, Optional, Union +from typing import Any, Callable, Generator, Optional, Tuple, Union import optree as ot import torch as th from gymnasium import spaces +from optree import PyTree from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS from stable_baselines3.common.recurrent.type_aliases import ( HiddenState, PyTreeGeneric, + RecurrentRolloutBufferData, RecurrentRolloutBufferSamples, - space_to_example, ) from stable_baselines3.common.vec_env import VecNormalize @@ -26,6 +27,33 @@ def index_into_pytree( return ot.tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) +def space_to_example( + batch_shape: Tuple[int, ...], + space: spaces.Space, + *, + device: Optional[th.device] = None, + ensure_non_batch_dim: bool = False, +) -> PyTree[th.Tensor]: + if isinstance(space, spaces.Dict): + return { + k: space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) + for k, v in space.items() + } + if isinstance(space, spaces.Tuple): + return tuple(space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for v in space) + + if isinstance(space, spaces.Box): + space_shape = space.shape + elif isinstance(space, spaces.Discrete): + space_shape = () + else: + raise TypeError(f"Unknown space type {type(space)} for {space}") + + if ensure_non_batch_dim and space_shape: + space_shape = (1,) + return th.zeros((*batch_shape, *space_shape), dtype=th.float32, device=device) + + class RecurrentRolloutBuffer(RolloutBuffer): """ Rollout buffer that also stores the LSTM cell and hidden states. @@ -42,6 +70,10 @@ class RecurrentRolloutBuffer(RolloutBuffer): :param n_envs: Number of parallel environments """ + advantages: th.Tensor + returns: th.Tensor + data: RecurrentRolloutBufferData + def __init__( self, buffer_size: int, @@ -53,34 +85,39 @@ def __init__( gamma: float = 0.99, n_envs: int = 1, ): - self.hidden_state_example = ot.tree_map( - lambda x: th.zeros((), dtype=x.dtype, device=device).expand_as(x), hidden_state_example - ) - super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) + super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device=device, n_envs=n_envs) + self.hidden_state_example = hidden_state_example + self.gae_lambda = gae_lambda + self.gamma = gamma batch_shape = (self.buffer_size, self.n_envs) - action_dim = self.action_dim - observation_space = self.observation_space - hidden_state_example = self.hidden_state_example device = self.device - self.data = RecurrentRolloutBufferSamples( - observations=space_to_example(batch_shape, observation_space, device=device, ensure_non_batch_dim=True), - actions=th.zeros((*batch_shape, action_dim), dtype=th.float32, device=device), - old_values=th.zeros(batch_shape, dtype=th.float32, device=device), - old_log_probs=th.zeros(batch_shape, dtype=th.float32, device=device), - advantages=th.zeros(batch_shape, dtype=th.float32, device=device), - returns=th.zeros(batch_shape, dtype=th.float32, device=device), + self.hidden_state_example = ot.tree_map( + lambda x: th.zeros((), dtype=x.dtype, device=device).expand_as(x), hidden_state_example + ) + self.advantages = th.zeros(batch_shape, dtype=th.float32, device=device) + self.returns = th.zeros(batch_shape, dtype=th.float32, device=device) + self.data = RecurrentRolloutBufferData( + observations=space_to_example(batch_shape, self.observation_space, device=device, ensure_non_batch_dim=True), + actions=th.zeros((*batch_shape, self.action_dim), dtype=th.float32, device=device), + rewards=th.zeros(batch_shape, dtype=th.float32, device=device), + episode_starts=th.zeros(batch_shape, dtype=th.float32, device=device), + values=th.zeros(batch_shape, dtype=th.float32, device=device), + log_probs=th.zeros(batch_shape, dtype=th.float32, device=device), hidden_states=ot.tree_map( - lambda x: th.zeros( - (*batch_shape[:-1], x.shape[0], batch_shape[-1], x.shape[1:]), dtype=x.dtype, device=device - ), + lambda x: th.zeros(self._reshape_hidden_state_shape(batch_shape, x.shape), dtype=x.dtype, device=device), hidden_state_example, + namespace=NS, ), - episode_starts=th.zeros(batch_shape, dtype=th.float32, device=device), - rewards=th.zeros(batch_shape, dtype=th.float32, device=device), ) + @staticmethod + def _reshape_hidden_state_shape(batch_shape: Tuple[int, ...], state_shape: Tuple[int, ...]) -> Tuple[int, ...]: + if len(state_shape) < 2: + raise NotImplementedError("State shape must be 2+ dimensions currently") + return (*batch_shape[:-1], state_shape[0], batch_shape[-1], *state_shape[1:]) + # Expose attributes of the RecurrentRolloutBufferData in the top-level to conform to the RolloutBuffer interface @property def episode_starts(self) -> th.Tensor: @@ -95,19 +132,9 @@ def rewards(self) -> th.Tensor: assert self.data.rewards is not None, "RecurrentRolloutBufferData should store rewards" return self.data.rewards - @property - def advantages(self) -> th.Tensor: - return self.data.advantages - - @property - def returns(self) -> th.Tensor: - return self.data.returns - - @returns.setter - def _set_returns(self, new_returns: th.Tensor): - self.data.returns.copy_(new_returns, non_blocking=True) - def reset(self): + self.returns.zero_() + self.advantages.zero_() ot.tree_map(lambda x: x.zero_(), self.data, namespace=NS) super(RolloutBuffer, self).reset() @@ -127,20 +154,20 @@ def _is_list(t): for i in range(len_tensors): self.add(*index_into_pytree(i, args, is_leaf=_is_list, namespace=NS)) - def add(self, data: RecurrentRolloutBufferSamples, **kwargs) -> None: + def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ if data.rewards is None: raise ValueError("Recorded samples must contain a reward") new_data = dataclasses.replace(data, actions=data.actions.reshape((self.n_envs, self.action_dim))) - ot.tree_map(lambda buf, x: buf[self.pos].copy_(x, non_blocking=True), self.data, new_data) + ot.tree_map(lambda buf, x: buf[self.pos].copy_(x, non_blocking=True), self.data, new_data, namespace=NS) # Increment pos self.pos += 1 if self.pos == self.buffer_size: self.full = True - def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferData, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" # Return everything, don't create minibatches @@ -156,5 +183,14 @@ def _get_samples( batch_inds: Union[slice, th.Tensor], env: Optional[VecNormalize] = None, ) -> RecurrentRolloutBufferSamples: - data_without_reward: RecurrentRolloutBufferSamples = dataclasses.replace(self.data, rewards=None) - return ot.tree_map(lambda tens: self.to_device(tens[batch_inds]), data_without_reward) + samples = RecurrentRolloutBufferSamples( + observations=self.data.observations, + actions=self.data.actions, + episode_starts=self.data.episode_starts, + old_values=self.data.values, + old_log_probs=self.data.log_probs, + advantages=self.advantages, + returns=self.returns, + hidden_states=self.data.hidden_states, + ) + return ot.tree_map(lambda tens: self.to_device(tens[batch_inds]), samples, namespace=NS) diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index a181de398..23719762c 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -12,38 +12,24 @@ PyTreeGeneric = TypeVar("PyTreeGeneric", bound=PyTree) -def space_to_example( - batch_shape: Tuple[int, ...], - space: spaces.Space, - *, - device: Optional[th.device] = None, - ensure_non_batch_dim: bool = False, -) -> PyTree[th.Tensor]: - if isinstance(space, spaces.Dict): - return {k: space_to_example(v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for k, v in space.items()} - if isinstance(space, spaces.Tuple): - return tuple(space_to_example(v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for v in space) - - if isinstance(space, spaces.Box): - space_shape = space.shape - elif isinstance(space, spaces.Discrete): - space_shape = () - else: - raise TypeError(f"Unknown space type {type(space)} for {space}") - - if ensure_non_batch_dim and space_shape: - space_shape = (1,) - return th.zeros((*batch_shape, *space_shape), dtype=th.float32, device=device) - - -@functools.partial(dataclass_frozen_pytree, frozen=False) +@dataclass_frozen_pytree +class RecurrentRolloutBufferData: + observations: PyTree[th.Tensor] + actions: th.Tensor + rewards: th.Tensor + episode_starts: th.Tensor + values: th.Tensor + log_probs: th.Tensor + hidden_states: HiddenState + + +@dataclass_frozen_pytree class RecurrentRolloutBufferSamples: observations: PyTree[th.Tensor] actions: th.Tensor + episode_starts: th.Tensor old_values: th.Tensor - old_log_prob: th.Tensor + old_log_probs: th.Tensor + hidden_states: HiddenState advantages: th.Tensor returns: th.Tensor - hidden_states: HiddenState - episode_starts: th.Tensor - rewards: Optional[th.Tensor] = None diff --git a/tests/test_buffers.py b/tests/test_buffers.py index b84f00e6b..06ce325d4 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -1,13 +1,30 @@ import gymnasium as gym import numpy as np +import optree as ot import pytest import torch as th from gymnasium import spaces -from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer +from stable_baselines3.common.buffers import ( + DictReplayBuffer, + DictRolloutBuffer, + ReplayBuffer, + RolloutBuffer, +) from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples +from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE +from stable_baselines3.common.recurrent.buffers import ( + RecurrentRolloutBuffer, + RecurrentRolloutBufferData, +) +from stable_baselines3.common.recurrent.type_aliases import ( + RecurrentRolloutBufferSamples, +) +from stable_baselines3.common.type_aliases import ( + DictReplayBufferSamples, + ReplayBufferSamples, +) from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -108,7 +125,9 @@ def test_replay_buffer_normalization(replay_buffer_cls): assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) -@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) +@pytest.mark.parametrize( + "replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, RecurrentRolloutBuffer] +) @pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) def test_device_buffer(replay_buffer_cls, device): if device == "cuda" and not th.cuda.is_available(): @@ -119,10 +138,17 @@ def test_device_buffer(replay_buffer_cls, device): DictRolloutBuffer: DummyDictEnv, ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv, + RecurrentRolloutBuffer: DummyDictEnv, }[replay_buffer_cls] env = make_vec_env(env) - buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) + if replay_buffer_cls == RecurrentRolloutBuffer: + hidden_states = {"a": {"b": th.zeros(2, 4)}} + buffer = RecurrentRolloutBuffer( + 100, env.observation_space, env.action_space, hidden_state_example=hidden_states, device=device + ) + else: + buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) # Interract and store transitions obs = env.reset() @@ -133,21 +159,25 @@ def test_device_buffer(replay_buffer_cls, device): if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1) buffer.add(obs, action, reward, episode_start, values, log_prob) + elif replay_buffer_cls == RecurrentRolloutBuffer: + episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1) + hidden_states = {"a": {"b": th.zeros(2, buffer.n_envs, 4)}} + buffer.add(RecurrentRolloutBufferData(obs, action, reward, episode_start, values, log_prob, hidden_states)) else: buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs # Get data from the buffer - if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer, RecurrentRolloutBuffer]: data = buffer.get(50) elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: - data = buffer.sample(50) + data = [buffer.sample(50)] # Check that all data are on the desired device desired_device = get_device(device).type - for value in list(data): - if isinstance(value, dict): - for key in value.keys(): - assert value[key].device.type == desired_device - elif isinstance(value, th.Tensor): + for minibatch in list(data): + flattened_tensors, _ = ot.tree_flatten(minibatch, namespace=OT_NAMESPACE) + assert len(flattened_tensors) > 3 + for value in flattened_tensors: + assert isinstance(value, th.Tensor) assert value.device.type == desired_device From a190a7ac4c6756561e3f9471f60914f60390dfe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Mon, 18 Sep 2023 17:16:38 -0700 Subject: [PATCH 06/31] Fixed typing in `ppo_recurrent` --- .../ppo_recurrent/ppo_recurrent.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index acd44c9c1..42837dcff 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -1,25 +1,36 @@ import sys import time from copy import deepcopy -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Type, Union import numpy as np import torch as th from gymnasium import spaces +from typing_extensions import Self + from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.recurrent.buffers import ( + RecurrentDictRolloutBuffer, + RecurrentRolloutBuffer, +) +from stable_baselines3.common.recurrent.policies import RecurrentActorCriticPolicy +from stable_baselines3.common.recurrent.type_aliases import RNNStates from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.utils import ( + explained_variance, + get_schedule_fn, + obs_as_tensor, + safe_mean, +) from stable_baselines3.common.vec_env import VecEnv - -from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer -from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy -from sb3_contrib.common.recurrent.type_aliases import RNNStates -from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy - -SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO") +from stable_baselines3.ppo_recurrent.policies import ( + CnnLstmPolicy, + MlpLstmPolicy, + MultiInputLstmPolicy, +) class RecurrentPPO(OnPolicyAlgorithm): @@ -445,14 +456,14 @@ def train(self) -> None: self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: SelfRecurrentPPO, + self: Self["RecurrentPPO"], total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "RecurrentPPO", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> SelfRecurrentPPO: + ) -> Self["RecurrentPPO"]: iteration = 0 total_timesteps, callback = self._setup_learn( From 71f7904156c974a2e7de2126aaf4bed50673787b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 11:28:23 -0700 Subject: [PATCH 07/31] Basic tests pass but still not properly recurrent --- stable_baselines3/common/buffers.py | 7 +- stable_baselines3/common/recurrent/buffers.py | 23 +++- .../common/recurrent/policies.py | 76 +++++++---- .../common/recurrent/type_aliases.py | 2 +- stable_baselines3/ppo_recurrent/__init__.py | 8 +- stable_baselines3/ppo_recurrent/policies.py | 2 +- .../ppo_recurrent/ppo_recurrent.py | 121 +++++++++++------- 7 files changed, 152 insertions(+), 87 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 589e707d4..66f343303 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -5,7 +5,6 @@ import numpy as np import torch as th from gymnasium import spaces - from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, @@ -422,7 +421,7 @@ def reset(self) -> None: self.actions = th.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=th.float32, device=self.device) self.rewards = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) self.returns = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) - self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) + self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool, device=self.device) self.values = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) self.log_probs = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) self.advantages = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) @@ -457,7 +456,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: th.Tensor next_non_terminal = ~dones next_values = last_values else: - next_non_terminal = 1.0 - self.episode_starts[step + 1] + next_non_terminal = ~self.episode_starts[step + 1] next_values = self.values[step + 1] delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam @@ -791,7 +790,7 @@ def __init__( self.actions = th.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=th.float32) self.rewards = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) self.returns = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) - self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) + self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool, device=self.device) self.values = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) self.log_probs = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) self.advantages = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index ccae73f57..15a1034c8 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -49,7 +49,7 @@ def space_to_example( else: raise TypeError(f"Unknown space type {type(space)} for {space}") - if ensure_non_batch_dim and space_shape: + if ensure_non_batch_dim and not space_shape: space_shape = (1,) return th.zeros((*batch_shape, *space_shape), dtype=th.float32, device=device) @@ -94,7 +94,7 @@ def __init__( device = self.device self.hidden_state_example = ot.tree_map( - lambda x: th.zeros((), dtype=x.dtype, device=device).expand_as(x), hidden_state_example + lambda x: th.zeros((), dtype=x.dtype, device=device).expand_as(x), hidden_state_example, namespace=NS ) self.advantages = th.zeros(batch_shape, dtype=th.float32, device=device) self.returns = th.zeros(batch_shape, dtype=th.float32, device=device) @@ -102,7 +102,7 @@ def __init__( observations=space_to_example(batch_shape, self.observation_space, device=device, ensure_non_batch_dim=True), actions=th.zeros((*batch_shape, self.action_dim), dtype=th.float32, device=device), rewards=th.zeros(batch_shape, dtype=th.float32, device=device), - episode_starts=th.zeros(batch_shape, dtype=th.float32, device=device), + episode_starts=th.zeros(batch_shape, dtype=th.bool, device=device), values=th.zeros(batch_shape, dtype=th.float32, device=device), log_probs=th.zeros(batch_shape, dtype=th.float32, device=device), hidden_states=ot.tree_map( @@ -125,7 +125,7 @@ def episode_starts(self) -> th.Tensor: @property def values(self) -> th.Tensor: - return self.data.old_values + return self.data.values @property def rewards(self) -> th.Tensor: @@ -175,8 +175,17 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf yield self._get_samples(slice(None)) return - for start_idx in range(0, self.buffer_size * self.n_envs, batch_size): - yield self._get_samples(slice(start_idx, start_idx + batch_size, None)) + if batch_size % self.n_envs != 0 or batch_size < self.n_envs: + raise ValueError( + f"The batch size must be a multiple of the number of environments (n_envs={self.n_envs}) ", + f"but batch_size={batch_size}.", + ) + batch_size //= self.n_envs + + for start_idx in range(0, self.buffer_size, batch_size): + out = self._get_samples(slice(start_idx, start_idx + batch_size, None)) + assert len(out.observations) != 0 + yield out def _get_samples( self, @@ -188,7 +197,7 @@ def _get_samples( actions=self.data.actions, episode_starts=self.data.episode_starts, old_values=self.data.values, - old_log_probs=self.data.log_probs, + old_log_prob=self.data.log_probs, advantages=self.advantages, returns=self.returns, hidden_states=self.data.hidden_states, diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 5e5090b2e..1296ebe69 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -1,10 +1,16 @@ +import math from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np +import optree as ot import torch as th from gymnasium import spaces +from torch import nn + from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree +from stable_baselines3.common.recurrent.buffers import space_to_example from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -12,11 +18,15 @@ MlpExtractor, NatureCNN, ) -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import Schedule, TorchGymObs from stable_baselines3.common.utils import zip_strict -from torch import nn +from tests.test_buffers import OT_NAMESPACE as NS -from sb3_contrib.common.recurrent.type_aliases import RNNStates + +@dataclass_frozen_pytree +class LSTMStates: + pi: th.Tensor + vf: th.Tensor class RecurrentActorCriticPolicy(ActorCriticPolicy): @@ -159,6 +169,8 @@ def _build_mlp_extractor(self) -> None: device=self.device, ) + self.observation_example = space_to_example((), self.observation_space) + @staticmethod def _process_sequence( features: th.Tensor, @@ -200,8 +212,8 @@ def _process_sequence( features.unsqueeze(dim=0), ( # Reset the states at the beginning of a new episode - (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0], - (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1], + (~episode_start).view(1, n_seq, 1) * lstm_states[0], + (~episode_start).view(1, n_seq, 1) * lstm_states[1], ), ) lstm_output += [hidden] @@ -213,10 +225,10 @@ def _process_sequence( def forward( self, obs: th.Tensor, - lstm_states: RNNStates, + lstm_states: LSTMStates, episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, LSTMStates]: """ Forward pass in all the networks (actor and critic) @@ -254,7 +266,7 @@ def forward( distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) - return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) + return actions, values, log_prob, LSTMStates(lstm_states_pi, lstm_states_vf) def get_distribution( self, @@ -307,8 +319,20 @@ def predict_values( latent_vf = self.mlp_extractor.forward_critic(latent_vf) return self.value_net(latent_vf) + def extract_features(self, obs: TorchGymObs) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: + obs_flat = ot.tree_map(lambda x, x_nobatch: x.view(-1, *x_nobatch.shape), obs, self.observation_example, namespace=NS) + obs_batch_shapes = ot.tree_map( + lambda x, x_nobatch: x.shape[: x.ndim - x_nobatch.ndim], obs, self.observation_example, namespace=NS + ) + + (obs_batch_shape, *_), _ = ot.tree_flatten(obs_batch_shapes, namespace=NS) + + features_flat = super().extract_features(obs_flat) + features = ot.tree_map(lambda x: x.view(*obs_batch_shape, *x.shape[1:]), features_flat, namespace=NS) + return features + def evaluate_actions( - self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor + self, obs: th.Tensor, actions: th.Tensor, lstm_states: LSTMStates, episode_starts: th.Tensor ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Evaluate actions according to the current policy, @@ -336,13 +360,18 @@ def evaluate_actions( else: latent_vf = self.critic(vf_features) + features_batch_shape = pi_features.shape[:-1] latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) - return values, log_prob, distribution.entropy() + return ( + values.view(features_batch_shape), + log_prob.view(features_batch_shape), + distribution.entropy().view(features_batch_shape), + ) def _predict( self, @@ -366,11 +395,11 @@ def _predict( def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, - episode_start: Optional[np.ndarray] = None, + observation: Union[th.Tensor, Dict[str, th.Tensor]], + state: Optional[Tuple[th.Tensor, ...]] = None, + episode_start: Optional[th.Tensor] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> Tuple[th.Tensor, Optional[Tuple[th.Tensor, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -395,25 +424,16 @@ def predict( # state : (n_layers, n_envs, dim) if state is None: # Initialize hidden states to zeros - state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) - state = (state, state) + state_component = th.cat([th.zeros(self.lstm_hidden_state_shape)] * n_envs, axis=1) + state = (state_component, state_component) if episode_start is None: - episode_start = np.array([False for _ in range(n_envs)]) + episode_start = th.zeros((n_envs,), dtype=th.bool) with th.no_grad(): - # Convert to PyTorch tensors - states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( - state[1], dtype=th.float32, device=self.device - ) - episode_starts = th.tensor(episode_start, dtype=th.float32, device=self.device) actions, states = self._predict( - observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic + observation, lstm_states=state, episode_starts=episode_start, deterministic=deterministic ) - states = (states[0].cpu().numpy(), states[1].cpu().numpy()) - - # Convert to numpy - actions = actions.cpu().numpy() if isinstance(self.action_space, spaces.Box): if self.squash_output: @@ -422,7 +442,7 @@ def predict( else: # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = np.clip(actions, self.action_space.low, self.action_space.high) + actions = th.clip(actions, th.as_tensor(self.action_space.low), th.as_tensor(self.action_space.high)) # Remove batch dimension if needed if not vectorized_env: diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index 23719762c..3707deffc 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -29,7 +29,7 @@ class RecurrentRolloutBufferSamples: actions: th.Tensor episode_starts: th.Tensor old_values: th.Tensor - old_log_probs: th.Tensor + old_log_prob: th.Tensor hidden_states: HiddenState advantages: th.Tensor returns: th.Tensor diff --git a/stable_baselines3/ppo_recurrent/__init__.py b/stable_baselines3/ppo_recurrent/__init__.py index f8301048b..c1d93f43b 100644 --- a/stable_baselines3/ppo_recurrent/__init__.py +++ b/stable_baselines3/ppo_recurrent/__init__.py @@ -1,4 +1,8 @@ -from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy -from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO +from stable_baselines3.ppo_recurrent.policies import ( + CnnLstmPolicy, + MlpLstmPolicy, + MultiInputLstmPolicy, +) +from stable_baselines3.ppo_recurrent.ppo_recurrent import RecurrentPPO __all__ = ["CnnLstmPolicy", "MlpLstmPolicy", "MultiInputLstmPolicy", "RecurrentPPO"] diff --git a/stable_baselines3/ppo_recurrent/policies.py b/stable_baselines3/ppo_recurrent/policies.py index d9b374582..95ab3dd50 100644 --- a/stable_baselines3/ppo_recurrent/policies.py +++ b/stable_baselines3/ppo_recurrent/policies.py @@ -1,4 +1,4 @@ -from sb3_contrib.common.recurrent.policies import ( +from stable_baselines3.common.recurrent.policies import ( RecurrentActorCriticCnnPolicy, RecurrentActorCriticPolicy, RecurrentMultiInputActorCriticPolicy, diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 42837dcff..4ee07119f 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -1,10 +1,12 @@ import sys import time +import warnings from copy import deepcopy from typing import Any, ClassVar, Dict, Optional, Type, Union import numpy as np import torch as th +import torch.nn.functional as F from gymnasium import spaces from typing_extensions import Self @@ -13,11 +15,14 @@ from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.recurrent.buffers import ( - RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, + index_into_pytree, ) -from stable_baselines3.common.recurrent.policies import RecurrentActorCriticPolicy -from stable_baselines3.common.recurrent.type_aliases import RNNStates +from stable_baselines3.common.recurrent.policies import ( + LSTMStates, + RecurrentActorCriticPolicy, +) +from stable_baselines3.common.recurrent.type_aliases import RecurrentRolloutBufferData from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import ( explained_variance, @@ -138,6 +143,31 @@ def __init__( ), ) + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert buffer_size > 1 or ( + not normalize_advantage + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) self.batch_size = batch_size self.n_epochs = n_epochs self.clip_range = clip_range @@ -153,8 +183,6 @@ def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer - self.policy = self.policy_class( self.observation_space, self.action_space, @@ -173,7 +201,7 @@ def _setup_model(self) -> None: single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) # hidden and cell states for actor and critic - self._last_lstm_states = RNNStates( + self._last_lstm_states = LSTMStates( ( th.zeros(single_hidden_state_shape, device=self.device), th.zeros(single_hidden_state_shape, device=self.device), @@ -184,13 +212,23 @@ def _setup_model(self) -> None: ), ) - hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + single_1envhidden_state_shape = (lstm.num_layers, lstm.hidden_size) + example_lstm_states = LSTMStates( + ( + th.zeros(single_1envhidden_state_shape, device=self.device), + th.zeros(single_1envhidden_state_shape, device=self.device), + ), + ( + th.zeros(single_1envhidden_state_shape, device=self.device), + th.zeros(single_1envhidden_state_shape, device=self.device), + ), + ) - self.rollout_buffer = buffer_cls( + self.rollout_buffer = RecurrentRolloutBuffer( self.n_steps, self.observation_space, self.action_space, - hidden_state_buffer_shape, + example_lstm_states, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, @@ -225,9 +263,7 @@ def collect_rollouts( :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ - assert isinstance( - rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) - ), f"{rollout_buffer} doesn't support recurrent policy" + assert isinstance(rollout_buffer, RecurrentRolloutBuffer), f"{rollout_buffer} doesn't support recurrent policy" assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) @@ -251,11 +287,9 @@ def collect_rollouts( with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device) + episode_starts = th.as_tensor(self._last_episode_starts).to(dtype=th.bool, device=self.device) actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) - actions = actions.cpu().numpy() - # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error @@ -293,18 +327,22 @@ def collect_rollouts( lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), ) # terminal_lstm_state = None - episode_starts = th.tensor([False], dtype=th.float32, device=self.device) - terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] + episode_starts = th.zeros((1,), dtype=th.bool, device=self.device) + terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[ + 0 + ].squeeze(-1) rewards[idx] += self.gamma * terminal_value rollout_buffer.add( - self._last_obs, - actions, - rewards, - self._last_episode_starts, - values, - log_probs, - lstm_states=self._last_lstm_states, + RecurrentRolloutBufferData( + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values.squeeze(-1), + log_probs, + hidden_states=self._last_lstm_states, + ) ) self._last_obs = new_obs @@ -313,7 +351,7 @@ def collect_rollouts( with th.no_grad(): # Compute value for the last timestep - episode_starts = th.tensor(dones, dtype=th.float32, device=self.device) + episode_starts = th.as_tensor(dones).to(dtype=th.bool, device=self.device) values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -352,9 +390,6 @@ def train(self) -> None: # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() - # Convert mask from float to bool - mask = rollout_data.mask > 1e-8 - # Re-sample the noise matrix because the log_std has changed if self.use_sde: self.policy.reset_noise(self.batch_size) @@ -362,15 +397,15 @@ def train(self) -> None: values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions, - rollout_data.lstm_states, + index_into_pytree(0, rollout_data.hidden_states), rollout_data.episode_starts, ) - values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - if self.normalize_advantage: - advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if self.normalize_advantage and len(advantages) > 1: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) @@ -378,34 +413,32 @@ def train(self) -> None: # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) + policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: - # Clip the different between old and new value + # Clip the difference between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf ) # Value loss using the TD(gae_lambda) target - # Mask padded sequences - value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) - + value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob[mask]) + entropy_loss = -th.mean(-log_prob) else: - entropy_loss = -th.mean(entropy[mask]) + entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) @@ -417,7 +450,7 @@ def train(self) -> None: # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: @@ -446,7 +479,7 @@ def train(self) -> None: self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) - self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/explained_variance", explained_var.item()) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) @@ -456,14 +489,14 @@ def train(self) -> None: self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: Self["RecurrentPPO"], + self: Self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "RecurrentPPO", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> Self["RecurrentPPO"]: + ) -> Self: iteration = 0 total_timesteps, callback = self._setup_learn( From 554a74c1d8d8a05d6b7b61fcd83f863632a67009 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 11:36:34 -0700 Subject: [PATCH 08/31] no more circular imports --- stable_baselines3/common/recurrent/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 1296ebe69..e9ddd0e0d 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -9,6 +9,7 @@ from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree from stable_baselines3.common.recurrent.buffers import space_to_example from stable_baselines3.common.torch_layers import ( @@ -20,7 +21,6 @@ ) from stable_baselines3.common.type_aliases import Schedule, TorchGymObs from stable_baselines3.common.utils import zip_strict -from tests.test_buffers import OT_NAMESPACE as NS @dataclass_frozen_pytree From cb1e271cc75d5e262fd81062a76a0c5e4abb3749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 11:58:45 -0700 Subject: [PATCH 09/31] Introduce MlpPolicy and stuff --- stable_baselines3/ppo_recurrent/ppo_recurrent.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 4ee07119f..4c9a6bd53 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -87,6 +87,9 @@ class RecurrentPPO(OnPolicyAlgorithm): "MlpLstmPolicy": MlpLstmPolicy, "CnnLstmPolicy": CnnLstmPolicy, "MultiInputLstmPolicy": MultiInputLstmPolicy, + "MlpPolicy": MlpLstmPolicy, + "CnnPolicy": CnnLstmPolicy, + "MultiInputPolicy": MultiInputLstmPolicy, } def __init__( From 55c1b62682f8d3c8e9fbb9ebf7fa8789ba20f32f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 12:06:15 -0700 Subject: [PATCH 10/31] OnPolicyAlgorihtm in some places --- stable_baselines3/common/buffers.py | 1 + tests/test_cnn.py | 5 ++--- tests/test_dict_env.py | 11 ++++++----- tests/test_identity.py | 3 ++- tests/test_train_eval_mode.py | 1 + 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 66f343303..188717a89 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -5,6 +5,7 @@ import numpy as np import torch as th from gymnasium import spaces + from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, diff --git a/tests/test_cnn.py b/tests/test_cnn.py index c7bb1a31e..76c747cd6 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -8,6 +8,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3, RecurrentPPO from stable_baselines3.common.envs import FakeImageEnv +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.preprocessing import ( is_image_space, is_image_space_channels_first, @@ -351,9 +352,7 @@ def test_image_like_input(model_class, normalize_images): ), seed=1, ) - policy = "CnnLstmPolicy" if model_class == RecurrentPPO else "CnnPolicy" - - if model_class in {A2C, PPO}: + if issubclass(model_class, OnPolicyAlgorithm): kwargs.update(dict(n_steps=64)) else: # Avoid memory error when using replay buffer diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index fd2045ac7..320f8e66f 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -10,6 +10,7 @@ from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.vec_env import ( DummyVecEnv, SubprocVecEnv, @@ -132,7 +133,7 @@ def test_consistency(model_class): kwargs = {} n_steps = 256 - if model_class in {A2C, PPO}: + if issubclass(model_class, OnPolicyAlgorithm): kwargs = dict( n_steps=128, ) @@ -176,7 +177,7 @@ def test_dict_spaces(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {A2C, PPO}: + if issubclass(model_class, OnPolicyAlgorithm): kwargs = dict( n_steps=128, policy_kwargs=dict( @@ -220,7 +221,7 @@ def make_env(): kwargs = {} n_steps = 128 - if model_class in {A2C, PPO}: + if issubclass(model_class, OnPolicyAlgorithm): kwargs = dict( n_steps=128, policy_kwargs=dict( @@ -261,7 +262,7 @@ def test_dict_vec_framestack(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {A2C, PPO}: + if issubclass(model_class, OnPolicyAlgorithm): kwargs = dict( n_steps=128, policy_kwargs=dict( @@ -303,7 +304,7 @@ def test_vec_normalize(model_class): kwargs = {} n_steps = 256 - if model_class in {A2C, PPO}: + if issubclass(model_class, OnPolicyAlgorithm): kwargs = dict( n_steps=128, policy_kwargs=dict( diff --git a/tests/test_identity.py b/tests/test_identity.py index a4bf1e327..aa65182e2 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -10,6 +10,7 @@ ) from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.noise import NormalActionNoise +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.vec_env import DummyVecEnv DIM = 4 @@ -39,7 +40,7 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = {A2C: 2000, PPO: 2000, SAC: 400, TD3: 400, DDPG: 400}[model_class] + n_steps = 2000 if issubclass(model_class, OnPolicyAlgorithm) else 400 kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index c9c8c6dde..aade0bd12 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -127,6 +127,7 @@ def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> Tuple[th.Tensor, th.Te SAC: clone_sac_batch_norm_stats, TD3: clone_td3_batch_norm_stats, PPO: clone_on_policy_batch_norm, + RecurrentPPO: clone_on_policy_batch_norm, } From ef27de4c4d1f46a1931640d8227fde68658548f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 14:15:32 -0700 Subject: [PATCH 11/31] attempt more patches --- stable_baselines3/common/recurrent/buffers.py | 15 +++++++++++++-- stable_baselines3/common/recurrent/policies.py | 13 +++++++------ tests/test_save_load.py | 8 ++++++-- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 15a1034c8..53cc58fc2 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -44,14 +44,19 @@ def space_to_example( if isinstance(space, spaces.Box): space_shape = space.shape + space_dtype = th.float32 elif isinstance(space, spaces.Discrete): space_shape = () + space_dtype = th.long + elif isinstance(space, spaces.MultiDiscrete): + space_shape = (len(space.nvec),) + space_dtype = th.long else: raise TypeError(f"Unknown space type {type(space)} for {space}") if ensure_non_batch_dim and not space_shape: space_shape = (1,) - return th.zeros((*batch_shape, *space_shape), dtype=th.float32, device=device) + return th.zeros((*batch_shape, *space_shape), dtype=space_dtype, device=device) class RecurrentRolloutBuffer(RolloutBuffer): @@ -161,7 +166,13 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: if data.rewards is None: raise ValueError("Recorded samples must contain a reward") new_data = dataclasses.replace(data, actions=data.actions.reshape((self.n_envs, self.action_dim))) - ot.tree_map(lambda buf, x: buf[self.pos].copy_(x, non_blocking=True), self.data, new_data, namespace=NS) + + ot.tree_map( + lambda buf, x: buf[self.pos].copy_(x if x.ndim + 1 == buf.ndim else x.unsqueeze(-1), non_blocking=True), + self.data, + new_data, + namespace=NS, + ) # Increment pos self.pos += 1 if self.pos == self.buffer_size: diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index e9ddd0e0d..a8eefe8e2 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -360,17 +360,18 @@ def evaluate_actions( else: latent_vf = self.critic(vf_features) - features_batch_shape = pi_features.shape[:-1] + features_batch_shape = pi_features.shape[:2] latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) - distribution = self._get_action_dist_from_latent(latent_pi) - log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions.view((-1, *values.shape[2:]))) + entropy = distribution.entropy() return ( - values.view(features_batch_shape), - log_prob.view(features_batch_shape), - distribution.entropy().view(features_batch_shape), + values.view(*features_batch_shape, *values.shape[2:]), + log_prob.view(*features_batch_shape, *log_prob.shape[2:]), + entropy.view(*features_batch_shape, *entropy.shape[2:]), ) def _predict( diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 8161ee44a..2b7954f4a 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -18,6 +18,8 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -182,10 +184,12 @@ def test_set_env(tmp_path, model_class): env4 = DummyVecEnv([lambda: select_env(model_class) for _ in range(2)]) kwargs = {} - if model_class in {DQN, DDPG, SAC, TD3}: + if issubclass(model_class, OffPolicyAlgorithm): kwargs = dict(learning_starts=50, train_freq=4) - elif model_class in {A2C, PPO}: + elif issubclass(model_class, OnPolicyAlgorithm): kwargs = dict(n_steps=64) + else: + raise TypeError(f"Unknown model class: {model_class}") # create model model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), **kwargs) From 5fbb1e4634ea9e5d1010d18ed147fc407516c90e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 14:29:31 -0700 Subject: [PATCH 12/31] which tests pass now? --- stable_baselines3/common/recurrent/buffers.py | 4 ++-- stable_baselines3/common/recurrent/policies.py | 13 +++---------- stable_baselines3/common/recurrent/type_aliases.py | 3 +-- tests/test_buffers.py | 3 --- 4 files changed, 6 insertions(+), 17 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 53cc58fc2..a86d37fd2 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -24,7 +24,7 @@ def index_into_pytree( none_is_leaf: bool = False, namespace: str = NS, ) -> PyTreeGeneric: - return ot.tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + return ot.tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) # type: ignore def space_to_example( @@ -178,7 +178,7 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: if self.pos == self.buffer_size: self.full = True - def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferData, None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" # Return everything, don't create minibatches diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index a8eefe8e2..c9848adf9 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -1,7 +1,5 @@ -import math from typing import Any, Dict, List, Optional, Tuple, Type, Union -import numpy as np import optree as ot import torch as th from gymnasium import spaces @@ -360,19 +358,14 @@ def evaluate_actions( else: latent_vf = self.critic(vf_features) - features_batch_shape = pi_features.shape[:2] latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) - values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) - log_prob = distribution.log_prob(actions.view((-1, *values.shape[2:]))) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) entropy = distribution.entropy() - return ( - values.view(*features_batch_shape, *values.shape[2:]), - log_prob.view(*features_batch_shape, *log_prob.shape[2:]), - entropy.view(*features_batch_shape, *entropy.shape[2:]), - ) + return (values, log_prob, entropy) def _predict( self, diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index 3707deffc..56be23b9d 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -1,7 +1,6 @@ -from typing import Optional, Tuple, TypeVar +from typing import TypeVar import torch as th -from gymnasium import spaces from optree import PyTree from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 06ce325d4..dd6cad5f7 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -18,9 +18,6 @@ RecurrentRolloutBuffer, RecurrentRolloutBufferData, ) -from stable_baselines3.common.recurrent.type_aliases import ( - RecurrentRolloutBufferSamples, -) from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, ReplayBufferSamples, From 522096a340076c786c95b365d38c4c618198bac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 14:39:43 -0700 Subject: [PATCH 13/31] attempt to fix pytype --- stable_baselines3/common/recurrent/buffers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index a86d37fd2..a957db718 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -178,7 +178,7 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: if self.pos == self.buffer_size: self.full = True - def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME assert self.full, "Rollout buffer must be full before sampling from it" # Return everything, don't create minibatches @@ -198,7 +198,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf assert len(out.observations) != 0 yield out - def _get_samples( + def _get_samples( # type: ignore[override] self, batch_inds: Union[slice, th.Tensor], env: Optional[VecNormalize] = None, From 8190d7884ae64d863efed89cfbf6e78d53d03bb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 14:51:25 -0700 Subject: [PATCH 14/31] Store actions as long directly, fewer shape modifications --- stable_baselines3/common/recurrent/buffers.py | 6 +++++- stable_baselines3/common/recurrent/policies.py | 6 ++++-- stable_baselines3/ppo_recurrent/ppo_recurrent.py | 3 --- tests/test_dict_env.py | 3 ++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index a957db718..4d3c521a3 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -105,7 +105,11 @@ def __init__( self.returns = th.zeros(batch_shape, dtype=th.float32, device=device) self.data = RecurrentRolloutBufferData( observations=space_to_example(batch_shape, self.observation_space, device=device, ensure_non_batch_dim=True), - actions=th.zeros((*batch_shape, self.action_dim), dtype=th.float32, device=device), + actions=th.zeros( + (*batch_shape, self.action_dim), + dtype=th.long if isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)) else th.float32, + device=device, + ), rewards=th.zeros(batch_shape, dtype=th.float32, device=device), episode_starts=th.zeros(batch_shape, dtype=th.bool, device=device), values=th.zeros(batch_shape, dtype=th.float32, device=device), diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index c9848adf9..a3f9fe9ab 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -361,11 +361,13 @@ def evaluate_actions( latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) + action_batch_shape = actions.shape[:-1] distribution = self._get_action_dist_from_latent(latent_pi) - log_prob = distribution.log_prob(actions) + distribution_shape = distribution.distribution.batch_shape + distribution.distribution.event_shape + log_prob = distribution.log_prob(actions.view(distribution_shape)) values = self.value_net(latent_vf) entropy = distribution.entropy() - return (values, log_prob, entropy) + return (values.view(action_batch_shape), log_prob.view(action_batch_shape), entropy.view(action_batch_shape)) def _predict( self, diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 4c9a6bd53..4bf70d097 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -389,9 +389,6 @@ def train(self) -> None: # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions - if isinstance(self.action_space, spaces.Discrete): - # Convert discrete action from float to long - actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed if self.use_sde: diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 320f8e66f..072305ec0 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -10,6 +10,7 @@ from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.vec_env import ( DummyVecEnv, @@ -229,7 +230,7 @@ def make_env(): features_extractor_kwargs=dict(cnn_output_dim=32), ), ) - elif model_class in {SAC, TD3, DQN}: + elif issubclass(model_class, OffPolicyAlgorithm): kwargs = dict( buffer_size=1000, policy_kwargs=dict( From 415dcb040e2b70b42e706f64318b3a9865ff0c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Tue, 19 Sep 2023 15:06:21 -0700 Subject: [PATCH 15/31] Fix spaces/shapes, increase parallelism --- .circleci/config.yml | 2 +- stable_baselines3/common/recurrent/buffers.py | 3 +++ stable_baselines3/common/recurrent/policies.py | 2 +- stable_baselines3/ppo_recurrent/ppo_recurrent.py | 11 ++++++----- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 70f122604..2c5dae597 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -71,7 +71,7 @@ jobs: password: "$GHCR_DOCKER_TOKEN" resource_class: medium working_directory: /workspace/third_party/stable-baselines3 - parallelism: 16 + parallelism: 24 steps: - checkout - run: diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 4d3c521a3..de61eeb30 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -51,6 +51,9 @@ def space_to_example( elif isinstance(space, spaces.MultiDiscrete): space_shape = (len(space.nvec),) space_dtype = th.long + elif isinstance(space, spaces.MultiBinary): + space_shape = (space.n,) + space_dtype = th.float32 else: raise TypeError(f"Unknown space type {type(space)} for {space}") diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index a3f9fe9ab..824b6ae79 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -363,7 +363,7 @@ def evaluate_actions( action_batch_shape = actions.shape[:-1] distribution = self._get_action_dist_from_latent(latent_pi) - distribution_shape = distribution.distribution.batch_shape + distribution.distribution.event_shape + distribution_shape = distribution.mode().shape # FIXME: don't instantiate a tensor for a shape log_prob = distribution.log_prob(actions.view(distribution_shape)) values = self.value_net(latent_vf) entropy = distribution.entropy() diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 4bf70d097..afa3a40be 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -2,13 +2,12 @@ import time import warnings from copy import deepcopy -from typing import Any, ClassVar, Dict, Optional, Type, Union +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union import numpy as np import torch as th import torch.nn.functional as F from gymnasium import spaces -from typing_extensions import Self from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback @@ -37,6 +36,8 @@ MultiInputLstmPolicy, ) +SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO") + class RecurrentPPO(OnPolicyAlgorithm): """ @@ -150,7 +151,7 @@ def __init__( # because of the advantage normalization if normalize_advantage: assert ( - batch_size > 1 + batch_size is None or batch_size > 1 ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: @@ -489,14 +490,14 @@ def train(self) -> None: self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self: Self, + self: SelfRecurrentPPO, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, tb_log_name: str = "RecurrentPPO", reset_num_timesteps: bool = True, progress_bar: bool = False, - ) -> Self: + ) -> SelfRecurrentPPO: iteration = 0 total_timesteps, callback = self._setup_learn( From e16f00854782864c845fc7438d63c7972bd07381 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 16:29:56 -0700 Subject: [PATCH 16/31] Batch size can't be None --- stable_baselines3/ppo_recurrent/ppo_recurrent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index afa3a40be..b6781b579 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -99,7 +99,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 128, - batch_size: Optional[int] = 128, + batch_size: int = 128, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, @@ -146,12 +146,11 @@ def __init__( spaces.MultiBinary, ), ) - # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization if normalize_advantage: assert ( - batch_size is None or batch_size > 1 + batch_size > 1 ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: From 04b990b9a95cb7a94079868f2a0a80720dea5b70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 16:31:00 -0700 Subject: [PATCH 17/31] Unsqueeze and re-squeeze around calling policy --- .../common/recurrent/policies.py | 33 ++++++++----------- .../ppo_recurrent/ppo_recurrent.py | 29 +++++++++++----- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index 824b6ae79..c39719e1b 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -116,6 +116,7 @@ def __init__( ) self.lstm_kwargs = lstm_kwargs or {} + assert not self.lstm_kwargs.get("batch_first", False) self.shared_lstm = shared_lstm self.enable_critic_lstm = enable_critic_lstm self.lstm_actor = nn.LSTM( @@ -186,38 +187,30 @@ def _process_sequence( :param lstm: LSTM object. :return: LSTM output and updated LSTM states. """ - # LSTM logic - # (sequence length, batch size, features dim) - # (batch size = n_envs for data collection or n_seq when doing gradient update) - n_seq = lstm_states[0].shape[1] - # Batch to sequence - # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) - # note: max length (max sequence length) is always 1 during data collection - features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1) - episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1) # If we don't have to reset the state in the middle of a sequence # we can avoid the for loop, which speeds up things - if th.all(episode_starts == 0.0): - lstm_output, lstm_states = lstm(features_sequence, lstm_states) - lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1) + assert episode_starts.ndim == 2 + if not th.any(episode_starts[1:]): + initial_is_not_reset = (~episode_starts[0]).unsqueeze(-1) + lstm_states = (lstm_states[0] * initial_is_not_reset, lstm_states[1] * initial_is_not_reset) + lstm_output, lstm_states = lstm(features, lstm_states) return lstm_output, lstm_states lstm_output = [] # Iterate over the sequence - for features, episode_start in zip_strict(features_sequence, episode_starts): + for features, episode_start in zip_strict(features, episode_starts): + is_not_reset = (~episode_start).unsqueeze(-1) hidden, lstm_states = lstm( - features.unsqueeze(dim=0), + features, ( # Reset the states at the beginning of a new episode - (~episode_start).view(1, n_seq, 1) * lstm_states[0], - (~episode_start).view(1, n_seq, 1) * lstm_states[1], + is_not_reset * lstm_states[0], + is_not_reset * lstm_states[1], ), ) - lstm_output += [hidden] - # Sequence to batch - # (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim) - lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) + lstm_output.append(hidden) + lstm_output = th.cat(lstm_output) return lstm_output, lstm_states def forward( diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index b6781b579..b44161618 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -5,6 +5,7 @@ from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union import numpy as np +import optree as ot import torch as th import torch.nn.functional as F from gymnasium import spaces @@ -13,6 +14,7 @@ from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS from stable_baselines3.common.recurrent.buffers import ( RecurrentRolloutBuffer, index_into_pytree, @@ -289,15 +291,20 @@ def collect_rollouts( with th.no_grad(): # Convert to pytorch tensor or to TensorDict - obs_tensor = obs_as_tensor(self._last_obs, self.device) + obs_tensor = ot.tree_map(lambda x: x.unsqueeze(0), obs_as_tensor(self._last_obs, self.device), namespace=NS) episode_starts = th.as_tensor(self._last_episode_starts).to(dtype=th.bool, device=self.device) - actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) + actions, values, log_probs, lstm_states = self.policy.forward( + obs_tensor, lstm_states, episode_starts.unsqueeze(0) + ) + actions = actions.squeeze(0) + values = values.squeeze(0) + log_probs = log_probs.squeeze(0) # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.action_space, spaces.Box): - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + clipped_actions = th.clip(actions, th.as_tensor(self.action_space.low), th.as_tensor(self.action_space.high)) new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -323,7 +330,7 @@ def collect_rollouts( and infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False) ): - terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + terminal_obs, _ = self.policy.obs_to_tensor(infos[idx]["terminal_observation"]) with th.no_grad(): terminal_lstm_state = ( lstm_states.vf[0][:, idx : idx + 1, :].contiguous(), @@ -331,9 +338,11 @@ def collect_rollouts( ) # terminal_lstm_state = None episode_starts = th.zeros((1,), dtype=th.bool, device=self.device) - terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[ - 0 - ].squeeze(-1) + terminal_value = self.policy.predict_values( + ot.tree_map(lambda x: x.unsqueeze(0), terminal_obs, namespace=NS), + terminal_lstm_state, + episode_starts.unsqueeze(0), + )[0].squeeze(0) rewards[idx] += self.gamma * terminal_value rollout_buffer.add( @@ -355,7 +364,11 @@ def collect_rollouts( with th.no_grad(): # Compute value for the last timestep episode_starts = th.as_tensor(dones).to(dtype=th.bool, device=self.device) - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) + values = self.policy.predict_values( + ot.tree_map(lambda x: x.unsqueeze(0), obs_as_tensor(new_obs, self.device), namespace=NS), + lstm_states.vf, + episode_starts, + ) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) From 3388a87e59ce075324281c1f9576754274fa906f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 19:51:33 -0700 Subject: [PATCH 18/31] Start from Numpy --- stable_baselines3/common/recurrent/buffers.py | 513 ++++++++++++------ .../common/recurrent/policies.py | 103 ++-- .../common/recurrent/type_aliases.py | 37 +- .../ppo_recurrent/ppo_recurrent.py | 148 ++--- 4 files changed, 469 insertions(+), 332 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index de61eeb30..59f885dfe 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -1,65 +1,97 @@ -import dataclasses -from typing import Any, Callable, Generator, Optional, Tuple, Union +from functools import partial +from typing import Callable, Generator, Optional, Tuple, Union -import optree as ot +import numpy as np import torch as th from gymnasium import spaces -from optree import PyTree -from stable_baselines3.common.buffers import RolloutBuffer -from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.recurrent.type_aliases import ( - HiddenState, - PyTreeGeneric, - RecurrentRolloutBufferData, + RecurrentDictRolloutBufferSamples, RecurrentRolloutBufferSamples, + RNNStates, ) from stable_baselines3.common.vec_env import VecNormalize -def index_into_pytree( - idx: Any, - tree: PyTreeGeneric, - is_leaf: Optional[Union[bool, Callable[[PyTreeGeneric], bool]]] = None, - none_is_leaf: bool = False, - namespace: str = NS, -) -> PyTreeGeneric: - return ot.tree_map(lambda x: x[idx], tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) # type: ignore - - -def space_to_example( - batch_shape: Tuple[int, ...], - space: spaces.Space, - *, - device: Optional[th.device] = None, - ensure_non_batch_dim: bool = False, -) -> PyTree[th.Tensor]: - if isinstance(space, spaces.Dict): - return { - k: space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) - for k, v in space.items() - } - if isinstance(space, spaces.Tuple): - return tuple(space_to_example(batch_shape, v, device=device, ensure_non_batch_dim=ensure_non_batch_dim) for v in space) - - if isinstance(space, spaces.Box): - space_shape = space.shape - space_dtype = th.float32 - elif isinstance(space, spaces.Discrete): - space_shape = () - space_dtype = th.long - elif isinstance(space, spaces.MultiDiscrete): - space_shape = (len(space.nvec),) - space_dtype = th.long - elif isinstance(space, spaces.MultiBinary): - space_shape = (space.n,) - space_dtype = th.float32 - else: - raise TypeError(f"Unknown space type {type(space)} for {space}") - - if ensure_non_batch_dim and not space_shape: - space_shape = (1,) - return th.zeros((*batch_shape, *space_shape), dtype=space_dtype, device=device) +def pad( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Chunk sequences and pad them to have constant dimensions. + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device + :param tensor: Tensor of shape (batch_size, *tensor_shape) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq, max_length, *tensor_shape) + """ + # Create sequences given start and end + seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] + return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) + + +def pad_and_flatten( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device (cpu, gpu, ...) + :param tensor: Tensor of shape (max_length, n_seq, 1) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq * max_length,) aka (padded_batch_size,) + """ + return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() + + +def create_sequencers( + episode_starts: np.ndarray, + env_change: np.ndarray, + device: th.device, +) -> Tuple[np.ndarray, Callable, Callable]: + """ + Create the utility function to chunk data into + sequences and pad them to create fixed size tensors. + + :param episode_starts: Indices where an episode starts + :param env_change: Indices where the data collected + come from a different env (when using multiple env for data collection) + :param device: PyTorch device + :return: Indices of the transitions that start a sequence, + pad and pad_and_flatten utilities tailored for this batch + (sequence starts and ends indices are fixed) + """ + # Create sequence if env changes too + seq_start = np.logical_or(episode_starts, env_change).flatten() + # First index is always the beginning of a sequence + seq_start[0] = True + # Retrieve indices of sequence starts + seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 + # End of sequence are just before sequence starts + # Last index is also always end of a sequence + seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) + + # Create padding method for this minibatch + # to avoid repeating arguments (seq_start_indices, seq_end_indices) + local_pad = partial(pad, seq_start_indices, seq_end_indices, device) + local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) + return seq_start_indices, local_pad, local_pad_and_flatten class RecurrentRolloutBuffer(RolloutBuffer): @@ -78,146 +110,301 @@ class RecurrentRolloutBuffer(RolloutBuffer): :param n_envs: Number of parallel environments """ - advantages: th.Tensor - returns: th.Tensor - data: RecurrentRolloutBufferData - def __init__( self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - hidden_state_example: HiddenState, + hidden_state_shape: Tuple[int, int, int, int], device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device=device, n_envs=n_envs) - self.hidden_state_example = hidden_state_example - self.gae_lambda = gae_lambda - self.gamma = gamma + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) - batch_shape = (self.buffer_size, self.n_envs) - device = self.device + def reset(self): + super().reset() + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.hidden_state_example = ot.tree_map( - lambda x: th.zeros((), dtype=x.dtype, device=device).expand_as(x), hidden_state_example, namespace=NS - ) - self.advantages = th.zeros(batch_shape, dtype=th.float32, device=device) - self.returns = th.zeros(batch_shape, dtype=th.float32, device=device) - self.data = RecurrentRolloutBufferData( - observations=space_to_example(batch_shape, self.observation_space, device=device, ensure_non_batch_dim=True), - actions=th.zeros( - (*batch_shape, self.action_dim), - dtype=th.long if isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)) else th.float32, - device=device, - ), - rewards=th.zeros(batch_shape, dtype=th.float32, device=device), - episode_starts=th.zeros(batch_shape, dtype=th.bool, device=device), - values=th.zeros(batch_shape, dtype=th.float32, device=device), - log_probs=th.zeros(batch_shape, dtype=th.float32, device=device), - hidden_states=ot.tree_map( - lambda x: th.zeros(self._reshape_hidden_state_shape(batch_shape, x.shape), dtype=x.dtype, device=device), - hidden_state_example, - namespace=NS, - ), - ) + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + + super().add(*(th.as_tensor(a) for a in args), **kwargs) @staticmethod - def _reshape_hidden_state_shape(batch_shape: Tuple[int, ...], state_shape: Tuple[int, ...]) -> Tuple[int, ...]: - if len(state_shape) < 2: - raise NotImplementedError("State shape must be 2+ dimensions currently") - return (*batch_shape[:-1], state_shape[0], batch_shape[-1], *state_shape[1:]) + def swap_and_flatten(arr: np.ndarray) -> np.ndarray: + """ + Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) + to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) + to [n_steps * n_envs, ...] (which maintain the order) - # Expose attributes of the RecurrentRolloutBufferData in the top-level to conform to the RolloutBuffer interface - @property - def episode_starts(self) -> th.Tensor: - return self.data.episode_starts + :param arr: + :return: + """ + shape = arr.shape + if len(shape) < 3: + shape = (*shape, 1) + return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) - @property - def values(self) -> th.Tensor: - return self.data.values + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" - @property - def rewards(self) -> th.Tensor: - assert self.data.rewards is not None, "RecurrentRolloutBufferData should store rewards" - return self.data.rewards + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) - def reset(self): - self.returns.zero_() - self.advantages.zero_() - ot.tree_map(lambda x: x.zero_(), self.data, namespace=NS) - super(RolloutBuffer, self).reset() + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True - def extend(self, *args) -> None: - """ - Add a new batch of transitions to the buffer - """ + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Sampling strategy that allows any mini batch size but requires + # more complexity and use of padding + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + # Number of sequences + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) + # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) + # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = ( + self.to_device(th.from_numpy(lstm_states_pi[0])).contiguous(), + self.to_device(th.from_numpy(lstm_states_pi[1])).contiguous(), + ) + lstm_states_vf = ( + self.to_device(th.from_numpy(lstm_states_vf[0])).contiguous(), + self.to_device(th.from_numpy(lstm_states_vf[1])).contiguous(), + ) + return RecurrentRolloutBufferSamples( + # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) + observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)), + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + ) + + +class RecurrentDictRolloutBuffer(DictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RecurrentRolloutBuffer to use dictionary observations + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ - # Do a for loop along the batch axis. - # Treat lists as leaves to avoid flattening the infos. - def _is_list(t): - return isinstance(t, list) + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) - tensors, _ = ot.tree_flatten(args, is_leaf=_is_list, namespace=NS) - len_tensors = len(tensors[0]) - assert all(len(t) == len_tensors for t in tensors), "All tensors must have the same batch size" - for i in range(len_tensors): - self.add(*index_into_pytree(i, args, is_leaf=_is_list, namespace=NS)) + def reset(self): + super().reset() + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - if data.rewards is None: - raise ValueError("Recorded samples must contain a reward") - new_data = dataclasses.replace(data, actions=data.actions.reshape((self.n_envs, self.action_dim))) - - ot.tree_map( - lambda buf, x: buf[self.pos].copy_(x if x.ndim + 1 == buf.ndim else x.unsqueeze(-1), non_blocking=True), - self.data, - new_data, - namespace=NS, - ) - # Increment pos - self.pos += 1 - if self.pos == self.buffer_size: - self.full = True + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + + super().add(*args, **kwargs) - def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + for tensor in [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + # Return everything, don't create minibatches - if batch_size is None or batch_size == self.buffer_size * self.n_envs: - yield self._get_samples(slice(None)) - return - - if batch_size % self.n_envs != 0 or batch_size < self.n_envs: - raise ValueError( - f"The batch size must be a multiple of the number of environments (n_envs={self.n_envs}) ", - f"but batch_size={batch_size}.", - ) - batch_size //= self.n_envs - - for start_idx in range(0, self.buffer_size, batch_size): - out = self._get_samples(slice(start_idx, start_idx + batch_size, None)) - assert len(out.observations) != 0 - yield out - - def _get_samples( # type: ignore[override] + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( self, - batch_inds: Union[slice, th.Tensor], + batch_inds: np.ndarray, + env_change: np.ndarray, env: Optional[VecNormalize] = None, - ) -> RecurrentRolloutBufferSamples: - samples = RecurrentRolloutBufferSamples( - observations=self.data.observations, - actions=self.data.actions, - episode_starts=self.data.episode_starts, - old_values=self.data.values, - old_log_prob=self.data.log_probs, - advantages=self.advantages, - returns=self.returns, - hidden_states=self.data.hidden_states, + ) -> RecurrentDictRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_vf = ( + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + ) + lstm_states_pi = ( + self.to_device(th.from_numpy(lstm_states_pi[0])).contiguous(), + self.to_device(th.from_numpy(lstm_states_pi[1])).contiguous(), + ) + lstm_states_vf = ( + self.to_device(th.from_numpy(lstm_states_vf[0])).contiguous(), + self.to_device(th.from_numpy(lstm_states_vf[1])).contiguous(), + ) + + observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} + observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} + + return RecurrentDictRolloutBufferSamples( + observations=observations, + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), ) - return ot.tree_map(lambda tens: self.to_device(tens[batch_inds]), samples, namespace=NS) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index c39719e1b..e45772190 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -1,15 +1,13 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union -import optree as ot +import numpy as np import torch as th from gymnasium import spaces from torch import nn from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS -from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree -from stable_baselines3.common.recurrent.buffers import space_to_example +from stable_baselines3.common.recurrent.type_aliases import RNNStates from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -17,16 +15,10 @@ MlpExtractor, NatureCNN, ) -from stable_baselines3.common.type_aliases import Schedule, TorchGymObs +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import zip_strict -@dataclass_frozen_pytree -class LSTMStates: - pi: th.Tensor - vf: th.Tensor - - class RecurrentActorCriticPolicy(ActorCriticPolicy): """ Recurrent policy class for actor-critic algorithms (has both policy and value prediction). @@ -116,7 +108,6 @@ def __init__( ) self.lstm_kwargs = lstm_kwargs or {} - assert not self.lstm_kwargs.get("batch_first", False) self.shared_lstm = shared_lstm self.enable_critic_lstm = enable_critic_lstm self.lstm_actor = nn.LSTM( @@ -168,8 +159,6 @@ def _build_mlp_extractor(self) -> None: device=self.device, ) - self.observation_example = space_to_example((), self.observation_space) - @staticmethod def _process_sequence( features: th.Tensor, @@ -187,39 +176,47 @@ def _process_sequence( :param lstm: LSTM object. :return: LSTM output and updated LSTM states. """ + # LSTM logic + # (sequence length, batch size, features dim) + # (batch size = n_envs for data collection or n_seq when doing gradient update) + n_seq = lstm_states[0].shape[1] + # Batch to sequence + # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) + # note: max length (max sequence length) is always 1 during data collection + features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1) # If we don't have to reset the state in the middle of a sequence # we can avoid the for loop, which speeds up things - assert episode_starts.ndim == 2 - if not th.any(episode_starts[1:]): - initial_is_not_reset = (~episode_starts[0]).unsqueeze(-1) - lstm_states = (lstm_states[0] * initial_is_not_reset, lstm_states[1] * initial_is_not_reset) - lstm_output, lstm_states = lstm(features, lstm_states) + if th.all(episode_starts == 0.0): + lstm_output, lstm_states = lstm(features_sequence, lstm_states) + lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1) return lstm_output, lstm_states lstm_output = [] # Iterate over the sequence - for features, episode_start in zip_strict(features, episode_starts): - is_not_reset = (~episode_start).unsqueeze(-1) + for features, episode_start in zip_strict(features_sequence, episode_starts): hidden, lstm_states = lstm( - features, + features.unsqueeze(dim=0), ( # Reset the states at the beginning of a new episode - is_not_reset * lstm_states[0], - is_not_reset * lstm_states[1], + (~episode_start).view(1, n_seq, 1) * lstm_states[0], + (~episode_start).view(1, n_seq, 1) * lstm_states[1], ), ) - lstm_output.append(hidden) - lstm_output = th.cat(lstm_output) + lstm_output += [hidden] + # Sequence to batch + # (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim) + lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) return lstm_output, lstm_states def forward( self, obs: th.Tensor, - lstm_states: LSTMStates, + lstm_states: RNNStates, episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, LSTMStates]: + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: """ Forward pass in all the networks (actor and critic) @@ -257,7 +254,7 @@ def forward( distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) - return actions, values, log_prob, LSTMStates(lstm_states_pi, lstm_states_vf) + return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) def get_distribution( self, @@ -310,20 +307,8 @@ def predict_values( latent_vf = self.mlp_extractor.forward_critic(latent_vf) return self.value_net(latent_vf) - def extract_features(self, obs: TorchGymObs) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: - obs_flat = ot.tree_map(lambda x, x_nobatch: x.view(-1, *x_nobatch.shape), obs, self.observation_example, namespace=NS) - obs_batch_shapes = ot.tree_map( - lambda x, x_nobatch: x.shape[: x.ndim - x_nobatch.ndim], obs, self.observation_example, namespace=NS - ) - - (obs_batch_shape, *_), _ = ot.tree_flatten(obs_batch_shapes, namespace=NS) - - features_flat = super().extract_features(obs_flat) - features = ot.tree_map(lambda x: x.view(*obs_batch_shape, *x.shape[1:]), features_flat, namespace=NS) - return features - def evaluate_actions( - self, obs: th.Tensor, actions: th.Tensor, lstm_states: LSTMStates, episode_starts: th.Tensor + self, obs: th.Tensor, actions: th.Tensor, lstm_states: RNNStates, episode_starts: th.Tensor ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Evaluate actions according to the current policy, @@ -354,13 +339,10 @@ def evaluate_actions( latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) - action_batch_shape = actions.shape[:-1] distribution = self._get_action_dist_from_latent(latent_pi) - distribution_shape = distribution.mode().shape # FIXME: don't instantiate a tensor for a shape - log_prob = distribution.log_prob(actions.view(distribution_shape)) + log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) - entropy = distribution.entropy() - return (values.view(action_batch_shape), log_prob.view(action_batch_shape), entropy.view(action_batch_shape)) + return values, log_prob, distribution.entropy() def _predict( self, @@ -384,11 +366,11 @@ def _predict( def predict( self, - observation: Union[th.Tensor, Dict[str, th.Tensor]], - state: Optional[Tuple[th.Tensor, ...]] = None, - episode_start: Optional[th.Tensor] = None, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[th.Tensor, Optional[Tuple[th.Tensor, ...]]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -413,16 +395,25 @@ def predict( # state : (n_layers, n_envs, dim) if state is None: # Initialize hidden states to zeros - state_component = th.cat([th.zeros(self.lstm_hidden_state_shape)] * n_envs, axis=1) - state = (state_component, state_component) + state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) + state = (state, state) if episode_start is None: - episode_start = th.zeros((n_envs,), dtype=th.bool) + episode_start = np.array([False for _ in range(n_envs)]) with th.no_grad(): + # Convert to PyTorch tensors + states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( + state[1], dtype=th.float32, device=self.device + ) + episode_starts = th.tensor(episode_start, dtype=th.bool, device=self.device) actions, states = self._predict( - observation, lstm_states=state, episode_starts=episode_start, deterministic=deterministic + observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic ) + states = (states[0].cpu().numpy(), states[1].cpu().numpy()) + + # Convert to numpy + actions = actions.cpu().numpy() if isinstance(self.action_space, spaces.Box): if self.squash_output: @@ -431,7 +422,7 @@ def predict( else: # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = th.clip(actions, th.as_tensor(self.action_space.low), th.as_tensor(self.action_space.high)) + actions = np.clip(actions, self.action_space.low, self.action_space.high) # Remove batch dimension if needed if not vectorized_env: diff --git a/stable_baselines3/common/recurrent/type_aliases.py b/stable_baselines3/common/recurrent/type_aliases.py index 56be23b9d..21ac0e0d9 100644 --- a/stable_baselines3/common/recurrent/type_aliases.py +++ b/stable_baselines3/common/recurrent/type_aliases.py @@ -1,34 +1,33 @@ -from typing import TypeVar +from typing import NamedTuple, Tuple import torch as th -from optree import PyTree +from stable_baselines3.common.type_aliases import TensorDict -from stable_baselines3.common.pytree_dataclass import dataclass_frozen_pytree -HiddenState = PyTree[th.Tensor] +class RNNStates(NamedTuple): + pi: Tuple[th.Tensor, ...] + vf: Tuple[th.Tensor, ...] -PyTreeGeneric = TypeVar("PyTreeGeneric", bound=PyTree) - - -@dataclass_frozen_pytree -class RecurrentRolloutBufferData: - observations: PyTree[th.Tensor] +class RecurrentRolloutBufferSamples(NamedTuple): + observations: th.Tensor actions: th.Tensor - rewards: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates episode_starts: th.Tensor - values: th.Tensor - log_probs: th.Tensor - hidden_states: HiddenState + mask: th.Tensor -@dataclass_frozen_pytree -class RecurrentRolloutBufferSamples: - observations: PyTree[th.Tensor] +class RecurrentDictRolloutBufferSamples(NamedTuple): + observations: TensorDict actions: th.Tensor - episode_starts: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor - hidden_states: HiddenState advantages: th.Tensor returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + mask: th.Tensor diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index b44161618..522667b67 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -1,29 +1,22 @@ import sys import time -import warnings from copy import deepcopy from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union import numpy as np -import optree as ot import torch as th -import torch.nn.functional as F from gymnasium import spaces from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy -from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS from stable_baselines3.common.recurrent.buffers import ( + RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, - index_into_pytree, ) -from stable_baselines3.common.recurrent.policies import ( - LSTMStates, - RecurrentActorCriticPolicy, -) -from stable_baselines3.common.recurrent.type_aliases import RecurrentRolloutBufferData +from stable_baselines3.common.recurrent.policies import RecurrentActorCriticPolicy +from stable_baselines3.common.recurrent.type_aliases import RNNStates from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import ( explained_variance, @@ -101,7 +94,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 128, - batch_size: int = 128, + batch_size: Optional[int] = 128, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, @@ -148,31 +141,7 @@ def __init__( spaces.MultiBinary, ), ) - # Sanity check, otherwise it will lead to noisy gradient and NaN - # because of the advantage normalization - if normalize_advantage: - assert ( - batch_size > 1 - ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" - - if self.env is not None: - # Check that `n_steps * n_envs > 1` to avoid NaN - # when doing advantage normalization - buffer_size = self.env.num_envs * self.n_steps - assert buffer_size > 1 or ( - not normalize_advantage - ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" - # Check that the rollout buffer size is a multiple of the mini-batch size - untruncated_batches = buffer_size // batch_size - if buffer_size % batch_size > 0: - warnings.warn( - f"You have specified a mini-batch size of {batch_size}," - f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," - f" after every {untruncated_batches} untruncated mini-batches," - f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" - f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" - f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" - ) + self.batch_size = batch_size self.n_epochs = n_epochs self.clip_range = clip_range @@ -188,6 +157,8 @@ def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) + buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer + self.policy = self.policy_class( self.observation_space, self.action_space, @@ -206,7 +177,7 @@ def _setup_model(self) -> None: single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) # hidden and cell states for actor and critic - self._last_lstm_states = LSTMStates( + self._last_lstm_states = RNNStates( ( th.zeros(single_hidden_state_shape, device=self.device), th.zeros(single_hidden_state_shape, device=self.device), @@ -217,23 +188,13 @@ def _setup_model(self) -> None: ), ) - single_1envhidden_state_shape = (lstm.num_layers, lstm.hidden_size) - example_lstm_states = LSTMStates( - ( - th.zeros(single_1envhidden_state_shape, device=self.device), - th.zeros(single_1envhidden_state_shape, device=self.device), - ), - ( - th.zeros(single_1envhidden_state_shape, device=self.device), - th.zeros(single_1envhidden_state_shape, device=self.device), - ), - ) + hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) - self.rollout_buffer = RecurrentRolloutBuffer( + self.rollout_buffer = buffer_cls( self.n_steps, self.observation_space, self.action_space, - example_lstm_states, + hidden_state_buffer_shape, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, @@ -268,7 +229,9 @@ def collect_rollouts( :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ - assert isinstance(rollout_buffer, RecurrentRolloutBuffer), f"{rollout_buffer} doesn't support recurrent policy" + assert isinstance( + rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) + ), f"{rollout_buffer} doesn't support recurrent policy" assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) @@ -291,20 +254,17 @@ def collect_rollouts( with th.no_grad(): # Convert to pytorch tensor or to TensorDict - obs_tensor = ot.tree_map(lambda x: x.unsqueeze(0), obs_as_tensor(self._last_obs, self.device), namespace=NS) - episode_starts = th.as_tensor(self._last_episode_starts).to(dtype=th.bool, device=self.device) - actions, values, log_probs, lstm_states = self.policy.forward( - obs_tensor, lstm_states, episode_starts.unsqueeze(0) - ) - actions = actions.squeeze(0) - values = values.squeeze(0) - log_probs = log_probs.squeeze(0) + obs_tensor = obs_as_tensor(self._last_obs, self.device) + episode_starts = th.tensor(self._last_episode_starts, dtype=th.bool, device=self.device) + actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) + + actions = actions.cpu().numpy() # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.action_space, spaces.Box): - clipped_actions = th.clip(actions, th.as_tensor(self.action_space.low), th.as_tensor(self.action_space.high)) + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -330,31 +290,27 @@ def collect_rollouts( and infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False) ): - terminal_obs, _ = self.policy.obs_to_tensor(infos[idx]["terminal_observation"]) + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): terminal_lstm_state = ( lstm_states.vf[0][:, idx : idx + 1, :].contiguous(), lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), ) # terminal_lstm_state = None - episode_starts = th.zeros((1,), dtype=th.bool, device=self.device) - terminal_value = self.policy.predict_values( - ot.tree_map(lambda x: x.unsqueeze(0), terminal_obs, namespace=NS), - terminal_lstm_state, - episode_starts.unsqueeze(0), - )[0].squeeze(0) + episode_starts = th.tensor([False], dtype=th.bool, device=self.device) + terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[ + 0 + ].squeeze() rewards[idx] += self.gamma * terminal_value rollout_buffer.add( - RecurrentRolloutBufferData( - self._last_obs, - actions, - rewards, - self._last_episode_starts, - values.squeeze(-1), - log_probs, - hidden_states=self._last_lstm_states, - ) + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values, + log_probs, + lstm_states=self._last_lstm_states, ) self._last_obs = new_obs @@ -363,12 +319,8 @@ def collect_rollouts( with th.no_grad(): # Compute value for the last timestep - episode_starts = th.as_tensor(dones).to(dtype=th.bool, device=self.device) - values = self.policy.predict_values( - ot.tree_map(lambda x: x.unsqueeze(0), obs_as_tensor(new_obs, self.device), namespace=NS), - lstm_states.vf, - episode_starts, - ) + episode_starts = th.tensor(dones, dtype=th.bool, device=self.device) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -402,6 +354,12 @@ def train(self) -> None: # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Convert mask from float to bool + mask = rollout_data.mask > 1e-8 # Re-sample the noise matrix because the log_std has changed if self.use_sde: @@ -410,15 +368,15 @@ def train(self) -> None: values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions, - index_into_pytree(0, rollout_data.hidden_states), + rollout_data.lstm_states, rollout_data.episode_starts, ) + values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - # Normalization does not make sense if mini batchsize == 1, see GH issue #325 - if self.normalize_advantage and len(advantages) > 1: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + if self.normalize_advantage: + advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) @@ -426,32 +384,34 @@ def train(self) -> None: # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) # Logging pg_losses.append(policy_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: - # Clip the difference between old and new value + # Clip the different between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf ) # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(rollout_data.returns, values_pred) + # Mask padded sequences + value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) + value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob) + entropy_loss = -th.mean(-log_prob[mask]) else: - entropy_loss = -th.mean(entropy) + entropy_loss = -th.mean(entropy[mask]) entropy_losses.append(entropy_loss.item()) @@ -463,7 +423,7 @@ def train(self) -> None: # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: @@ -492,7 +452,7 @@ def train(self) -> None: self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) - self.logger.record("train/explained_variance", explained_var.item()) + self.logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) From 44bc3b5cc665fae6b6538987bbb3d79c1dd42de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 19:56:17 -0700 Subject: [PATCH 19/31] Fix lack of data --- tests/test_buffers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index dd6cad5f7..2bbbdaca6 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -14,10 +14,7 @@ from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE -from stable_baselines3.common.recurrent.buffers import ( - RecurrentRolloutBuffer, - RecurrentRolloutBufferData, -) +from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, ReplayBufferSamples, @@ -159,7 +156,7 @@ def test_device_buffer(replay_buffer_cls, device): elif replay_buffer_cls == RecurrentRolloutBuffer: episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1) hidden_states = {"a": {"b": th.zeros(2, buffer.n_envs, 4)}} - buffer.add(RecurrentRolloutBufferData(obs, action, reward, episode_start, values, log_prob, hidden_states)) + buffer.add(obs, action, reward, episode_start, values, log_prob, hidden_states) else: buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs From fd7f9f0df0c37139d04479953fabf70dfb9c27aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 20:13:12 -0700 Subject: [PATCH 20/31] Some tweaking --- stable_baselines3/common/recurrent/buffers.py | 51 ++++++++++--------- .../ppo_recurrent/ppo_recurrent.py | 2 - tests/test_identity.py | 2 +- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 59f885dfe..8d095d9f0 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -2,6 +2,7 @@ from typing import Callable, Generator, Optional, Tuple, Union import numpy as np +import optree as ot import torch as th from gymnasium import spaces @@ -127,19 +128,19 @@ def __init__( def reset(self): super().reset() - self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) - self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) - self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) - self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + self.hidden_states_pi[self.pos].copy_(lstm_states.pi[0], non_blocking=True) + self.cell_states_pi[self.pos].copy_(lstm_states.pi[1], non_blocking=True) + self.hidden_states_vf[self.pos].copy_(lstm_states.vf[0], non_blocking=True) + self.cell_states_vf[self.pos].copy_(lstm_states.vf[1], non_blocking=True) super().add(*(th.as_tensor(a) for a in args), **kwargs) @@ -240,12 +241,12 @@ def _get_samples( self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) lstm_states_pi = ( - self.to_device(th.from_numpy(lstm_states_pi[0])).contiguous(), - self.to_device(th.from_numpy(lstm_states_pi[1])).contiguous(), + self.to_device((lstm_states_pi[0])).contiguous(), + self.to_device((lstm_states_pi[1])).contiguous(), ) lstm_states_vf = ( - self.to_device(th.from_numpy(lstm_states_vf[0])).contiguous(), - self.to_device(th.from_numpy(lstm_states_vf[1])).contiguous(), + self.to_device((lstm_states_vf[0])).contiguous(), + self.to_device((lstm_states_vf[1])).contiguous(), ) return RecurrentRolloutBufferSamples( # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) @@ -294,21 +295,21 @@ def __init__( def reset(self): super().reset() - self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) - self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) - self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) - self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + self.hidden_states_pi[self.pos].copy_(lstm_states.pi[0], non_blocking=True) + self.cell_states_pi[self.pos].copy_(lstm_states.pi[1], non_blocking=True) + self.hidden_states_vf[self.pos].copy_(lstm_states.vf[0], non_blocking=True) + self.cell_states_vf[self.pos].copy_(lstm_states.vf[1], non_blocking=True) - super().add(*args, **kwargs) + super().add(*ot.tree_map(th.as_tensor, args), **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" @@ -386,12 +387,12 @@ def _get_samples( self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) lstm_states_pi = ( - self.to_device(th.from_numpy(lstm_states_pi[0])).contiguous(), - self.to_device(th.from_numpy(lstm_states_pi[1])).contiguous(), + self.to_device((lstm_states_pi[0])).contiguous(), + self.to_device((lstm_states_pi[1])).contiguous(), ) lstm_states_vf = ( - self.to_device(th.from_numpy(lstm_states_vf[0])).contiguous(), - self.to_device(th.from_numpy(lstm_states_vf[1])).contiguous(), + self.to_device((lstm_states_vf[0])).contiguous(), + self.to_device((lstm_states_vf[1])).contiguous(), ) observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 522667b67..2611438b9 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -258,8 +258,6 @@ def collect_rollouts( episode_starts = th.tensor(self._last_episode_starts, dtype=th.bool, device=self.device) actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) - actions = actions.cpu().numpy() - # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error diff --git a/tests/test_identity.py b/tests/test_identity.py index aa65182e2..6d795561f 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -21,7 +21,7 @@ def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} - n_steps = 10000 + n_steps = 20000 if model_class == RecurrentPPO else 10000 if model_class == DQN: kwargs = dict(learning_starts=0) # DQN only support discrete actions From 2c8974dadceb8318a289069b4a536a279d5ac3b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 20:13:57 -0700 Subject: [PATCH 21/31] No swapaxes overwrite --- stable_baselines3/common/recurrent/buffers.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 8d095d9f0..cdec36d25 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -144,21 +144,6 @@ def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: super().add(*(th.as_tensor(a) for a in args), **kwargs) - @staticmethod - def swap_and_flatten(arr: np.ndarray) -> np.ndarray: - """ - Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) - to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) - to [n_steps * n_envs, ...] (which maintain the order) - - :param arr: - :return: - """ - shape = arr.shape - if len(shape) < 3: - shape = (*shape, 1) - return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) - def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" From 2eae72a338f220f76969cc5ef82ec861967368b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 20:21:37 -0700 Subject: [PATCH 22/31] Basic torchification of buffers --- stable_baselines3/common/recurrent/buffers.py | 95 +++++++++---------- tests/test_identity.py | 2 +- 2 files changed, 48 insertions(+), 49 deletions(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index cdec36d25..68dbf1a79 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -16,10 +16,9 @@ def pad( - seq_start_indices: np.ndarray, - seq_end_indices: np.ndarray, - device: th.device, - tensor: np.ndarray, + seq_start_indices: th.Tensor, + seq_end_indices: th.Tensor, + tensor: th.Tensor, padding_value: float = 0.0, ) -> th.Tensor: """ @@ -27,22 +26,20 @@ def pad( :param seq_start_indices: Indices of the transitions that start a sequence :param seq_end_indices: Indices of the transitions that end a sequence - :param device: PyTorch device :param tensor: Tensor of shape (batch_size, *tensor_shape) :param padding_value: Value used to pad sequence to the same length (zero padding by default) :return: (n_seq, max_length, *tensor_shape) """ # Create sequences given start and end - seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] + seq = [tensor[start : end + 1] for start, end in zip(seq_start_indices, seq_end_indices)] return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) def pad_and_flatten( - seq_start_indices: np.ndarray, - seq_end_indices: np.ndarray, - device: th.device, - tensor: np.ndarray, + seq_start_indices: th.Tensor, + seq_end_indices: th.Tensor, + tensor: th.Tensor, padding_value: float = 0.0, ) -> th.Tensor: """ @@ -52,20 +49,18 @@ def pad_and_flatten( :param seq_start_indices: Indices of the transitions that start a sequence :param seq_end_indices: Indices of the transitions that end a sequence - :param device: PyTorch device (cpu, gpu, ...) :param tensor: Tensor of shape (max_length, n_seq, 1) :param padding_value: Value used to pad sequence to the same length (zero padding by default) :return: (n_seq * max_length,) aka (padded_batch_size,) """ - return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() + return pad(seq_start_indices, seq_end_indices, tensor, padding_value).flatten() def create_sequencers( - episode_starts: np.ndarray, - env_change: np.ndarray, - device: th.device, -) -> Tuple[np.ndarray, Callable, Callable]: + episode_starts: th.Tensor, + env_change: th.Tensor, +) -> Tuple[th.Tensor, Callable, Callable]: """ Create the utility function to chunk data into sequences and pad them to create fixed size tensors. @@ -73,25 +68,29 @@ def create_sequencers( :param episode_starts: Indices where an episode starts :param env_change: Indices where the data collected come from a different env (when using multiple env for data collection) - :param device: PyTorch device :return: Indices of the transitions that start a sequence, pad and pad_and_flatten utilities tailored for this batch (sequence starts and ends indices are fixed) """ # Create sequence if env changes too - seq_start = np.logical_or(episode_starts, env_change).flatten() + seq_start = (episode_starts | env_change).flatten() # First index is always the beginning of a sequence seq_start[0] = True # Retrieve indices of sequence starts - seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 + seq_start_indices = th.argwhere(seq_start).squeeze(1) # End of sequence are just before sequence starts # Last index is also always end of a sequence - seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) + seq_end_indices = th.cat( + [ + (seq_start_indices - 1)[1:], + th.tensor([len(episode_starts)], device=seq_start_indices.device, dtype=seq_start_indices.dtype), + ] + ) # Create padding method for this minibatch # to avoid repeating arguments (seq_start_indices, seq_end_indices) - local_pad = partial(pad, seq_start_indices, seq_end_indices, device) - local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) + local_pad = partial(pad, seq_start_indices, seq_end_indices) + local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices) return seq_start_indices, local_pad, local_pad_and_flatten @@ -128,10 +127,10 @@ def __init__( def reset(self): super().reset() - self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) - self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) - self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) - self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) + self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) + self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) + self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ @@ -181,13 +180,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf # more complexity and use of padding # Trick to shuffle a bit: keep the sequence order # but split the indices in two - split_index = np.random.randint(self.buffer_size * self.n_envs) - indices = np.arange(self.buffer_size * self.n_envs) - indices = np.concatenate((indices[split_index:], indices[:split_index])) + split_index = int(np.random.randint(self.buffer_size * self.n_envs)) + indices = th.arange(self.buffer_size * self.n_envs) + indices = th.cat((indices[split_index:], indices[:split_index])) - env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + env_change = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool) # Flag first timestep as change of environment - env_change[0, :] = 1.0 + env_change[0, :] = True env_change = self.swap_and_flatten(env_change) start_idx = 0 @@ -198,13 +197,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf def _get_samples( self, - batch_inds: np.ndarray, - env_change: np.ndarray, + batch_inds: th.Tensor, + env_change: th.Tensor, env: Optional[VecNormalize] = None, ) -> RecurrentRolloutBufferSamples: # Retrieve sequence starts and utility function self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( - self.episode_starts[batch_inds], env_change[batch_inds], self.device + self.episode_starts[batch_inds], env_change[batch_inds] ) # Number of sequences @@ -243,7 +242,7 @@ def _get_samples( returns=self.pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), - mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + mask=self.pad_and_flatten(th.ones_like(self.returns[batch_inds])), ) @@ -280,10 +279,10 @@ def __init__( def reset(self): super().reset() - self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) - self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32) - self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) - self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32) + self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) + self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) + self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) + self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device) def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ @@ -330,13 +329,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou # Trick to shuffle a bit: keep the sequence order # but split the indices in two - split_index = np.random.randint(self.buffer_size * self.n_envs) - indices = np.arange(self.buffer_size * self.n_envs) - indices = np.concatenate((indices[split_index:], indices[:split_index])) + split_index = int(np.random.randint(self.buffer_size * self.n_envs)) + indices = th.arange(self.buffer_size * self.n_envs) + indices = th.cat((indices[split_index:], indices[:split_index])) - env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + env_change = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool) # Flag first timestep as change of environment - env_change[0, :] = 1.0 + env_change[0, :] = True env_change = self.swap_and_flatten(env_change) start_idx = 0 @@ -347,13 +346,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou def _get_samples( self, - batch_inds: np.ndarray, - env_change: np.ndarray, + batch_inds: th.Tensor, + env_change: th.Tensor, env: Optional[VecNormalize] = None, ) -> RecurrentDictRolloutBufferSamples: # Retrieve sequence starts and utility function self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( - self.episode_starts[batch_inds], env_change[batch_inds], self.device + self.episode_starts[batch_inds], env_change[batch_inds] ) n_seq = len(self.seq_start_indices) @@ -392,5 +391,5 @@ def _get_samples( returns=self.pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), - mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + mask=self.pad_and_flatten(th.ones_like(self.returns[batch_inds])), ) diff --git a/tests/test_identity.py b/tests/test_identity.py index 6d795561f..78d924748 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -21,7 +21,7 @@ def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} - n_steps = 20000 if model_class == RecurrentPPO else 10000 + n_steps = 25000 if model_class == RecurrentPPO else 10000 if model_class == DQN: kwargs = dict(learning_starts=0) # DQN only support discrete actions From 13828a4094783ad6e8031236504f52108bbf8797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 20:27:38 -0700 Subject: [PATCH 23/31] Torchify policies and PPORecurrent --- stable_baselines3/common/recurrent/policies.py | 16 +++++++++------- stable_baselines3/ppo_recurrent/ppo_recurrent.py | 6 ++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index e45772190..eab3f7fb9 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -366,11 +366,11 @@ def _predict( def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, - episode_start: Optional[np.ndarray] = None, + observation: Union[th.Tensor, Dict[str, th.Tensor]], + state: Optional[Tuple[th.Tensor, ...]] = None, + episode_start: Optional[th.Tensor] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> Tuple[th.Tensor, Optional[Tuple[th.Tensor, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -395,11 +395,11 @@ def predict( # state : (n_layers, n_envs, dim) if state is None: # Initialize hidden states to zeros - state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) + state = th.cat([th.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) state = (state, state) if episode_start is None: - episode_start = np.array([False for _ in range(n_envs)]) + episode_start = th.zeros(n_envs, dtype=th.bool) with th.no_grad(): # Convert to PyTorch tensors @@ -422,7 +422,9 @@ def predict( else: # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = np.clip(actions, self.action_space.low, self.action_space.high) + actions = th.clip( + actions, th.as_tensor(self.action_space.low).to(actions), th.as_tensor(self.action_space.high).to(actions) + ) # Remove batch dimension if needed if not vectorized_env: diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index 2611438b9..4424b6f00 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -262,7 +262,9 @@ def collect_rollouts( clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.action_space, spaces.Box): - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + clipped_actions = th.clip( + actions, th.as_tensor(self.action_space.low).to(actions), th.as_tensor(self.action_space.high).to(actions) + ) new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -450,7 +452,7 @@ def train(self) -> None: self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) - self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/explained_variance", explained_var.item()) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) From 18b9f4101dba62f5e73187fd89b1b2a9b649a8fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 20:38:47 -0700 Subject: [PATCH 24/31] Don't use numpy in poilcy predict --- stable_baselines3/common/recurrent/policies.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/stable_baselines3/common/recurrent/policies.py b/stable_baselines3/common/recurrent/policies.py index eab3f7fb9..87372a1d9 100644 --- a/stable_baselines3/common/recurrent/policies.py +++ b/stable_baselines3/common/recurrent/policies.py @@ -403,17 +403,9 @@ def predict( with th.no_grad(): # Convert to PyTorch tensors - states = th.tensor(state[0], dtype=th.float32, device=self.device), th.tensor( - state[1], dtype=th.float32, device=self.device + actions, state = self._predict( + observation, lstm_states=state, episode_starts=episode_start, deterministic=deterministic ) - episode_starts = th.tensor(episode_start, dtype=th.bool, device=self.device) - actions, states = self._predict( - observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic - ) - states = (states[0].cpu().numpy(), states[1].cpu().numpy()) - - # Convert to numpy - actions = actions.cpu().numpy() if isinstance(self.action_space, spaces.Box): if self.squash_output: @@ -430,7 +422,7 @@ def predict( if not vectorized_env: actions = actions.squeeze(axis=0) - return actions, states + return actions, state class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): From 43a04a7709c067f9013ee48dd8992c54197c486d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 20:46:19 -0700 Subject: [PATCH 25/31] The pytree that's actually used --- stable_baselines3/common/pytree_dataclass.py | 47 -------------------- tests/test_buffers.py | 2 +- 2 files changed, 1 insertion(+), 48 deletions(-) diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py index b5303284e..8fbe75fcb 100644 --- a/stable_baselines3/common/pytree_dataclass.py +++ b/stable_baselines3/common/pytree_dataclass.py @@ -1,48 +1 @@ -import dataclasses -from typing import Optional, Sequence, Type - -import optree as ot -from typing_extensions import dataclass_transform - -__all__ = ["register_dataclass_as_pytree", "dataclass_frozen_pytree", "tree_empty", "OT_NAMESPACE"] - OT_NAMESPACE = "stable-baselines3" - - -def register_dataclass_as_pytree(Cls, whitelist: Optional[Sequence[str]] = None): - """Register a dataclass as a pytree, using the given whitelist of field names. - - :param Cls: The dataclass to register. - :param whitelist: The names of the fields to include in the pytree. If None, all fields are included. - :return: The dataclass, with the pytree registration applied. This is useful to be able to register a decorator. - """ - - assert dataclasses.is_dataclass(Cls) - - names = tuple(f.name for f in dataclasses.fields(Cls) if whitelist is None or f.name in whitelist) - - def flatten_fn(inst): - return (getattr(inst, n) for n in names), None, names - - def unflatten_fn(context, values): - return Cls(**dict(zip(names, values))) - - ot.register_pytree_node(Cls, flatten_fn, unflatten_fn, namespace=OT_NAMESPACE) - - Cls.__iter__ = lambda self: iter(getattr(self, n) for n in names) - return Cls - - -@dataclass_transform() -def dataclass_frozen_pytree(Cls: Type, **kwargs) -> Type[ot.PyTree]: - """Decorator to make a frozen dataclass and register it as a PyTree.""" - true_kwargs = dict(frozen=True, slots=True) - true_kwargs.update(kwargs) - dataCls = dataclasses.dataclass(**true_kwargs)(Cls) - register_dataclass_as_pytree(dataCls) - return dataCls - - -def tree_empty(tree: ot.PyTree) -> bool: - flattened_state, _ = ot.tree_flatten(tree, namespace=OT_NAMESPACE) - return not bool(len(flattened_state)) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 2bbbdaca6..1b01630cc 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -13,7 +13,7 @@ ) from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE +from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, From d89d269752da7868aff98c61e32b7b7a1d29b9fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 20 Sep 2023 20:52:41 -0700 Subject: [PATCH 26/31] Make the silly test pass --- tests/test_buffers.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 1b01630cc..37b2c7a0f 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -13,8 +13,12 @@ ) from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE as NS -from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer +from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE +from stable_baselines3.common.recurrent.buffers import ( + RecurrentDictRolloutBuffer, + RecurrentRolloutBuffer, +) +from stable_baselines3.common.recurrent.type_aliases import RNNStates from stable_baselines3.common.type_aliases import ( DictReplayBufferSamples, ReplayBufferSamples, @@ -132,14 +136,14 @@ def test_device_buffer(replay_buffer_cls, device): DictRolloutBuffer: DummyDictEnv, ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv, - RecurrentRolloutBuffer: DummyDictEnv, + RecurrentRolloutBuffer: DummyEnv, + RecurrentDictRolloutBuffer: DummyDictEnv, }[replay_buffer_cls] env = make_vec_env(env) if replay_buffer_cls == RecurrentRolloutBuffer: - hidden_states = {"a": {"b": th.zeros(2, 4)}} buffer = RecurrentRolloutBuffer( - 100, env.observation_space, env.action_space, hidden_state_example=hidden_states, device=device + 100, env.observation_space, env.action_space, hidden_state_shape=(100, 1, env.num_envs, 4), device=device ) else: buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) @@ -155,8 +159,9 @@ def test_device_buffer(replay_buffer_cls, device): buffer.add(obs, action, reward, episode_start, values, log_prob) elif replay_buffer_cls == RecurrentRolloutBuffer: episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1) - hidden_states = {"a": {"b": th.zeros(2, buffer.n_envs, 4)}} - buffer.add(obs, action, reward, episode_start, values, log_prob, hidden_states) + one_lstm_states = (th.zeros((1, env.num_envs, 4)), th.zeros((1, env.num_envs, 4))) + hidden_states = RNNStates(one_lstm_states, one_lstm_states) + buffer.add(obs, action, reward, episode_start, values, log_prob, lstm_states=hidden_states) else: buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs From a6f9ed3b73297657b68e658065b5485ca4a2a83c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Thu, 21 Sep 2023 14:57:25 -0700 Subject: [PATCH 27/31] Are tests actually faster? --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 2c5dae597..70f122604 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -71,7 +71,7 @@ jobs: password: "$GHCR_DOCKER_TOKEN" resource_class: medium working_directory: /workspace/third_party/stable-baselines3 - parallelism: 24 + parallelism: 16 steps: - checkout - run: From dc8932aed964520bca47835e152adc4b872e82fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Thu, 21 Sep 2023 15:08:48 -0700 Subject: [PATCH 28/31] next_is_non_terminal --- stable_baselines3/common/buffers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 188717a89..a23156106 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -454,13 +454,13 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: th.Tensor last_gae_lam: Union[float, th.Tensor] = 0.0 for step in reversed(range(self.buffer_size)): if step == self.buffer_size - 1: - next_non_terminal = ~dones + next_is_non_terminal = ~dones next_values = last_values else: - next_non_terminal = ~self.episode_starts[step + 1] + next_is_non_terminal = ~self.episode_starts[step + 1] next_values = self.values[step + 1] - delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] - last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + delta = self.rewards[step] + self.gamma * next_values * next_is_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_is_non_terminal * last_gae_lam self.advantages[step] = last_gae_lam # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA From 7d46e47cf82b614de0907debef8114cfac2dc9b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Thu, 21 Sep 2023 15:08:59 -0700 Subject: [PATCH 29/31] use sb3_namespace by default --- stable_baselines3/common/pytree_dataclass.py | 22 +++++++++++++++++++- tests/test_buffers.py | 6 ++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/pytree_dataclass.py b/stable_baselines3/common/pytree_dataclass.py index 8fbe75fcb..a9b75b87d 100644 --- a/stable_baselines3/common/pytree_dataclass.py +++ b/stable_baselines3/common/pytree_dataclass.py @@ -1 +1,21 @@ -OT_NAMESPACE = "stable-baselines3" +from typing import Callable, TypeVar + +import optree as ot +from optree import PyTree as PyTree + +__all__ = ["tree_flatten", "PyTree"] + +T = TypeVar("T") + +SB3_NAMESPACE = "stable-baselines3" + + +def tree_flatten( + tree: ot.PyTree[T], + is_leaf: Callable[[T], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = SB3_NAMESPACE +) -> tuple[list[T], ot.PyTreeSpec]: + """optree.tree_flatten(...) but the default namespace is SB3_NAMESPACE""" + return ot.tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 37b2c7a0f..12fed2b3e 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -4,6 +4,9 @@ import pytest import torch as th from gymnasium import spaces +from third_party.stable_baselines3.stable_baselines3.common.pytree_dataclass import ( + tree_flatten, +) from stable_baselines3.common.buffers import ( DictReplayBuffer, @@ -13,7 +16,6 @@ ) from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.pytree_dataclass import OT_NAMESPACE from stable_baselines3.common.recurrent.buffers import ( RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, @@ -175,7 +177,7 @@ def test_device_buffer(replay_buffer_cls, device): # Check that all data are on the desired device desired_device = get_device(device).type for minibatch in list(data): - flattened_tensors, _ = ot.tree_flatten(minibatch, namespace=OT_NAMESPACE) + flattened_tensors, _ = tree_flatten(minibatch) assert len(flattened_tensors) > 3 for value in flattened_tensors: assert isinstance(value, th.Tensor) From 30ccd416c771359f2b8f0b4d0a76aa254e1997d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Thu, 21 Sep 2023 15:51:48 -0700 Subject: [PATCH 30/31] correct importing --- tests/test_buffers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 12fed2b3e..918988ffb 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -4,9 +4,6 @@ import pytest import torch as th from gymnasium import spaces -from third_party.stable_baselines3.stable_baselines3.common.pytree_dataclass import ( - tree_flatten, -) from stable_baselines3.common.buffers import ( DictReplayBuffer, @@ -16,6 +13,7 @@ ) from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.pytree_dataclass import tree_flatten from stable_baselines3.common.recurrent.buffers import ( RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, From fdc4370d5289f8cd59024fb9a5599057cbf07f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Fri, 6 Oct 2023 20:27:20 -0400 Subject: [PATCH 31/31] Make 100 not a magic number --- tests/test_buffers.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 918988ffb..e69328bf1 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -26,6 +26,8 @@ from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize +EP_LENGTH: int = 100 + class DummyEnv(gym.Env): """ @@ -38,7 +40,7 @@ def __init__(self): self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 - self._ep_length = 100 + self._ep_length = EP_LENGTH def reset(self, *, seed=None, options=None): self._t = 0 @@ -68,7 +70,7 @@ def __init__(self): self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 - self._ep_length = 100 + self._ep_length = EP_LENGTH def reset(self, seed=None, options=None): self._t = 0 @@ -98,12 +100,12 @@ def test_replay_buffer_normalization(replay_buffer_cls): env = make_vec_env(env) env = VecNormalize(env) - buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") + buffer = replay_buffer_cls(EP_LENGTH, env.observation_space, env.action_space, device="cpu") # Interract and store transitions env.reset() obs = env.get_original_obs() - for _ in range(100): + for _ in range(EP_LENGTH): action = th.as_tensor(env.action_space.sample()) _, _, done, info = env.step(action) next_obs = env.get_original_obs() @@ -143,14 +145,18 @@ def test_device_buffer(replay_buffer_cls, device): if replay_buffer_cls == RecurrentRolloutBuffer: buffer = RecurrentRolloutBuffer( - 100, env.observation_space, env.action_space, hidden_state_shape=(100, 1, env.num_envs, 4), device=device + EP_LENGTH, + env.observation_space, + env.action_space, + hidden_state_shape=(EP_LENGTH, 1, env.num_envs, 4), + device=device, ) else: - buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) + buffer = replay_buffer_cls(EP_LENGTH, env.observation_space, env.action_space, device=device) # Interract and store transitions obs = env.reset() - for _ in range(100): + for _ in range(EP_LENGTH): action = th.as_tensor(env.action_space.sample()) next_obs, reward, done, info = env.step(action)