Skip to content

Commit

Permalink
Merge pull request #5 from AlignmentResearch/start-from-numpy
Browse files Browse the repository at this point in the history
Port in torchified PPO from sb3_contrib
  • Loading branch information
rhaps0dy authored Oct 7, 2023
2 parents 8ccaa72 + fdc4370 commit fc9b730
Show file tree
Hide file tree
Showing 16 changed files with 287 additions and 178 deletions.
2 changes: 2 additions & 0 deletions stable_baselines3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from stable_baselines3.dqn import DQN
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from stable_baselines3.ppo import PPO
from stable_baselines3.ppo_recurrent import RecurrentPPO
from stable_baselines3.sac import SAC
from stable_baselines3.td3 import TD3

Expand All @@ -27,6 +28,7 @@ def HER(*args, **kwargs):
"DDPG",
"DQN",
"PPO",
"RecurrentPPO",
"SAC",
"TD3",
"HerReplayBuffer",
Expand Down
12 changes: 6 additions & 6 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def reset(self) -> None:
self.actions = th.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=th.float32, device=self.device)
self.rewards = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.returns = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool, device=self.device)
self.values = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.log_probs = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.advantages = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
Expand Down Expand Up @@ -454,13 +454,13 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: th.Tensor
last_gae_lam: Union[float, th.Tensor] = 0.0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = ~dones
next_is_non_terminal = ~dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_is_non_terminal = ~self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
delta = self.rewards[step] + self.gamma * next_values * next_is_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_is_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
Expand Down Expand Up @@ -791,7 +791,7 @@ def __init__(
self.actions = th.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=th.float32)
self.rewards = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.returns = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.episode_starts = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool, device=self.device)
self.values = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.log_probs = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
self.advantages = th.zeros((self.buffer_size, self.n_envs), dtype=th.float32, device=self.device)
Expand Down
21 changes: 21 additions & 0 deletions stable_baselines3/common/pytree_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Callable, TypeVar

import optree as ot
from optree import PyTree as PyTree

__all__ = ["tree_flatten", "PyTree"]

T = TypeVar("T")

SB3_NAMESPACE = "stable-baselines3"


def tree_flatten(
tree: ot.PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = SB3_NAMESPACE
) -> tuple[list[T], ot.PyTreeSpec]:
"""optree.tree_flatten(...) but the default namespace is SB3_NAMESPACE"""
return ot.tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
143 changes: 77 additions & 66 deletions stable_baselines3/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,44 @@
from typing import Callable, Generator, Optional, Tuple, Union

import numpy as np
import optree as ot
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize

from sb3_contrib.common.recurrent.type_aliases import (
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.recurrent.type_aliases import (
RecurrentDictRolloutBufferSamples,
RecurrentRolloutBufferSamples,
RNNStates,
)
from stable_baselines3.common.vec_env import VecNormalize


def pad(
seq_start_indices: np.ndarray,
seq_end_indices: np.ndarray,
device: th.device,
tensor: np.ndarray,
seq_start_indices: th.Tensor,
seq_end_indices: th.Tensor,
tensor: th.Tensor,
padding_value: float = 0.0,
) -> th.Tensor:
"""
Chunk sequences and pad them to have constant dimensions.
:param seq_start_indices: Indices of the transitions that start a sequence
:param seq_end_indices: Indices of the transitions that end a sequence
:param device: PyTorch device
:param tensor: Tensor of shape (batch_size, *tensor_shape)
:param padding_value: Value used to pad sequence to the same length
(zero padding by default)
:return: (n_seq, max_length, *tensor_shape)
"""
# Create sequences given start and end
seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)]
seq = [tensor[start : end + 1] for start, end in zip(seq_start_indices, seq_end_indices)]
return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value)


def pad_and_flatten(
seq_start_indices: np.ndarray,
seq_end_indices: np.ndarray,
device: th.device,
tensor: np.ndarray,
seq_start_indices: th.Tensor,
seq_end_indices: th.Tensor,
tensor: th.Tensor,
padding_value: float = 0.0,
) -> th.Tensor:
"""
Expand All @@ -51,46 +49,48 @@ def pad_and_flatten(
:param seq_start_indices: Indices of the transitions that start a sequence
:param seq_end_indices: Indices of the transitions that end a sequence
:param device: PyTorch device (cpu, gpu, ...)
:param tensor: Tensor of shape (max_length, n_seq, 1)
:param padding_value: Value used to pad sequence to the same length
(zero padding by default)
:return: (n_seq * max_length,) aka (padded_batch_size,)
"""
return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten()
return pad(seq_start_indices, seq_end_indices, tensor, padding_value).flatten()


def create_sequencers(
episode_starts: np.ndarray,
env_change: np.ndarray,
device: th.device,
) -> Tuple[np.ndarray, Callable, Callable]:
episode_starts: th.Tensor,
env_change: th.Tensor,
) -> Tuple[th.Tensor, Callable, Callable]:
"""
Create the utility function to chunk data into
sequences and pad them to create fixed size tensors.
:param episode_starts: Indices where an episode starts
:param env_change: Indices where the data collected
come from a different env (when using multiple env for data collection)
:param device: PyTorch device
:return: Indices of the transitions that start a sequence,
pad and pad_and_flatten utilities tailored for this batch
(sequence starts and ends indices are fixed)
"""
# Create sequence if env changes too
seq_start = np.logical_or(episode_starts, env_change).flatten()
seq_start = (episode_starts | env_change).flatten()
# First index is always the beginning of a sequence
seq_start[0] = True
# Retrieve indices of sequence starts
seq_start_indices = np.where(seq_start == True)[0] # noqa: E712
seq_start_indices = th.argwhere(seq_start).squeeze(1)
# End of sequence are just before sequence starts
# Last index is also always end of a sequence
seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])])
seq_end_indices = th.cat(
[
(seq_start_indices - 1)[1:],
th.tensor([len(episode_starts)], device=seq_start_indices.device, dtype=seq_start_indices.dtype),
]
)

# Create padding method for this minibatch
# to avoid repeating arguments (seq_start_indices, seq_end_indices)
local_pad = partial(pad, seq_start_indices, seq_end_indices, device)
local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device)
local_pad = partial(pad, seq_start_indices, seq_end_indices)
local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices)
return seq_start_indices, local_pad, local_pad_and_flatten


Expand Down Expand Up @@ -127,21 +127,21 @@ def __init__(

def reset(self):
super().reset()
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)
self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)
self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)
self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)

def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
"""
:param hidden_states: LSTM cell and hidden state
"""
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
self.hidden_states_pi[self.pos].copy_(lstm_states.pi[0], non_blocking=True)
self.cell_states_pi[self.pos].copy_(lstm_states.pi[1], non_blocking=True)
self.hidden_states_vf[self.pos].copy_(lstm_states.vf[0], non_blocking=True)
self.cell_states_vf[self.pos].copy_(lstm_states.vf[1], non_blocking=True)

super().add(*args, **kwargs)
super().add(*(th.as_tensor(a) for a in args), **kwargs)

def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"
Expand Down Expand Up @@ -180,13 +180,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf
# more complexity and use of padding
# Trick to shuffle a bit: keep the sequence order
# but split the indices in two
split_index = np.random.randint(self.buffer_size * self.n_envs)
indices = np.arange(self.buffer_size * self.n_envs)
indices = np.concatenate((indices[split_index:], indices[:split_index]))
split_index = int(np.random.randint(self.buffer_size * self.n_envs))
indices = th.arange(self.buffer_size * self.n_envs)
indices = th.cat((indices[split_index:], indices[:split_index]))

env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
env_change = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool)
# Flag first timestep as change of environment
env_change[0, :] = 1.0
env_change[0, :] = True
env_change = self.swap_and_flatten(env_change)

start_idx = 0
Expand All @@ -197,13 +197,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf

def _get_samples(
self,
batch_inds: np.ndarray,
env_change: np.ndarray,
batch_inds: th.Tensor,
env_change: th.Tensor,
env: Optional[VecNormalize] = None,
) -> RecurrentRolloutBufferSamples:
# Retrieve sequence starts and utility function
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
self.episode_starts[batch_inds], env_change[batch_inds], self.device
self.episode_starts[batch_inds], env_change[batch_inds]
)

# Number of sequences
Expand All @@ -224,9 +224,14 @@ def _get_samples(
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())

lstm_states_pi = (
self.to_device((lstm_states_pi[0])).contiguous(),
self.to_device((lstm_states_pi[1])).contiguous(),
)
lstm_states_vf = (
self.to_device((lstm_states_vf[0])).contiguous(),
self.to_device((lstm_states_vf[1])).contiguous(),
)
return RecurrentRolloutBufferSamples(
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size, *self.obs_shape)),
Expand All @@ -237,7 +242,7 @@ def _get_samples(
returns=self.pad_and_flatten(self.returns[batch_inds]),
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
mask=self.pad_and_flatten(th.ones_like(self.returns[batch_inds])),
)


Expand Down Expand Up @@ -274,21 +279,21 @@ def __init__(

def reset(self):
super().reset()
self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
self.hidden_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)
self.cell_states_pi = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)
self.hidden_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)
self.cell_states_vf = th.zeros(self.hidden_state_shape, dtype=th.float32, device=self.device)

def add(self, *args, lstm_states: RNNStates, **kwargs) -> None:
"""
:param hidden_states: LSTM cell and hidden state
"""
self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())
self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())
self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())
self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())
self.hidden_states_pi[self.pos].copy_(lstm_states.pi[0], non_blocking=True)
self.cell_states_pi[self.pos].copy_(lstm_states.pi[1], non_blocking=True)
self.hidden_states_vf[self.pos].copy_(lstm_states.vf[0], non_blocking=True)
self.cell_states_vf[self.pos].copy_(lstm_states.vf[1], non_blocking=True)

super().add(*args, **kwargs)
super().add(*ot.tree_map(th.as_tensor, args), **kwargs)

def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"
Expand Down Expand Up @@ -324,13 +329,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou

# Trick to shuffle a bit: keep the sequence order
# but split the indices in two
split_index = np.random.randint(self.buffer_size * self.n_envs)
indices = np.arange(self.buffer_size * self.n_envs)
indices = np.concatenate((indices[split_index:], indices[:split_index]))
split_index = int(np.random.randint(self.buffer_size * self.n_envs))
indices = th.arange(self.buffer_size * self.n_envs)
indices = th.cat((indices[split_index:], indices[:split_index]))

env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs)
env_change = th.zeros((self.buffer_size, self.n_envs), dtype=th.bool)
# Flag first timestep as change of environment
env_change[0, :] = 1.0
env_change[0, :] = True
env_change = self.swap_and_flatten(env_change)

start_idx = 0
Expand All @@ -341,13 +346,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou

def _get_samples(
self,
batch_inds: np.ndarray,
env_change: np.ndarray,
batch_inds: th.Tensor,
env_change: th.Tensor,
env: Optional[VecNormalize] = None,
) -> RecurrentDictRolloutBufferSamples:
# Retrieve sequence starts and utility function
self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers(
self.episode_starts[batch_inds], env_change[batch_inds], self.device
self.episode_starts[batch_inds], env_change[batch_inds]
)

n_seq = len(self.seq_start_indices)
Expand All @@ -365,8 +370,14 @@ def _get_samples(
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())
lstm_states_pi = (
self.to_device((lstm_states_pi[0])).contiguous(),
self.to_device((lstm_states_pi[1])).contiguous(),
)
lstm_states_vf = (
self.to_device((lstm_states_vf[0])).contiguous(),
self.to_device((lstm_states_vf[1])).contiguous(),
)

observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
Expand All @@ -380,5 +391,5 @@ def _get_samples(
returns=self.pad_and_flatten(self.returns[batch_inds]),
lstm_states=RNNStates(lstm_states_pi, lstm_states_vf),
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
mask=self.pad_and_flatten(th.ones_like(self.returns[batch_inds])),
)
Loading

0 comments on commit fc9b730

Please sign in to comment.