diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4cff7277f..8f08ccca4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.1.0a3 (WIP) +Release 2.1.0a4 (WIP) -------------------------- Breaking Changes: @@ -27,7 +27,8 @@ Bug Fixes: ^^^^^^^^^^ - Relaxed check in logger, that was causing issue on Windows with colorama - Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer) - +- Fixed env_checker.py warning messages for out of bounds in complex observation spaces (@Gabo-Tor) + Deprecations: ^^^^^^^^^^^^^ @@ -1398,7 +1399,7 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@carlosluis @arjun-kg @tlpss @JonathanKuelz +@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 516d7ba61..8b8da7f44 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -203,18 +203,24 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" ) if isinstance(observation_space, spaces.Box): - assert np.all(obs >= observation_space.low), ( - f"The observation returned by the `{method_name}()` method does not match the lower bound " - f"of the given observation space {observation_space}." - f"Expected: obs >= {np.min(observation_space.low)}, " - f"actual min value: {np.min(obs)} at index {np.argmin(obs)}" - ) - assert np.all(obs <= observation_space.high), ( - f"The observation returned by the `{method_name}()` method does not match the upper bound " - f"of the given observation space {observation_space}. " - f"Expected: obs <= {np.max(observation_space.high)}, " - f"actual max value: {np.max(obs)} at index {np.argmax(obs)}" - ) + lower_bounds, upper_bounds = observation_space.low, observation_space.high + # Expose all invalid indices at once + invalid_indices = np.where(np.logical_or(obs < lower_bounds, obs > upper_bounds)) + if (obs > upper_bounds).any() or (obs < lower_bounds).any(): + message = ( + f"The observation returned by the `{method_name}()` method does not match the bounds " + f"of the given observation space {observation_space}. \n" + ) + message += f"{len(invalid_indices[0])} invalid indices: \n" + + for index in zip(*invalid_indices): + index_str = ",".join(map(str, index)) + message += ( + f"Expected: {lower_bounds[index]} <= obs[{index_str}] <= {upper_bounds[index]}, " + f"actual value: {obs[index]} \n" + ) + + raise AssertionError(message) assert observation_space.contains(obs), ( f"The observation returned by the `{method_name}()` method " diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a4a6a877e..736f6c84a 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.1.0a3 +2.1.0a4 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index c0a5e0610..87cc177b7 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -42,15 +42,28 @@ def test_check_env_dict_action(): [ # Above upper bound ( - spaces.Box(low=0.0, high=1.0, shape=(3,), dtype=np.float32), + spaces.Box(low=np.array([0.0, 0.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32), np.array([1.0, 1.5, 0.5], dtype=np.float32), - r"Expected: obs <= 1\.0, actual max value: 1\.5 at index 1", + r"Expected: 0\.0 <= obs\[1] <= 1\.0, actual value: 1\.5", + ), + # Above upper bound (multi-dim) + ( + spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32), + 3.0 * np.ones((2, 3, 3, 1), dtype=np.float32), + # Note: this is one of the 18 invalid indices + r"Expected: -1\.0 <= obs\[1,2,1,0\] <= 2\.0, actual value: 3\.0", ), # Below lower bound ( - spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32), + spaces.Box(low=np.array([0.0, -10.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32), np.array([-1.0, 1.5, 0.5], dtype=np.float32), - r"Expected: obs >= 0\.0, actual min value: -1\.0 at index 0", + r"Expected: 0\.0 <= obs\[0] <= 2\.0, actual value: -1\.0", + ), + # Below lower bound (multi-dim) + ( + spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32), + -2 * np.ones((2, 3, 3, 1), dtype=np.float32), + r"18 invalid indices:", ), # Wrong dtype ( @@ -111,7 +124,7 @@ def step(self, action): test_env = TestEnv() with pytest.raises(AssertionError, match=error_message): - check_env(env=test_env) + check_env(env=test_env, warn=False) class LimitedStepsTestEnv(gym.Env):