Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RecurrentRolloutBuffer generic to observations and states #8

Merged
merged 14 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
command: ruff .
- run:
name: Typecheck (mypy)
command: mypy stable_baselines3/common/pytree_dataclass.py tests/test_pytree_dataclass.py # TODO: remove, in PR#4.
command: mypy --exclude '^stable_baselines3/common/recurrent/policies\.py$' stable_baselines3/common tests # TODO: remove, in PR#4.
pytype:
docker:
- image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >>
Expand All @@ -62,7 +62,7 @@ jobs:
working_directory: /workspace/third_party/stable-baselines3
steps:
- checkout
- run: pytype -j 4 stable_baselines3/common/pytree_dataclass.py tests/test_pytree_dataclass.py # TODO: remove, in PR#4.
- run: pytype --keep-going -x stable_baselines3/common/recurrent/policies.py stable_baselines3/common tests # TODO: remove, in PR#4.
py-tests:
docker:
- image: ghcr.io/alignmentresearch/learned-planners:<< pipeline.parameters.docker_img_version >>
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pytest:
./scripts/run_tests.sh

pytype:
pytype -j auto
pytype -j auto --keep-going

mypy:
mypy ${LINT_PATHS}
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.n_envs = n_envs

@staticmethod
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
def swap_and_flatten(arr: th.Tensor) -> th.Tensor:
"""
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
Expand Down Expand Up @@ -766,7 +766,7 @@ class DictRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""

observations: Dict[str, np.ndarray]
observations: Dict[str, th.Tensor]

def __init__(
self,
Expand Down
389 changes: 155 additions & 234 deletions stable_baselines3/common/recurrent/buffers.py

Large diffs are not rendered by default.

39 changes: 24 additions & 15 deletions stable_baselines3/common/recurrent/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
from typing import NamedTuple, Tuple
from typing import Optional, Tuple, TypeVar

import torch as th

from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.pytree_dataclass import FrozenPyTreeDataclass, TensorTree

T = TypeVar("T")

class RNNStates(NamedTuple):

def non_null(v: Optional[T]) -> T:
if v is None:
raise ValueError("Expected a value, got None")
return v


LSTMStates = Tuple[th.Tensor, th.Tensor]


class RNNStates(FrozenPyTreeDataclass[th.Tensor]):
pi: Tuple[th.Tensor, ...]
vf: Tuple[th.Tensor, ...]


class RecurrentRolloutBufferSamples(NamedTuple):
observations: th.Tensor
class RecurrentRolloutBufferData(FrozenPyTreeDataclass[th.Tensor]):
observations: TensorTree
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
lstm_states: RNNStates
rewards: th.Tensor
episode_starts: th.Tensor
mask: th.Tensor
values: th.Tensor
log_probs: th.Tensor
hidden_states: TensorTree


class RecurrentDictRolloutBufferSamples(NamedTuple):
observations: TensorDict
class RecurrentRolloutBufferSamples(FrozenPyTreeDataclass[th.Tensor]):
observations: TensorTree
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
hidden_states: TensorTree
episode_starts: th.Tensor
advantages: th.Tensor
returns: th.Tensor
lstm_states: RNNStates
episode_starts: th.Tensor
mask: th.Tensor
59 changes: 38 additions & 21 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch as th
from gymnasium import spaces

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
Expand All @@ -16,7 +15,11 @@
RecurrentRolloutBuffer,
)
from stable_baselines3.common.recurrent.policies import RecurrentActorCriticPolicy
from stable_baselines3.common.recurrent.type_aliases import RNNStates
from stable_baselines3.common.recurrent.type_aliases import (
RecurrentRolloutBufferData,
RNNStates,
non_null,
)
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import (
explained_variance,
Expand Down Expand Up @@ -88,13 +91,17 @@ class RecurrentPPO(OnPolicyAlgorithm):
"MultiInputPolicy": MultiInputLstmPolicy,
}

policy: RecurrentActorCriticPolicy
policy_class: Type[RecurrentActorCriticPolicy]
rollout_buffer: RecurrentRolloutBuffer

def __init__(
self,
policy: Union[str, Type[RecurrentActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 128,
batch_size: Optional[int] = 128,
batch_size: int = 128,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
Expand Down Expand Up @@ -175,9 +182,21 @@ def _setup_model(self) -> None:
if not isinstance(self.policy, RecurrentActorCriticPolicy):
raise ValueError("Policy must subclass RecurrentActorCriticPolicy")

single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size)
per_timestep_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size)
# hidden and cell states for actor and critic
self._last_lstm_states = RNNStates(
(
th.zeros(per_timestep_hidden_state_shape, device=self.device),
th.zeros(per_timestep_hidden_state_shape, device=self.device),
),
(
th.zeros(per_timestep_hidden_state_shape, device=self.device),
th.zeros(per_timestep_hidden_state_shape, device=self.device),
),
)

single_hidden_state_shape = (lstm.num_layers, lstm.hidden_size)
hidden_state_example = RNNStates(
(
th.zeros(single_hidden_state_shape, device=self.device),
th.zeros(single_hidden_state_shape, device=self.device),
Expand All @@ -188,13 +207,11 @@ def _setup_model(self) -> None:
),
)

hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)

self.rollout_buffer = buffer_cls(
self.n_steps,
self.observation_space,
self.action_space,
hidden_state_buffer_shape,
hidden_state_example,
self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
Expand All @@ -209,11 +226,11 @@ def _setup_model(self) -> None:

self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

def collect_rollouts(
def collect_rollouts( # type: ignore[override]
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
rollout_buffer: RecurrentRolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""
Expand All @@ -229,9 +246,7 @@ def collect_rollouts(
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert isinstance(
rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer)
), f"{rollout_buffer} doesn't support recurrent policy"
assert isinstance(rollout_buffer, RecurrentRolloutBuffer), f"{rollout_buffer} doesn't support recurrent policy"

assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
Expand All @@ -255,7 +270,7 @@ def collect_rollouts(
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
episode_starts = th.tensor(self._last_episode_starts, dtype=th.bool, device=self.device)
episode_starts = non_null(self._last_episode_starts)
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)

# Rescale and perform action
Expand Down Expand Up @@ -304,13 +319,15 @@ def collect_rollouts(
rewards[idx] += self.gamma * terminal_value

rollout_buffer.add(
self._last_obs,
actions,
rewards,
self._last_episode_starts,
values,
log_probs,
lstm_states=self._last_lstm_states,
RecurrentRolloutBufferData(
self._last_obs,
actions,
rewards,
non_null(self._last_episode_starts),
values.squeeze(-1),
log_probs,
hidden_states=non_null(self._last_lstm_states),
)
)

self._last_obs = new_obs
Expand Down Expand Up @@ -368,7 +385,7 @@ def train(self) -> None:
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations,
actions,
rollout_data.lstm_states,
rollout_data.hidden_states,
rollout_data.episode_starts,
)

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a3
2.3.0+learned-planners-a1
26 changes: 11 additions & 15 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.pytree_dataclass import tree_flatten
from stable_baselines3.common.recurrent.buffers import (
RecurrentDictRolloutBuffer,
RecurrentRolloutBuffer,
)
from stable_baselines3.common.recurrent.type_aliases import RNNStates
from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer
from stable_baselines3.common.recurrent.type_aliases import RecurrentRolloutBufferData
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
ReplayBufferSamples,
Expand Down Expand Up @@ -124,6 +121,9 @@ def test_replay_buffer_normalization(replay_buffer_cls):
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)


HIDDEN_STATES_EXAMPLE = {"a": {"b": th.zeros(2, 4)}}


@pytest.mark.parametrize(
"replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, RecurrentRolloutBuffer]
)
Expand All @@ -137,22 +137,20 @@ def test_device_buffer(replay_buffer_cls, device):
DictRolloutBuffer: DummyDictEnv,
ReplayBuffer: DummyEnv,
DictReplayBuffer: DummyDictEnv,
RecurrentRolloutBuffer: DummyEnv,
RecurrentDictRolloutBuffer: DummyDictEnv,
RecurrentRolloutBuffer: DummyDictEnv,
}[replay_buffer_cls]
env = make_vec_env(env)

if replay_buffer_cls == RecurrentRolloutBuffer:
buffer = RecurrentRolloutBuffer(
EP_LENGTH,
env.observation_space,
env.action_space,
hidden_state_shape=(EP_LENGTH, 1, env.num_envs, 4),
device=device,
EP_LENGTH, env.observation_space, env.action_space, hidden_state_example=HIDDEN_STATES_EXAMPLE, device=device
)
else:
buffer = replay_buffer_cls(EP_LENGTH, env.observation_space, env.action_space, device=device)

hidden_states_shape = HIDDEN_STATES_EXAMPLE["a"]["b"].shape
N_ENVS_HIDDEN_STATES = {"a": {"b": th.zeros((hidden_states_shape[0], env.num_envs, *hidden_states_shape[1:]))}}

# Interract and store transitions
obs = env.reset()
for _ in range(EP_LENGTH):
Expand All @@ -164,9 +162,7 @@ def test_device_buffer(replay_buffer_cls, device):
buffer.add(obs, action, reward, episode_start, values, log_prob)
elif replay_buffer_cls == RecurrentRolloutBuffer:
episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1)
one_lstm_states = (th.zeros((1, env.num_envs, 4)), th.zeros((1, env.num_envs, 4)))
hidden_states = RNNStates(one_lstm_states, one_lstm_states)
buffer.add(obs, action, reward, episode_start, values, log_prob, lstm_states=hidden_states)
buffer.add(RecurrentRolloutBufferData(obs, action, reward, episode_start, values, log_prob, N_ENVS_HIDDEN_STATES))
else:
buffer.add(obs, next_obs, action, reward, done, info)
obs = next_obs
Expand Down