Skip to content

Commit

Permalink
Correct types and formatting, steps_to_think
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Feb 27, 2024
1 parent 6882d98 commit 6d2ecf2
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 31 deletions.
2 changes: 2 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class BaseAlgorithm(ABC):
n_envs: int
lr_schedule: Schedule
_logger: Logger
_last_obs: Optional[TorchGymObsBasic]
_last_episode_starts: Optional[th.Tensor]

def __init__(
self,
Expand Down
19 changes: 16 additions & 3 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
import gymnasium as gym
import numpy as np
import torch as th

from stable_baselines3.common import type_aliases
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecEnv,
VecMonitor,
is_vecenv_wrapped,
)
from stable_baselines3.common.vec_env.util import obs_as_tensor


Expand All @@ -19,6 +25,7 @@ def evaluate_policy(
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
steps_to_think: Optional[int] = None,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
Expand Down Expand Up @@ -50,6 +57,8 @@ def evaluate_policy(
per episode will be returned instead of the mean.
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
evaluation environment.
:param steps_to_think: how many steps should the model think before taking the first action? If None, copy the
default from `model`.
:return: Mean reward per episode, std of reward per episode.
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
list containing per-episode rewards and second containing per-episode lengths
Expand Down Expand Up @@ -79,6 +88,10 @@ def evaluate_policy(
observations = env.reset()
observations = obs_as_tensor(observations, model.device)

if steps_to_think is None:
steps_to_think = getattr(model, "steps_to_think", 0)
assert steps_to_think is not None

# Hardcode episode counts and the reward accumulators to use CPU. They're used for bookkeeping and don't involve
# much computation.

Expand All @@ -92,8 +105,8 @@ def evaluate_policy(
episode_starts = th.ones((env.num_envs,), dtype=th.bool, device=model.device)
while (episode_counts < episode_count_targets).any():
with th.no_grad():
if model.steps_to_think > 0:
states = model.think_for_n_steps(observations, states, episode_starts)
if hasattr(model, "think_for_n_steps"):
states = model.think_for_n_steps(steps_to_think, observations, states, episode_starts)
actions, states = model.predict(
observations, # type: ignore[arg-type]
state=states,
Expand Down
6 changes: 5 additions & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import (
GymEnv,
MaybeCallback,
Schedule,
)
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env.util import obs_as_tensor
Expand Down
11 changes: 6 additions & 5 deletions stable_baselines3/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from stable_baselines3.common.preprocessing import preprocess_obs
from stable_baselines3.common.pytree_dataclass import tree_flatten
from stable_baselines3.common.recurrent.torch_layers import (
ExtractorInput,
GRUNatureCNNExtractor,
GRUWrappedFeaturesExtractor,
LSTMFlattenExtractor,
Expand Down Expand Up @@ -189,7 +188,7 @@ def predict( # type: ignore[override]
return actions, state


class RecurrentActorCriticPolicy(BaseRecurrentActorCriticPolicy):
class RecurrentActorCriticPolicy(BaseRecurrentActorCriticPolicy[ActorCriticStates[LSTMRecurrentState]]):
"""
Recurrent policy class for actor-critic algorithms (has both policy and value prediction).
To be used with A2C, PPO and the likes.
Expand Down Expand Up @@ -328,7 +327,9 @@ def _build_mlp_extractor(self) -> None:
device=self.device,
)

def recurrent_initial_state(self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None):
def recurrent_initial_state(
self, n_envs: Optional[int] = None, *, device: Optional[th.device | str] = None
) -> ActorCriticStates[LSTMRecurrentState]:
shape: tuple[int, ...]
if n_envs is None:
shape = (self.lstm_hidden_state_shape[0], self.lstm_hidden_state_shape[2])
Expand Down Expand Up @@ -628,8 +629,8 @@ def __init__(
)


class RecurrentFeaturesExtractorActorCriticPolicy(BaseRecurrentActorCriticPolicy, Generic[ExtractorInput, RecurrentState]):
features_extractor: RecurrentFeaturesExtractor[ExtractorInput, RecurrentState]
class RecurrentFeaturesExtractorActorCriticPolicy(BaseRecurrentActorCriticPolicy[RecurrentState], Generic[RecurrentState]):
features_extractor: RecurrentFeaturesExtractor[TorchGymObs, RecurrentState]

def __init__(
self,
Expand Down
67 changes: 45 additions & 22 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
import sys
import time
import warnings
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union

import numpy as np
import torch as th
import torch.nn.functional as F
from gymnasium import spaces

from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.pytree_dataclass import tree_map
from stable_baselines3.common.pytree_dataclass import tree_index, tree_map
from stable_baselines3.common.recurrent.buffers import RecurrentRolloutBuffer
from stable_baselines3.common.recurrent.policies import BaseRecurrentActorCriticPolicy, RecurrentActorCriticPolicy
from stable_baselines3.common.recurrent.type_aliases import ActorCriticStates, RecurrentRolloutBufferData
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, non_null
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, safe_mean
from stable_baselines3.common.recurrent.policies import (
BaseRecurrentActorCriticPolicy,
)
from stable_baselines3.common.recurrent.torch_layers import RecurrentState
from stable_baselines3.common.recurrent.type_aliases import (
RecurrentRolloutBufferData,
)
from stable_baselines3.common.type_aliases import (
GymEnv,
MaybeCallback,
Schedule,
TorchGymObs,
non_null,
)
from stable_baselines3.common.utils import (
explained_variance,
get_schedule_fn,
safe_mean,
)
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env.util import obs_as_tensor
from stable_baselines3.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy
from stable_baselines3.ppo_recurrent.policies import (
CnnLstmPolicy,
MlpLstmPolicy,
MultiInputLstmPolicy,
)

SelfRecurrentPPO = TypeVar("SelfRecurrentPPO", bound="RecurrentPPO")


class RecurrentPPO(OnPolicyAlgorithm):
class RecurrentPPO(OnPolicyAlgorithm, Generic[RecurrentState]):
"""
Proximal Policy Optimization algorithm (PPO) (clip version)
with support for recurrent policies (LSTM).
Expand Down Expand Up @@ -78,13 +98,13 @@ class RecurrentPPO(OnPolicyAlgorithm):
"MultiInputPolicy": MultiInputLstmPolicy,
}

policy: RecurrentActorCriticPolicy
policy_class: Type[RecurrentActorCriticPolicy]
policy: BaseRecurrentActorCriticPolicy[RecurrentState]
policy_class: Type[BaseRecurrentActorCriticPolicy[RecurrentState]]
rollout_buffer: RecurrentRolloutBuffer

def __init__(
self,
policy: Union[str, Type[BaseRecurrentActorCriticPolicy]],
policy: Union[str, Type[BaseRecurrentActorCriticPolicy[RecurrentState]]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 128,
Expand Down Expand Up @@ -183,27 +203,31 @@ def __init__(
self.clip_range_vf: Schedule = clip_range_vf # type: ignore
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
self._last_lstm_states: Optional[ActorCriticStates] = None
self._last_lstm_states: Optional[RecurrentState] = None
self.steps_to_think = steps_to_think

if _init_setup_model:
self._setup_model()

def think_for_n_steps(self, obs_tensor, lstm_states, episode_starts):
def think_for_n_steps(
self, n_steps: int, obs_tensor: TorchGymObs, lstm_states: Optional[RecurrentState], episode_starts: th.Tensor
) -> RecurrentState:
if lstm_states is None:
lstm_states = self.policy.recurrent_initial_state(self.env.num_envs, device=self.device)
out = self.policy.recurrent_initial_state(episode_starts.size(0), device=self.device)
lstm_states = out

if not episode_starts.any():
if not episode_starts.any() or n_steps == 0:
return lstm_states
obs_for_start_envs = obs_tensor[episode_starts, ...]
lstm_states_for_start_envs = lstm_states[episode_starts, ...]
for _ in range(self.steps_to_think):
# ignore because TorchGymObs and TensorTree do not match
obs_for_start_envs: TorchGymObs = tree_index(obs_tensor, (episode_starts,)) # type: ignore[type-var]
lstm_states_for_start_envs = tree_index(lstm_states, (episode_starts,))
for _ in range(n_steps):
_, _, _, lstm_states_for_start_envs = self.policy.forward(
obs_for_start_envs,
lstm_states_for_start_envs,
episode_starts[episode_starts],
)
lstm_states[episode_starts] = lstm_states_for_start_envs
lstm_states = tree_map(lambda x, y: x[episode_starts].copy_(y), lstm_states, lstm_states_for_start_envs)
return lstm_states

def _setup_model(self) -> None:
Expand All @@ -222,7 +246,7 @@ def _setup_model(self) -> None:
# if not isinstance(self.policy, RecurrentActorCriticPolicy):
# raise ValueError("Policy must subclass RecurrentActorCriticPolicy")

hidden_state_example = self.policy.recurrent_initial_state(n_envs=self.n_envs, device=self.device)
hidden_state_example: RecurrentState = self.policy.recurrent_initial_state(n_envs=self.n_envs, device=self.device)

self.rollout_buffer = RecurrentRolloutBuffer(
self.n_steps,
Expand Down Expand Up @@ -292,8 +316,7 @@ def collect_rollouts( # type: ignore[override]
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
episode_starts = non_null(self._last_episode_starts)
if self.steps_to_think > 0:
lstm_states = self.think_for_n_steps(obs_tensor, lstm_states, episode_starts)
lstm_states = self.think_for_n_steps(self.steps_to_think, obs_tensor, lstm_states, episode_starts)
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)

# Rescale and perform action
Expand Down

0 comments on commit 6d2ecf2

Please sign in to comment.