diff --git a/docs/README.md b/docs/README.md index 169a5e3db..1fc4d762e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,7 +8,7 @@ This folder contains documentation for the RL baselines. #### Install Sphinx and Theme Execute this command in the project root: ``` -pip install -e .[docs] +pip install -e ".[docs]" ``` #### Building the Docs diff --git a/docs/conf.py b/docs/conf.py index 29077dcb3..9138aaa6e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,6 +11,7 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import datetime import os import sys from typing import Dict @@ -43,7 +44,7 @@ # -- Project information ----------------------------------------------------- project = "Stable Baselines3" -copyright = "2022, Stable Baselines3" +copyright = f"2021-{datetime.date.today().year}, Stable Baselines3" author = "Stable Baselines3 Contributors" # The short X.Y version diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index a136bfa59..0807498e4 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -215,7 +215,7 @@ downsampling and "vector" with a single linear layer. from stable_baselines3.common.torch_layers import BaseFeaturesExtractor class CustomCombinedExtractor(BaseFeaturesExtractor): - def __init__(self, observation_space: spaces.Dict): + def __init__(self, observation_space: gym.spaces.Dict): # We do not know features-dim here before going over all the items, # so put something dummy for now. PyTorch requires calling # nn.Module.__init__ before adding modules diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 8670b2fce..919a351ad 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -320,7 +320,7 @@ You can control the evaluation frequency with ``eval_freq`` to monitor your agen from stable_baselines3 import SAC from stable_baselines3.common.callbacks import EvalCallback - from stable-baselines3.common.env_util import make_vec_env + from stable_baselines3.common.env_util import make_vec_env env_id = "Pendulum-v1" n_training_envs = 1 @@ -331,7 +331,7 @@ You can control the evaluation frequency with ``eval_freq`` to monitor your agen os.makedirs(eval_log_dir, exist_ok=True) # Initialize a vectorized training environment with default parameters - train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0) + train_env = make_vec_env(env_id, n_envs=n_training_envs, seed=0) # Separate evaluation env, with different parameters passed via env_kwargs # Eval environments can be vectorized to speed up evaluation. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7bb1f16a1..fd209dea8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,10 +3,11 @@ Changelog ========== - -Release 2.1.0a0 (WIP) +Release 2.1.0 (2023-08-17) -------------------------- +**Float64 actions , Gymnasium 0.29 support and bug fixes** + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed Python 3.7 support @@ -15,15 +16,24 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added Python 3.11 support +- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts) `SB3-Contrib`_ ^^^^^^^^^^^^^^ +- Fixed MaskablePPO ignoring ``stats_window_size`` argument +- Added Python 3.11 support `RL Zoo`_ ^^^^^^^^^ +- Upgraded to Huggingface-SB3 >= 2.3 +- Added Python 3.11 support + Bug Fixes: ^^^^^^^^^^ +- Relaxed check in logger, that was causing issue on Windows with colorama +- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer) +- Fixed ``env_checker.py`` warning messages for out of bounds in complex observation spaces (@Gabo-Tor) Deprecations: ^^^^^^^^^^^^^ @@ -32,9 +42,14 @@ Others: ^^^^^^^ - Updated GitHub issue templates - Fix typo in gym patch error message (@lukashass) +- Refactor ``test_spaces.py`` tests Documentation: ^^^^^^^^^^^^^^ +- Fixed callback example (@BertrandDecoster) +- Fixed policy network example (@kyle-he) +- Added mobile-env as new community project (@stefanbschneider) +- Added [DeepNetSlice](https://github.com/AlexPasqua/DeepNetSlice) to community projects (@AlexPasqua) Release 2.0.0 (2023-06-22) @@ -1391,8 +1406,8 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@carlosluis @arjun-kg @tlpss @JonathanKuelz +@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 @lukashass +@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he diff --git a/docs/misc/projects.rst b/docs/misc/projects.rst index 9d05d54c1..1c3ba95fc 100644 --- a/docs/misc/projects.rst +++ b/docs/misc/projects.rst @@ -197,3 +197,35 @@ A simple library for pink noise exploration with deterministic (DDPG / TD3) and | Authors: Onno Eberhard, Jakob Hollenstein, Cristina Pinneri, Georg Martius | Github: https://github.com/martius-lab/pink-noise-rl | Paper: https://openreview.net/forum?id=hQ9V5QN27eS (Oral at ICLR 2023) + + +mobile-env +---------- + +An open, minimalist Gymnasium environment for autonomous coordination in wireless mobile networks. +It allows simulating various scenarios with moving users in a cellular network with multiple base stations. + +- Written in pure Python, easy to modify and extend, and can be installed directly via PyPI. +- Implements the standard Gymnasium interface such that it can be used with all common frameworks for reinforcement learning. +- There are examples for both single-agent and multi-agent RL using either `stable-baselines3` or Ray RLlib. + +| Authors: Stefan Schneider, Stefan Werner +| Github: https://github.com/stefanbschneider/mobile-env +| Paper: https://ris.uni-paderborn.de/download/30236/30237 (2022 IEEE/IFIP Network Operations and Management Symposium (NOMS)) + + +DeepNetSlice +------------ + +A Deep Reinforcement Learning Open-Source Toolkit for Network Slice Placement (NSP). + +NSP is the problem of deciding which physical servers in a network should host the virtual network functions (VNFs) that make up a network slice, as well as managing the mapping of the virtual links between the VNFs onto the physical infrastructure. +It is a complex optimization problem, as it involves considering the requirements of the network slice and the available resources on the physical network. +The goal is generally to maximize the utilization of the physical resources while ensuring that the network slices meet their performance requirements. + +The toolkit includes a customizable simulation environments, as well as some ready-to-use demos for training +intelligent agents to perform network slice placement. + +| Author: Alex Pasquali +| Github: https://github.com/AlexPasqua/DeepNetSlice +| Paper: **under review** (citation instructions on the project's README.md) -> see this Master's Thesis for the moment: https://etd.adm.unipi.it/theses/available/etd-01182023-110038/unrestricted/Tesi_magistrale_Pasquali_Alex.pdf diff --git a/setup.py b/setup.py index 6ddb9b65d..deb9f5498 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium==0.28.1", + "gymnasium>=0.28.1,<0.30", "numpy>=1.20", "torch>=1.13", # For saving models diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index fe633e1af..576e10a8b 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -207,7 +207,9 @@ def __init__( else: self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.actions = np.zeros( + (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) + ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -311,6 +313,21 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + @staticmethod + def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike: + """ + Cast `np.float64` action datatype to `np.float32`, + keep the others dtype unchanged. + See GH#1572 for more information. + + :param dtype: The original action space dtype + :return: ``np.float32`` if the dtype was float64, + the original dtype otherwise. + """ + if dtype == np.float64: + return np.float32 + return dtype + class RolloutBuffer(BaseBuffer): """ @@ -543,7 +560,9 @@ def __init__( for key, _obs_shape in self.obs_shape.items() } - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.actions = np.zeros( + (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype) + ) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 516d7ba61..8b8da7f44 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -203,18 +203,24 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" ) if isinstance(observation_space, spaces.Box): - assert np.all(obs >= observation_space.low), ( - f"The observation returned by the `{method_name}()` method does not match the lower bound " - f"of the given observation space {observation_space}." - f"Expected: obs >= {np.min(observation_space.low)}, " - f"actual min value: {np.min(obs)} at index {np.argmin(obs)}" - ) - assert np.all(obs <= observation_space.high), ( - f"The observation returned by the `{method_name}()` method does not match the upper bound " - f"of the given observation space {observation_space}. " - f"Expected: obs <= {np.max(observation_space.high)}, " - f"actual max value: {np.max(obs)} at index {np.argmax(obs)}" - ) + lower_bounds, upper_bounds = observation_space.low, observation_space.high + # Expose all invalid indices at once + invalid_indices = np.where(np.logical_or(obs < lower_bounds, obs > upper_bounds)) + if (obs > upper_bounds).any() or (obs < lower_bounds).any(): + message = ( + f"The observation returned by the `{method_name}()` method does not match the bounds " + f"of the given observation space {observation_space}. \n" + ) + message += f"{len(invalid_indices[0])} invalid indices: \n" + + for index in zip(*invalid_indices): + index_str = ",".join(map(str, index)) + message += ( + f"Expected: {lower_bounds[index]} <= obs[{index_str}] <= {upper_bounds[index]}, " + f"actual value: {obs[index]} \n" + ) + + raise AssertionError(message) assert observation_space.contains(obs), ( f"The observation returned by the `{method_name}()` method " diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 4d0d3461e..3955131c5 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -164,8 +164,10 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): if isinstance(filename_or_file, str): self.file = open(filename_or_file, "w") self.own_file = True - elif isinstance(filename_or_file, TextIOBase): - self.file = filename_or_file + elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"): + # Note: in theory `TextIOBase` check should be sufficient, + # in practice, libraries don't always inherit from it, see GH#1598 + self.file = filename_or_file # type: ignore[assignment] self.own_file = False else: raise ValueError(f"Expected file or str, got {filename_or_file}") diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 5253954e8..f421a4df2 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -193,7 +193,9 @@ def __init__( mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 self.file_handler = open(filename, f"{mode}t", newline="\n") - self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys)) + self.logger = csv.DictWriter( + self.file_handler, fieldnames=("r", "l", "t", *extra_keys) + ) # pytype: disable=wrong-arg-types if override_existing: self.file_handler.write(f"#{json.dumps(header)}\n") self.logger.writeheader() diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 6b44254cc..42e3d0df0 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -245,7 +245,7 @@ def predict( if not deterministic and np.random.rand() < self.exploration_rate: if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): - n_batch = observation[list(observation.keys())[0]].shape[0] + n_batch = observation[next(iter(observation.keys()))].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ecaf4eea7..7ec1d6db4 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.1.0a0 +2.1.0 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index c0a5e0610..87cc177b7 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -42,15 +42,28 @@ def test_check_env_dict_action(): [ # Above upper bound ( - spaces.Box(low=0.0, high=1.0, shape=(3,), dtype=np.float32), + spaces.Box(low=np.array([0.0, 0.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32), np.array([1.0, 1.5, 0.5], dtype=np.float32), - r"Expected: obs <= 1\.0, actual max value: 1\.5 at index 1", + r"Expected: 0\.0 <= obs\[1] <= 1\.0, actual value: 1\.5", + ), + # Above upper bound (multi-dim) + ( + spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32), + 3.0 * np.ones((2, 3, 3, 1), dtype=np.float32), + # Note: this is one of the 18 invalid indices + r"Expected: -1\.0 <= obs\[1,2,1,0\] <= 2\.0, actual value: 3\.0", ), # Below lower bound ( - spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32), + spaces.Box(low=np.array([0.0, -10.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32), np.array([-1.0, 1.5, 0.5], dtype=np.float32), - r"Expected: obs >= 0\.0, actual min value: -1\.0 at index 0", + r"Expected: 0\.0 <= obs\[0] <= 2\.0, actual value: -1\.0", + ), + # Below lower bound (multi-dim) + ( + spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32), + -2 * np.ones((2, 3, 3, 1), dtype=np.float32), + r"18 invalid indices:", ), # Wrong dtype ( @@ -111,7 +124,7 @@ def step(self, action): test_env = TestEnv() with pytest.raises(AssertionError, match=error_message): - check_env(env=test_env) + check_env(env=test_env, warn=False) class LimitedStepsTestEnv(gym.Env): diff --git a/tests/test_envs.py b/tests/test_envs.py index e6c973852..e82ef5768 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -156,8 +156,6 @@ def patched_step(_action): spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32), # Too small range spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32), - # Inverted boundaries - spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32), # Same boundaries spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32), # Unbounded action space diff --git a/tests/test_logger.py b/tests/test_logger.py index 9d275a2ec..05bf196a3 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -437,8 +437,9 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size): assert model.ep_success_buffer.maxlen == stats_window_size -def test_human_output_format_custom_test_io(): - class DummyTextIO(TextIOBase): +@pytest.mark.parametrize("base_class", [object, TextIOBase]) +def test_human_output_format_custom_test_io(base_class): + class DummyTextIO(base_class): def __init__(self) -> None: super().__init__() self.lines = [[]] diff --git a/tests/test_spaces.py b/tests/test_spaces.py index fb70d0a33..e4a933976 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,63 +1,67 @@ +from dataclasses import dataclass from typing import Dict, Optional import gymnasium as gym import numpy as np import pytest from gymnasium import spaces +from gymnasium.spaces.space import Space from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy +BOX_SPACE_FLOAT64 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float64) +BOX_SPACE_FLOAT32 = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) -class DummyMultiDiscreteSpace(gym.Env): - def __init__(self, nvec): - super().__init__() - self.observation_space = spaces.MultiDiscrete(nvec) - self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - if seed is not None: - super().reset(seed=seed) - return self.observation_space.sample(), {} +@dataclass +class DummyEnv(gym.Env): + observation_space: Space + action_space: Space def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} - -class DummyMultiBinary(gym.Env): - def __init__(self, n): - super().__init__() - self.observation_space = spaces.MultiBinary(n) - self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} - -class DummyMultidimensionalAction(gym.Env): +class DummyMultidimensionalAction(DummyEnv): def __init__(self): - super().__init__() - self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + super().__init__( + BOX_SPACE_FLOAT32, + spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32), + ) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - if seed is not None: - super().reset(seed=seed) - return self.observation_space.sample(), {} - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} +class DummyMultiBinary(DummyEnv): + def __init__(self, n): + super().__init__( + spaces.MultiBinary(n), + BOX_SPACE_FLOAT32, + ) + + +class DummyMultiDiscreteSpace(DummyEnv): + def __init__(self, nvec): + super().__init__( + spaces.MultiDiscrete(nvec), + BOX_SPACE_FLOAT32, + ) @pytest.mark.parametrize( - "env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()] + "env", + [ + DummyMultiDiscreteSpace([4, 3]), + DummyMultiBinary(8), + DummyMultiBinary((3, 2)), + DummyMultidimensionalAction(), + ], ) def test_env(env): # Check the env used for testing @@ -127,3 +131,40 @@ def test_discrete_obs_space(model_class, env): else: kwargs = dict(n_steps=256) model_class("MlpPolicy", env, **kwargs).learn(256) + + +@pytest.mark.parametrize("model_class", [SAC, TD3, PPO, DDPG, A2C]) +@pytest.mark.parametrize( + "obs_space", + [ + BOX_SPACE_FLOAT32, + BOX_SPACE_FLOAT64, + spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT32}), + spaces.Dict({"a": BOX_SPACE_FLOAT32, "b": BOX_SPACE_FLOAT64}), + ], +) +@pytest.mark.parametrize( + "action_space", + [ + BOX_SPACE_FLOAT32, + BOX_SPACE_FLOAT64, + ], +) +def test_float64_action_space(model_class, obs_space, action_space): + env = DummyEnv(obs_space, action_space) + env = gym.wrappers.TimeLimit(env, max_episode_steps=200) + if isinstance(env.observation_space, spaces.Dict): + policy = "MultiInputPolicy" + else: + policy = "MlpPolicy" + + if model_class in [PPO, A2C]: + kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12])) + else: + kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12])) + + model = model_class(policy, env, **kwargs) + model.learn(64) + initial_obs, _ = env.reset() + action, _ = model.predict(initial_obs, deterministic=False) + assert action.dtype == env.action_space.dtype