From 9c338f917a822c8be13b1aa9f7b2319770481b62 Mon Sep 17 00:00:00 2001 From: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com> Date: Sat, 20 May 2023 11:30:54 +0300 Subject: [PATCH] `vec_env`s fix `seed()` causing a reset (#1486) * `dummy_vec_env` fix `seed()` causing a reset * rename `seed` * fixes * bug fix * fix seed return type * Cleanup seeding, add test and remove compat wrapper * Update env checker and tests * Add deterministic test for make_vec_env --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 4 ++- pyproject.toml | 1 - stable_baselines3/common/env_checker.py | 5 +++ stable_baselines3/common/env_util.py | 9 ++++-- stable_baselines3/common/noise.py | 2 +- stable_baselines3/common/utils.py | 16 ---------- .../common/vec_env/base_vec_env.py | 19 +++++++++-- .../common/vec_env/dummy_vec_env.py | 17 +++------- .../common/vec_env/subproc_vec_env.py | 32 ++++++++----------- stable_baselines3/version.txt | 2 +- tests/test_buffers.py | 4 +-- tests/test_env_checker.py | 2 +- tests/test_envs.py | 23 ++++++++++--- tests/test_logger.py | 4 +-- tests/test_predict.py | 2 +- tests/test_spaces.py | 8 +++-- tests/test_vec_check_nan.py | 4 +-- tests/test_vec_envs.py | 16 +++++++--- 18 files changed, 94 insertions(+), 76 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2edf3adec..27bf79c03 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a8 (WIP) +Release 2.0.0a9 (WIP) -------------------------- **Gymnasium support** @@ -22,6 +22,7 @@ Breaking Changes: - Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit) - Upgraded wrappers and custom environment to Gymnasium - Refined the ``HumanOutputFormat`` file check: now it verifies if the object is an instance of ``io.TextIOBase`` instead of only checking for the presence of a ``write`` method. +- Because of new Gym API (0.26+), the random seed passed to ``vec_env.seed(seed=seed)`` will only be effective after then ``env.reset()`` call. New Features: ^^^^^^^^^^^^^ @@ -55,6 +56,7 @@ Others: - Fixed ``stable_baselines3/common/vec_env/base_vec_env.py`` type hints - Fixed ``stable_baselines3/common/vec_env/vec_frame_stack.py`` type hints - Fixed ``stable_baselines3/common/vec_env/dummy_vec_env.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/subproc_vec_env.py`` type hints - Upgraded docker images to use mamba/micromamba and CUDA 11.7 - Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks - Improve type annotation of wrappers diff --git a/pyproject.toml b/pyproject.toml index 99bd21884..a8645323b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ exclude = """(?x)( | stable_baselines3/common/save_util.py$ | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/__init__.py$ - | stable_baselines3/common/vec_env/subproc_vec_env.py$ | stable_baselines3/common/vec_env/vec_normalize.py$ | stable_baselines3/common/vec_env/vec_transpose.py$ | stable_baselines3/common/vec_env/vec_video_recorder.py$ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index cc8be48ef..5a7308a59 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -398,6 +398,11 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - observation_space = env.observation_space action_space = env.action_space + try: + env.reset(seed=0) + except TypeError as e: + raise TypeError("The reset() method must accept a `seed` parameter") from e + # Warn the user if needed. # A warning means that the environment may run but not work properly with Stable Baselines algorithms if warn: diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index c3b73909e..0132c32f8 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -5,7 +5,6 @@ from stable_baselines3.common.atari_wrappers import AtariWrapper from stable_baselines3.common.monitor import Monitor -from stable_baselines3.common.utils import compat_gym_seed from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv from stable_baselines3.common.vec_env.patch_gym import _patch_env @@ -101,7 +100,8 @@ def _init() -> gym.Env: env = _patch_env(env) if seed is not None: - compat_gym_seed(env, seed=seed + rank) + # Note: here we only seed the action space + # We will seed the env at the next reset env.action_space.seed(seed + rank) # Wrap the env in a Monitor wrapper # to have additional training information @@ -122,7 +122,10 @@ def _init() -> gym.Env: # Default: use a DummyVecEnv vec_env_cls = DummyVecEnv - return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) + vec_env = vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) + # Prepare the seeds for the first reset + vec_env.seed(seed) + return vec_env def make_atari_env( diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index 944408fb3..01670e6e4 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -131,7 +131,7 @@ def reset(self, indices: Optional[Iterable[int]] = None) -> None: self.noises[index].reset() def __repr__(self) -> str: - return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})" + return f"VecNoise(BaseNoise={self.base_noise!r}), n_envs={len(self.noises)})" def __call__(self) -> np.ndarray: """ diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 08366bda1..b6fbe59be 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -4,7 +4,6 @@ import random import re from collections import deque -from inspect import signature from itertools import zip_longest from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -549,18 +548,3 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: if print_info: print(env_info_str) return env_info, env_info_str - - -def compat_gym_seed(env: GymEnv, seed: int) -> None: - """ - Compatibility helper to seed Gym envs. - - :param env: The Gym environment. - :param seed: The seed for the pseudo random generator - """ - if "seed" in signature(env.unwrapped.reset).parameters: - # gym >= 0.23.1 - env.reset(seed=seed) - else: - # VecEnv and backward compatibility - env.seed(seed) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 4a97026ee..7d7cfc279 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -69,6 +69,14 @@ def __init__( self.render_mode = render_mode # store info returned by the reset method self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)] + # seeds to be used in the next call to env.reset() + self._seeds: List[Optional[int]] = [None for _ in range(num_envs)] + + def _reset_seeds(self) -> None: + """ + Reset the seeds that are going to be used at the next reset. + """ + self._seeds = [None for _ in range(self.num_envs)] @abstractmethod def reset(self) -> VecEnvObs: @@ -239,17 +247,24 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: self.env_method("render") return None - @abstractmethod def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: """ Sets the random seeds for all environments, based on a given seed. Each individual environment will still get its own seed, by incrementing the given seed. + WARNING: since gym 0.26, those seeds will only be passed to the environment + at the next reset. :param seed: The random seed. May be None for completely random seeding. :return: Returns a list containing the seeds for each individual env. Note that all list elements may be None, if the env does not return anything when being seeded. """ - pass + if seed is None: + # To ensure that subprocesses have different seeds, + # we still populate the seed variable when no argument is passed + seed = np.random.randint(0, 2**32 - 1) + + self._seeds = [seed + idx for idx in range(self.num_envs)] + return self._seeds @property def unwrapped(self) -> "VecEnv": diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 822025f53..29b4d6320 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,7 +1,7 @@ import warnings from collections import OrderedDict from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type import gymnasium as gym import numpy as np @@ -71,21 +71,12 @@ def step_wait(self) -> VecEnvStepReturn: self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) - def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: - # Avoid circular import - from stable_baselines3.common.utils import compat_gym_seed - - if seed is None: - seed = np.random.randint(0, 2**32 - 1) - seeds = [] - for idx, env in enumerate(self.envs): - seeds.append(compat_gym_seed(env, seed=seed + idx)) # type: ignore[func-returns-value] - return seeds - def reset(self) -> VecEnvObs: for env_idx in range(self.num_envs): - obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx]) self._save_obs(env_idx, obs) + # Seeds are only used once + self._reset_seeds() return self._obs_from_buf() def close(self) -> None: diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index d3fab9821..4d4695492 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,7 +1,7 @@ import multiprocessing as mp import warnings from collections import OrderedDict -from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import gymnasium as gym import numpy as np @@ -24,11 +24,10 @@ def _worker( ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped - from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() env = _patch_env(env_fn_wrapper.var()) - reset_info = {} + reset_info: Optional[Dict[str, Any]] = {} while True: try: cmd, data = remote.recv() @@ -42,10 +41,8 @@ def _worker( info["terminal_observation"] = observation observation, reset_info = env.reset() remote.send((observation, reward, done, info, reset_info)) - elif cmd == "seed": - remote.send(compat_gym_seed(env, seed=data)) elif cmd == "reset": - observation, reset_info = env.reset() + observation, reset_info = env.reset(seed=data) remote.send((observation, reset_info)) elif cmd == "render": remote.send(env.render()) @@ -61,7 +58,7 @@ def _worker( elif cmd == "get_attr": remote.send(getattr(env, data)) elif cmd == "set_attr": - remote.send(setattr(env, data[0], data[1])) + remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": remote.send(is_wrapped(env, data)) else: @@ -112,7 +109,9 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[ for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): args = (work_remote, remote, CloudpickleWrapper(env_fn)) # daemon=True: if the main process crashes, we should not cause things to hang - process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error + # pytype: disable=attribute-error + process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined] + # pytype: enable=attribute-error process.start() self.processes.append(process) work_remote.close() @@ -135,18 +134,13 @@ def step_wait(self) -> VecEnvStepReturn: obs, rews, dones, infos, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos - def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: - if seed is None: - seed = np.random.randint(0, 2**32 - 1) - for idx, remote in enumerate(self.remotes): - remote.send(("seed", seed + idx)) - return [remote.recv() for remote in self.remotes] - def reset(self) -> VecEnvObs: - for remote in self.remotes: - remote.send(("reset", None)) + for env_idx, remote in enumerate(self.remotes): + remote.send(("reset", self._seeds[env_idx])) results = [remote.recv() for remote in self.remotes] obs, self.reset_infos = zip(*results) + # Seeds are only used once + self._reset_seeds() return _flatten_obs(obs, self.observation_space) def close(self) -> None: @@ -235,6 +229,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Sp elif isinstance(space, spaces.Tuple): assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) + return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index] else: - return np.stack(obs) + return np.stack(obs) # type: ignore[arg-type] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 55a349a6b..caf8413d2 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a8 +2.0.0a9 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 825002c92..e7d4a1c57 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -25,7 +25,7 @@ def __init__(self): self._t = 0 self._ep_length = 100 - def reset(self): + def reset(self, *, seed=None, options=None): self._t = 0 obs = self._observations[0] return obs, {} @@ -55,7 +55,7 @@ def __init__(self): self._t = 0 self._ep_length = 100 - def reset(self): + def reset(self, seed=None, options=None): self._t = 0 obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()} return obs, {} diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 1050e866e..e855e2137 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -23,7 +23,7 @@ def step(self, action): info = {} return observation, reward, terminated, truncated, info - def reset(self): + def reset(self, seed=None): return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} def render(self): diff --git a/tests/test_envs.py b/tests/test_envs.py index aeb248fbb..e6c973852 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -130,8 +130,12 @@ def patched_step(_action): def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space + # Patch methods to avoid errors - env.reset = lambda: (new_obs_space.sample(), {}) + def patched_reset(seed=None): + return new_obs_space.sample(), {} + + env.reset = patched_reset def patched_step(_action): return new_obs_space.sample(), 0.0, False, False, {} @@ -204,7 +208,7 @@ def check_reset_assert_error(env, new_reset_return): :param new_reset_return: (Any) """ - def wrong_reset(): + def wrong_reset(seed=None): return new_reset_return, {} # Patch the reset method with a wrong one @@ -224,10 +228,21 @@ def test_common_failures_reset(): check_reset_assert_error(env, 1) # Return only obs (gym < 0.26) - env.reset = env.observation_space.sample + def wrong_reset(self, seed=None): + return env.observation_space.sample() + + env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError): check_env(env) + # No seed parameter (gym < 0.26) + def wrong_reset(self): + return env.observation_space.sample(), {} + + env.reset = types.MethodType(wrong_reset, env) + with pytest.raises(TypeError): + check_env(env) + # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) @@ -242,7 +257,7 @@ def test_common_failures_reset(): obs, _ = env.reset() - def wrong_reset(self): + def wrong_reset(self, seed=None): return {"img": obs["img"], "vec": obs["img"]}, {} env.reset = types.MethodType(wrong_reset, env) diff --git a/tests/test_logger.py b/tests/test_logger.py index 7a1c389ba..9d275a2ec 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -96,7 +96,7 @@ def read_fn(_format): tb_values_logged = [] for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]: for k in reservoir.Keys(): - tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}") + tb_values_logged.append(f"{k}: {reservoir.Items(k)!s}") content = LogContent(_format, tb_values_logged) return content @@ -353,7 +353,7 @@ def __init__(self, delay: float = 0.01): self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32) self.action_space = spaces.Discrete(2) - def reset(self): + def reset(self, seed=None): return self.observation_space.sample(), {} def step(self, action): diff --git a/tests/test_predict.py b/tests/test_predict.py index 247fe9172..aac6b1667 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -30,7 +30,7 @@ def __init__(self): self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) - def reset(self): + def reset(self, seed=None): return self.observation_space.sample(), {} def step(self, action): diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 6d18fcef8..fb70d0a33 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -47,14 +47,18 @@ def __init__(self): self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) return self.observation_space.sample(), {} def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} -@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) +@pytest.mark.parametrize( + "env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2)), DummyMultidimensionalAction()] +) def test_env(env): # Check the env used for testing check_env(env, skip_render_check=True) diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 1253be6e5..7efd94caa 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -9,7 +9,7 @@ class NanAndInfEnv(gym.Env): """Custom Environment that raised NaNs and Infs""" - metadata = {"render.modes": ["human"]} + metadata = {"render_modes": ["human"]} def __init__(self): super().__init__() @@ -27,7 +27,7 @@ def step(action): return [obs], 0.0, False, False, {} @staticmethod - def reset(): + def reset(seed=None): return [0.0], {} def render(self): diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 6bc7e74db..36d848a89 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -172,7 +172,7 @@ def __init__(self, max_steps): self.max_steps = max_steps self.current_step = 0 - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): self.current_step = 0 return np.array([self.current_step], dtype="int"), {} @@ -476,12 +476,9 @@ def make_monitored_env(): @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) -def test_backward_compat_seed(vec_env_class): +def test_vec_deterministic(vec_env_class): def make_env(): env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) - # Patch reset function to remove seed param - env.reset = lambda: (env.observation_space.sample(), {}) - env.seed = env.observation_space.seed return env vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) @@ -490,6 +487,15 @@ def make_env(): vec_env.seed(3) new_obs = vec_env.reset() assert np.allclose(new_obs, obs) + vec_env.close() + # Similar test but with make_vec_env + vec_env_1 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0) + vec_env_2 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0) + assert np.allclose(vec_env_1.reset(), vec_env_2.reset()) + random_actions = [vec_env_1.action_space.sample() for _ in range(N_ENVS)] + assert np.allclose(vec_env_1.step(random_actions)[0], vec_env_2.step(random_actions)[0]) + vec_env_1.close() + vec_env_2.close() @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)