Skip to content

Commit 005ab7f

Browse files
Copilotaraffin
andcommitted
Refactor VecEnv checker: move to separate files and add documentation
Co-authored-by: araffin <1973948+araffin@users.noreply.github.com>
1 parent e3f2190 commit 005ab7f

File tree

6 files changed

+606
-446
lines changed

6 files changed

+606
-446
lines changed

docs/guide/vec_envs.rst

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,59 @@ This callback can then be used to safely modify environment attributes during tr
183183
it calls the environment setter method.
184184

185185

186+
Checking VecEnv Implementation
187+
-----------------------------
188+
189+
When implementing custom vectorized environments, it's easy to make mistakes that can lead to hard-to-debug issues.
190+
To help with this, Stable-Baselines3 provides a ``check_vecenv`` function that validates your VecEnv implementation
191+
and checks for common issues.
192+
193+
The ``check_vecenv`` function verifies:
194+
195+
* The VecEnv properly inherits from ``stable_baselines3.common.vec_env.VecEnv``
196+
* Required attributes (``num_envs``, ``observation_space``, ``action_space``) are present and valid
197+
* The ``reset()`` method returns observations with the correct vectorized shape (batch dimension first)
198+
* The ``step()`` method returns properly shaped observations, rewards, dones, and infos
199+
* All return values have the expected types and dimensions
200+
* Compatibility with Stable-Baselines3 algorithms
201+
202+
**Usage:**
203+
204+
.. code-block:: python
205+
206+
from stable_baselines3.common.vec_env import DummyVecEnv
207+
from stable_baselines3.common.vec_env_checker import check_vecenv
208+
import gymnasium as gym
209+
210+
def make_env():
211+
return gym.make('CartPole-v1')
212+
213+
# Create your VecEnv
214+
vec_env = DummyVecEnv([make_env for _ in range(4)])
215+
216+
# Check the VecEnv implementation
217+
check_vecenv(vec_env, warn=True)
218+
219+
vec_env.close()
220+
221+
**When to use:**
222+
223+
* When implementing a custom VecEnv class
224+
* When debugging issues with vectorized environments
225+
* When contributing new VecEnv implementations to ensure they follow the API
226+
* As a sanity check before training to catch potential issues early
227+
228+
**Note:** Similar to ``check_env`` for single environments, ``check_vecenv`` is particularly useful during development
229+
and debugging. It helps catch common vectorization mistakes like incorrect batch dimensions, wrong return types, or
230+
missing required methods.
231+
232+
233+
VecEnv Checker
234+
~~~~~~~~~~~~~~
235+
236+
.. autofunction:: stable_baselines3.common.vec_env_checker.check_vecenv
237+
238+
186239
Vectorized Environments Wrappers
187240
--------------------------------
188241

stable_baselines3/common/env_checker.py

Lines changed: 0 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
99
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
10-
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
1110

1211

1312
def _is_oneof_space(space: spaces.Space) -> bool:
@@ -538,257 +537,3 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
538537
_check_nan(env)
539538
except NotImplementedError:
540539
pass
541-
542-
543-
def _check_vecenv_spaces(vec_env: VecEnv) -> None:
544-
"""
545-
Check that the VecEnv has valid observation and action spaces.
546-
"""
547-
assert hasattr(vec_env, "observation_space"), "VecEnv must have an observation_space attribute"
548-
assert hasattr(vec_env, "action_space"), "VecEnv must have an action_space attribute"
549-
assert hasattr(vec_env, "num_envs"), "VecEnv must have a num_envs attribute"
550-
551-
assert isinstance(
552-
vec_env.observation_space, spaces.Space
553-
), "The observation space must inherit from gymnasium.spaces"
554-
assert isinstance(vec_env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces"
555-
assert isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0, "num_envs must be a positive integer"
556-
557-
558-
def _check_vecenv_reset(vec_env: VecEnv) -> Any:
559-
"""
560-
Check that VecEnv reset method works correctly and returns properly shaped observations.
561-
"""
562-
try:
563-
obs = vec_env.reset()
564-
except Exception as e:
565-
raise RuntimeError(f"VecEnv reset() failed: {e}") from e
566-
567-
# Check observation shape matches expected vectorized shape
568-
if isinstance(vec_env.observation_space, spaces.Box):
569-
assert isinstance(obs, np.ndarray), f"For Box observation space, reset() must return np.ndarray, got {type(obs)}"
570-
expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape
571-
assert obs.shape == expected_shape, (
572-
f"Expected observation shape {expected_shape}, got {obs.shape}. "
573-
f"VecEnv observations should have batch dimension first."
574-
)
575-
elif isinstance(vec_env.observation_space, spaces.Dict):
576-
assert isinstance(obs, dict), f"For Dict observation space, reset() must return dict, got {type(obs)}"
577-
for key, space in vec_env.observation_space.spaces.items():
578-
assert key in obs, f"Missing key '{key}' in observation dict"
579-
if isinstance(space, spaces.Box):
580-
expected_shape = (vec_env.num_envs,) + space.shape
581-
assert obs[key].shape == expected_shape, (
582-
f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}"
583-
)
584-
elif isinstance(vec_env.observation_space, spaces.Discrete):
585-
assert isinstance(obs, np.ndarray), f"For Discrete observation space, reset() must return np.ndarray, got {type(obs)}"
586-
expected_shape = (vec_env.num_envs,)
587-
assert obs.shape == expected_shape, f"Expected observation shape {expected_shape}, got {obs.shape}"
588-
589-
return obs
590-
591-
592-
def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None:
593-
"""
594-
Check that VecEnv step method works correctly and returns properly shaped values.
595-
"""
596-
# Generate valid actions
597-
if isinstance(vec_env.action_space, spaces.Box):
598-
actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)])
599-
elif isinstance(vec_env.action_space, spaces.Discrete):
600-
actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)])
601-
elif isinstance(vec_env.action_space, spaces.MultiDiscrete):
602-
actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)])
603-
elif isinstance(vec_env.action_space, spaces.MultiBinary):
604-
actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)])
605-
else:
606-
actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)])
607-
608-
try:
609-
obs, rewards, dones, infos = vec_env.step(actions)
610-
except Exception as e:
611-
raise RuntimeError(f"VecEnv step() failed: {e}") from e
612-
613-
# Check rewards
614-
assert isinstance(rewards, np.ndarray), f"step() must return rewards as np.ndarray, got {type(rewards)}"
615-
assert rewards.shape == (vec_env.num_envs,), f"Expected rewards shape ({vec_env.num_envs},), got {rewards.shape}"
616-
617-
# Check dones
618-
assert isinstance(dones, np.ndarray), f"step() must return dones as np.ndarray, got {type(dones)}"
619-
assert dones.shape == (vec_env.num_envs,), f"Expected dones shape ({vec_env.num_envs},), got {dones.shape}"
620-
assert dones.dtype == bool, f"dones must have dtype bool, got {dones.dtype}"
621-
622-
# Check infos
623-
assert isinstance(infos, (list, tuple)), f"step() must return infos as list or tuple, got {type(infos)}"
624-
assert len(infos) == vec_env.num_envs, f"Expected infos length {vec_env.num_envs}, got {len(infos)}"
625-
for i, info in enumerate(infos):
626-
assert isinstance(info, dict), f"infos[{i}] must be dict, got {type(info)}"
627-
628-
# Check observation shape consistency (similar to reset)
629-
if isinstance(vec_env.observation_space, spaces.Box):
630-
assert isinstance(obs, np.ndarray), f"For Box observation space, step() must return np.ndarray, got {type(obs)}"
631-
expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape
632-
assert obs.shape == expected_shape, (
633-
f"Expected observation shape {expected_shape}, got {obs.shape}. "
634-
f"VecEnv observations should have batch dimension first."
635-
)
636-
elif isinstance(vec_env.observation_space, spaces.Dict):
637-
assert isinstance(obs, dict), f"For Dict observation space, step() must return dict, got {type(obs)}"
638-
for key, space in vec_env.observation_space.spaces.items():
639-
assert key in obs, f"Missing key '{key}' in observation dict"
640-
if isinstance(space, spaces.Box):
641-
expected_shape = (vec_env.num_envs,) + space.shape
642-
assert obs[key].shape == expected_shape, (
643-
f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}"
644-
)
645-
646-
647-
def _check_vecenv_unsupported_spaces(observation_space: spaces.Space, action_space: spaces.Space) -> bool:
648-
"""
649-
Emit warnings when the observation space or action space used is not supported by Stable-Baselines
650-
for VecEnv. This is a VecEnv-specific version of _check_unsupported_spaces.
651-
652-
:return: True if return value tests should be skipped.
653-
"""
654-
should_skip = graph_space = sequence_space = False
655-
if isinstance(observation_space, spaces.Dict):
656-
nested_dict = False
657-
for key, space in observation_space.spaces.items():
658-
if isinstance(space, spaces.Dict):
659-
nested_dict = True
660-
elif isinstance(space, spaces.Graph):
661-
graph_space = True
662-
elif isinstance(space, spaces.Sequence):
663-
sequence_space = True
664-
_check_non_zero_start(space, "observation", key)
665-
666-
if nested_dict:
667-
warnings.warn(
668-
"Nested observation spaces are not supported by Stable Baselines3 "
669-
"(Dict spaces inside Dict space). "
670-
"You should flatten it to have only one level of keys."
671-
"For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` "
672-
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
673-
)
674-
675-
if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1:
676-
warnings.warn(
677-
f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} "
678-
"which is currently not supported by Stable-Baselines3. "
679-
"Please convert it to a 1D array using a wrapper: "
680-
"https://github.com/DLR-RM/stable-baselines3/issues/1836."
681-
)
682-
683-
if isinstance(observation_space, spaces.Tuple):
684-
warnings.warn(
685-
"The observation space is a Tuple, "
686-
"this is currently not supported by Stable Baselines3. "
687-
"However, you can convert it to a Dict observation space "
688-
"(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). "
689-
"which is supported by SB3."
690-
)
691-
# Check for Sequence spaces inside Tuple
692-
for space in observation_space.spaces:
693-
if isinstance(space, spaces.Sequence):
694-
sequence_space = True
695-
elif isinstance(space, spaces.Graph):
696-
graph_space = True
697-
698-
# Check for Sequence spaces inside OneOf
699-
if _is_oneof_space(observation_space):
700-
warnings.warn(
701-
"OneOf observation space is not supported by Stable-Baselines3. "
702-
"Note: The checks for returned values are skipped."
703-
)
704-
should_skip = True
705-
706-
_check_non_zero_start(observation_space, "observation")
707-
708-
if isinstance(observation_space, spaces.Sequence) or sequence_space:
709-
warnings.warn(
710-
"Sequence observation space is not supported by Stable-Baselines3. "
711-
"You can pad your observation to have a fixed size instead.\n"
712-
"Note: The checks for returned values are skipped."
713-
)
714-
should_skip = True
715-
716-
if isinstance(observation_space, spaces.Graph) or graph_space:
717-
warnings.warn(
718-
"Graph observation space is not supported by Stable-Baselines3. "
719-
"Note: The checks for returned values are skipped."
720-
)
721-
should_skip = True
722-
723-
_check_non_zero_start(action_space, "action")
724-
725-
if not _is_numpy_array_space(action_space):
726-
warnings.warn(
727-
"The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. "
728-
"This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the "
729-
"action using a wrapper."
730-
)
731-
return should_skip
732-
733-
734-
def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None:
735-
"""
736-
Check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3.
737-
738-
This checker verifies that:
739-
- The VecEnv has proper observation_space, action_space, and num_envs attributes
740-
- The reset() method returns observations with correct vectorized shape
741-
- The step() method returns observations, rewards, dones, and infos with correct shapes
742-
- All return values have the expected types and dimensions
743-
744-
:param vec_env: The vectorized environment to check
745-
:param warn: Whether to output additional warnings mainly related to
746-
the interaction with Stable Baselines
747-
"""
748-
assert isinstance(vec_env, VecEnv), (
749-
"Your environment must inherit from stable_baselines3.common.vec_env.VecEnv"
750-
)
751-
752-
# ============= Check basic VecEnv attributes ================
753-
_check_vecenv_spaces(vec_env)
754-
755-
# Define aliases for convenience
756-
observation_space = vec_env.observation_space
757-
action_space = vec_env.action_space
758-
759-
# Warn the user if needed - reuse existing space checking logic
760-
if warn:
761-
should_skip = _check_vecenv_unsupported_spaces(observation_space, action_space)
762-
if should_skip:
763-
warnings.warn("VecEnv contains unsupported spaces, skipping some checks")
764-
return
765-
766-
obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space}
767-
for key, space in obs_spaces.items():
768-
if isinstance(space, spaces.Box):
769-
_check_box_obs(space, key)
770-
771-
# Check for the action space
772-
if isinstance(action_space, spaces.Box) and (
773-
np.any(np.abs(action_space.low) != np.abs(action_space.high))
774-
or np.any(action_space.low != -1)
775-
or np.any(action_space.high != 1)
776-
):
777-
warnings.warn(
778-
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
779-
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
780-
)
781-
782-
if isinstance(action_space, spaces.Box):
783-
assert np.all(
784-
np.isfinite(np.array([action_space.low, action_space.high]))
785-
), "Continuous action space must have a finite lower and upper bound"
786-
787-
if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32):
788-
warnings.warn(
789-
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."
790-
)
791-
792-
# ============ Check the VecEnv methods ===============
793-
obs = _check_vecenv_reset(vec_env)
794-
_check_vecenv_step(vec_env, obs)

stable_baselines3/common/vec_env/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
1414
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
1515

16+
# Avoid circular import by importing the vec_env_checker here
17+
from stable_baselines3.common.vec_env_checker import check_vecenv
18+
1619
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
1720

1821

@@ -98,6 +101,7 @@ def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
98101
"VecNormalize",
99102
"VecTransposeImage",
100103
"VecVideoRecorder",
104+
"check_vecenv",
101105
"is_vecenv_wrapped",
102106
"sync_envs_normalization",
103107
"unwrap_vec_normalize",

0 commit comments

Comments
 (0)