|
7 | 7 |
|
8 | 8 | from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
|
9 | 9 | from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
|
10 |
| -from stable_baselines3.common.vec_env.base_vec_env import VecEnv |
11 | 10 |
|
12 | 11 |
|
13 | 12 | def _is_oneof_space(space: spaces.Space) -> bool:
|
@@ -538,257 +537,3 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
|
538 | 537 | _check_nan(env)
|
539 | 538 | except NotImplementedError:
|
540 | 539 | pass
|
541 |
| - |
542 |
| - |
543 |
| -def _check_vecenv_spaces(vec_env: VecEnv) -> None: |
544 |
| - """ |
545 |
| - Check that the VecEnv has valid observation and action spaces. |
546 |
| - """ |
547 |
| - assert hasattr(vec_env, "observation_space"), "VecEnv must have an observation_space attribute" |
548 |
| - assert hasattr(vec_env, "action_space"), "VecEnv must have an action_space attribute" |
549 |
| - assert hasattr(vec_env, "num_envs"), "VecEnv must have a num_envs attribute" |
550 |
| - |
551 |
| - assert isinstance( |
552 |
| - vec_env.observation_space, spaces.Space |
553 |
| - ), "The observation space must inherit from gymnasium.spaces" |
554 |
| - assert isinstance(vec_env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" |
555 |
| - assert isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0, "num_envs must be a positive integer" |
556 |
| - |
557 |
| - |
558 |
| -def _check_vecenv_reset(vec_env: VecEnv) -> Any: |
559 |
| - """ |
560 |
| - Check that VecEnv reset method works correctly and returns properly shaped observations. |
561 |
| - """ |
562 |
| - try: |
563 |
| - obs = vec_env.reset() |
564 |
| - except Exception as e: |
565 |
| - raise RuntimeError(f"VecEnv reset() failed: {e}") from e |
566 |
| - |
567 |
| - # Check observation shape matches expected vectorized shape |
568 |
| - if isinstance(vec_env.observation_space, spaces.Box): |
569 |
| - assert isinstance(obs, np.ndarray), f"For Box observation space, reset() must return np.ndarray, got {type(obs)}" |
570 |
| - expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape |
571 |
| - assert obs.shape == expected_shape, ( |
572 |
| - f"Expected observation shape {expected_shape}, got {obs.shape}. " |
573 |
| - f"VecEnv observations should have batch dimension first." |
574 |
| - ) |
575 |
| - elif isinstance(vec_env.observation_space, spaces.Dict): |
576 |
| - assert isinstance(obs, dict), f"For Dict observation space, reset() must return dict, got {type(obs)}" |
577 |
| - for key, space in vec_env.observation_space.spaces.items(): |
578 |
| - assert key in obs, f"Missing key '{key}' in observation dict" |
579 |
| - if isinstance(space, spaces.Box): |
580 |
| - expected_shape = (vec_env.num_envs,) + space.shape |
581 |
| - assert obs[key].shape == expected_shape, ( |
582 |
| - f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" |
583 |
| - ) |
584 |
| - elif isinstance(vec_env.observation_space, spaces.Discrete): |
585 |
| - assert isinstance(obs, np.ndarray), f"For Discrete observation space, reset() must return np.ndarray, got {type(obs)}" |
586 |
| - expected_shape = (vec_env.num_envs,) |
587 |
| - assert obs.shape == expected_shape, f"Expected observation shape {expected_shape}, got {obs.shape}" |
588 |
| - |
589 |
| - return obs |
590 |
| - |
591 |
| - |
592 |
| -def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: |
593 |
| - """ |
594 |
| - Check that VecEnv step method works correctly and returns properly shaped values. |
595 |
| - """ |
596 |
| - # Generate valid actions |
597 |
| - if isinstance(vec_env.action_space, spaces.Box): |
598 |
| - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) |
599 |
| - elif isinstance(vec_env.action_space, spaces.Discrete): |
600 |
| - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) |
601 |
| - elif isinstance(vec_env.action_space, spaces.MultiDiscrete): |
602 |
| - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) |
603 |
| - elif isinstance(vec_env.action_space, spaces.MultiBinary): |
604 |
| - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) |
605 |
| - else: |
606 |
| - actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) |
607 |
| - |
608 |
| - try: |
609 |
| - obs, rewards, dones, infos = vec_env.step(actions) |
610 |
| - except Exception as e: |
611 |
| - raise RuntimeError(f"VecEnv step() failed: {e}") from e |
612 |
| - |
613 |
| - # Check rewards |
614 |
| - assert isinstance(rewards, np.ndarray), f"step() must return rewards as np.ndarray, got {type(rewards)}" |
615 |
| - assert rewards.shape == (vec_env.num_envs,), f"Expected rewards shape ({vec_env.num_envs},), got {rewards.shape}" |
616 |
| - |
617 |
| - # Check dones |
618 |
| - assert isinstance(dones, np.ndarray), f"step() must return dones as np.ndarray, got {type(dones)}" |
619 |
| - assert dones.shape == (vec_env.num_envs,), f"Expected dones shape ({vec_env.num_envs},), got {dones.shape}" |
620 |
| - assert dones.dtype == bool, f"dones must have dtype bool, got {dones.dtype}" |
621 |
| - |
622 |
| - # Check infos |
623 |
| - assert isinstance(infos, (list, tuple)), f"step() must return infos as list or tuple, got {type(infos)}" |
624 |
| - assert len(infos) == vec_env.num_envs, f"Expected infos length {vec_env.num_envs}, got {len(infos)}" |
625 |
| - for i, info in enumerate(infos): |
626 |
| - assert isinstance(info, dict), f"infos[{i}] must be dict, got {type(info)}" |
627 |
| - |
628 |
| - # Check observation shape consistency (similar to reset) |
629 |
| - if isinstance(vec_env.observation_space, spaces.Box): |
630 |
| - assert isinstance(obs, np.ndarray), f"For Box observation space, step() must return np.ndarray, got {type(obs)}" |
631 |
| - expected_shape = (vec_env.num_envs,) + vec_env.observation_space.shape |
632 |
| - assert obs.shape == expected_shape, ( |
633 |
| - f"Expected observation shape {expected_shape}, got {obs.shape}. " |
634 |
| - f"VecEnv observations should have batch dimension first." |
635 |
| - ) |
636 |
| - elif isinstance(vec_env.observation_space, spaces.Dict): |
637 |
| - assert isinstance(obs, dict), f"For Dict observation space, step() must return dict, got {type(obs)}" |
638 |
| - for key, space in vec_env.observation_space.spaces.items(): |
639 |
| - assert key in obs, f"Missing key '{key}' in observation dict" |
640 |
| - if isinstance(space, spaces.Box): |
641 |
| - expected_shape = (vec_env.num_envs,) + space.shape |
642 |
| - assert obs[key].shape == expected_shape, ( |
643 |
| - f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" |
644 |
| - ) |
645 |
| - |
646 |
| - |
647 |
| -def _check_vecenv_unsupported_spaces(observation_space: spaces.Space, action_space: spaces.Space) -> bool: |
648 |
| - """ |
649 |
| - Emit warnings when the observation space or action space used is not supported by Stable-Baselines |
650 |
| - for VecEnv. This is a VecEnv-specific version of _check_unsupported_spaces. |
651 |
| -
|
652 |
| - :return: True if return value tests should be skipped. |
653 |
| - """ |
654 |
| - should_skip = graph_space = sequence_space = False |
655 |
| - if isinstance(observation_space, spaces.Dict): |
656 |
| - nested_dict = False |
657 |
| - for key, space in observation_space.spaces.items(): |
658 |
| - if isinstance(space, spaces.Dict): |
659 |
| - nested_dict = True |
660 |
| - elif isinstance(space, spaces.Graph): |
661 |
| - graph_space = True |
662 |
| - elif isinstance(space, spaces.Sequence): |
663 |
| - sequence_space = True |
664 |
| - _check_non_zero_start(space, "observation", key) |
665 |
| - |
666 |
| - if nested_dict: |
667 |
| - warnings.warn( |
668 |
| - "Nested observation spaces are not supported by Stable Baselines3 " |
669 |
| - "(Dict spaces inside Dict space). " |
670 |
| - "You should flatten it to have only one level of keys." |
671 |
| - "For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` " |
672 |
| - "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." |
673 |
| - ) |
674 |
| - |
675 |
| - if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: |
676 |
| - warnings.warn( |
677 |
| - f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " |
678 |
| - "which is currently not supported by Stable-Baselines3. " |
679 |
| - "Please convert it to a 1D array using a wrapper: " |
680 |
| - "https://github.com/DLR-RM/stable-baselines3/issues/1836." |
681 |
| - ) |
682 |
| - |
683 |
| - if isinstance(observation_space, spaces.Tuple): |
684 |
| - warnings.warn( |
685 |
| - "The observation space is a Tuple, " |
686 |
| - "this is currently not supported by Stable Baselines3. " |
687 |
| - "However, you can convert it to a Dict observation space " |
688 |
| - "(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). " |
689 |
| - "which is supported by SB3." |
690 |
| - ) |
691 |
| - # Check for Sequence spaces inside Tuple |
692 |
| - for space in observation_space.spaces: |
693 |
| - if isinstance(space, spaces.Sequence): |
694 |
| - sequence_space = True |
695 |
| - elif isinstance(space, spaces.Graph): |
696 |
| - graph_space = True |
697 |
| - |
698 |
| - # Check for Sequence spaces inside OneOf |
699 |
| - if _is_oneof_space(observation_space): |
700 |
| - warnings.warn( |
701 |
| - "OneOf observation space is not supported by Stable-Baselines3. " |
702 |
| - "Note: The checks for returned values are skipped." |
703 |
| - ) |
704 |
| - should_skip = True |
705 |
| - |
706 |
| - _check_non_zero_start(observation_space, "observation") |
707 |
| - |
708 |
| - if isinstance(observation_space, spaces.Sequence) or sequence_space: |
709 |
| - warnings.warn( |
710 |
| - "Sequence observation space is not supported by Stable-Baselines3. " |
711 |
| - "You can pad your observation to have a fixed size instead.\n" |
712 |
| - "Note: The checks for returned values are skipped." |
713 |
| - ) |
714 |
| - should_skip = True |
715 |
| - |
716 |
| - if isinstance(observation_space, spaces.Graph) or graph_space: |
717 |
| - warnings.warn( |
718 |
| - "Graph observation space is not supported by Stable-Baselines3. " |
719 |
| - "Note: The checks for returned values are skipped." |
720 |
| - ) |
721 |
| - should_skip = True |
722 |
| - |
723 |
| - _check_non_zero_start(action_space, "action") |
724 |
| - |
725 |
| - if not _is_numpy_array_space(action_space): |
726 |
| - warnings.warn( |
727 |
| - "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " |
728 |
| - "This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the " |
729 |
| - "action using a wrapper." |
730 |
| - ) |
731 |
| - return should_skip |
732 |
| - |
733 |
| - |
734 |
| -def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None: |
735 |
| - """ |
736 |
| - Check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3. |
737 |
| - |
738 |
| - This checker verifies that: |
739 |
| - - The VecEnv has proper observation_space, action_space, and num_envs attributes |
740 |
| - - The reset() method returns observations with correct vectorized shape |
741 |
| - - The step() method returns observations, rewards, dones, and infos with correct shapes |
742 |
| - - All return values have the expected types and dimensions |
743 |
| - |
744 |
| - :param vec_env: The vectorized environment to check |
745 |
| - :param warn: Whether to output additional warnings mainly related to |
746 |
| - the interaction with Stable Baselines |
747 |
| - """ |
748 |
| - assert isinstance(vec_env, VecEnv), ( |
749 |
| - "Your environment must inherit from stable_baselines3.common.vec_env.VecEnv" |
750 |
| - ) |
751 |
| - |
752 |
| - # ============= Check basic VecEnv attributes ================ |
753 |
| - _check_vecenv_spaces(vec_env) |
754 |
| - |
755 |
| - # Define aliases for convenience |
756 |
| - observation_space = vec_env.observation_space |
757 |
| - action_space = vec_env.action_space |
758 |
| - |
759 |
| - # Warn the user if needed - reuse existing space checking logic |
760 |
| - if warn: |
761 |
| - should_skip = _check_vecenv_unsupported_spaces(observation_space, action_space) |
762 |
| - if should_skip: |
763 |
| - warnings.warn("VecEnv contains unsupported spaces, skipping some checks") |
764 |
| - return |
765 |
| - |
766 |
| - obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space} |
767 |
| - for key, space in obs_spaces.items(): |
768 |
| - if isinstance(space, spaces.Box): |
769 |
| - _check_box_obs(space, key) |
770 |
| - |
771 |
| - # Check for the action space |
772 |
| - if isinstance(action_space, spaces.Box) and ( |
773 |
| - np.any(np.abs(action_space.low) != np.abs(action_space.high)) |
774 |
| - or np.any(action_space.low != -1) |
775 |
| - or np.any(action_space.high != 1) |
776 |
| - ): |
777 |
| - warnings.warn( |
778 |
| - "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " |
779 |
| - "cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" |
780 |
| - ) |
781 |
| - |
782 |
| - if isinstance(action_space, spaces.Box): |
783 |
| - assert np.all( |
784 |
| - np.isfinite(np.array([action_space.low, action_space.high])) |
785 |
| - ), "Continuous action space must have a finite lower and upper bound" |
786 |
| - |
787 |
| - if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): |
788 |
| - warnings.warn( |
789 |
| - f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." |
790 |
| - ) |
791 |
| - |
792 |
| - # ============ Check the VecEnv methods =============== |
793 |
| - obs = _check_vecenv_reset(vec_env) |
794 |
| - _check_vecenv_step(vec_env, obs) |
0 commit comments