Skip to content

Commit

Permalink
Fix env checker bounds, expose all invalid indices at once (#1638)
Browse files Browse the repository at this point in the history
* Fix bug in env_checker.py bounds warning message

* Fix bug where Gym Environment Checker does not output the correct warning message when dealing with observation spaces that have different upper and different lower bounds

* Update test_env_checker.py with more comprehensive tests

* Make naming consistent

* Update version

* Catch all invalid indices at once

---------

Co-authored-by: gabo_tor <gabriel0torre@gmail.com>
  • Loading branch information
araffin and Gabo-Tor authored Aug 2, 2023
1 parent d43400b commit 17f02a8
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 21 deletions.
7 changes: 4 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.1.0a3 (WIP)
Release 2.1.0a4 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -27,7 +27,8 @@ Bug Fixes:
^^^^^^^^^^
- Relaxed check in logger, that was causing issue on Windows with colorama
- Fixed off-policy algorithms with continuous float64 actions (see #1145) (@tobirohrer)

- Fixed env_checker.py warning messages for out of bounds in complex observation spaces (@Gabo-Tor)

Deprecations:
^^^^^^^^^^^^^

Expand Down Expand Up @@ -1398,7 +1399,7 @@ And all the contributors:
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@carlosluis @arjun-kg @tlpss @JonathanKuelz
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
Expand Down
30 changes: 18 additions & 12 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,24 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac
f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}"
)
if isinstance(observation_space, spaces.Box):
assert np.all(obs >= observation_space.low), (
f"The observation returned by the `{method_name}()` method does not match the lower bound "
f"of the given observation space {observation_space}."
f"Expected: obs >= {np.min(observation_space.low)}, "
f"actual min value: {np.min(obs)} at index {np.argmin(obs)}"
)
assert np.all(obs <= observation_space.high), (
f"The observation returned by the `{method_name}()` method does not match the upper bound "
f"of the given observation space {observation_space}. "
f"Expected: obs <= {np.max(observation_space.high)}, "
f"actual max value: {np.max(obs)} at index {np.argmax(obs)}"
)
lower_bounds, upper_bounds = observation_space.low, observation_space.high
# Expose all invalid indices at once
invalid_indices = np.where(np.logical_or(obs < lower_bounds, obs > upper_bounds))
if (obs > upper_bounds).any() or (obs < lower_bounds).any():
message = (
f"The observation returned by the `{method_name}()` method does not match the bounds "
f"of the given observation space {observation_space}. \n"
)
message += f"{len(invalid_indices[0])} invalid indices: \n"

for index in zip(*invalid_indices):
index_str = ",".join(map(str, index))
message += (
f"Expected: {lower_bounds[index]} <= obs[{index_str}] <= {upper_bounds[index]}, "
f"actual value: {obs[index]} \n"
)

raise AssertionError(message)

assert observation_space.contains(obs), (
f"The observation returned by the `{method_name}()` method "
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0a3
2.1.0a4
23 changes: 18 additions & 5 deletions tests/test_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,28 @@ def test_check_env_dict_action():
[
# Above upper bound
(
spaces.Box(low=0.0, high=1.0, shape=(3,), dtype=np.float32),
spaces.Box(low=np.array([0.0, 0.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
np.array([1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: obs <= 1\.0, actual max value: 1\.5 at index 1",
r"Expected: 0\.0 <= obs\[1] <= 1\.0, actual value: 1\.5",
),
# Above upper bound (multi-dim)
(
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
3.0 * np.ones((2, 3, 3, 1), dtype=np.float32),
# Note: this is one of the 18 invalid indices
r"Expected: -1\.0 <= obs\[1,2,1,0\] <= 2\.0, actual value: 3\.0",
),
# Below lower bound
(
spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32),
spaces.Box(low=np.array([0.0, -10.0, 0.0]), high=np.array([2.0, 1.0, 1.0]), shape=(3,), dtype=np.float32),
np.array([-1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: obs >= 0\.0, actual min value: -1\.0 at index 0",
r"Expected: 0\.0 <= obs\[0] <= 2\.0, actual value: -1\.0",
),
# Below lower bound (multi-dim)
(
spaces.Box(low=-1.0, high=2.0, shape=(2, 3, 3, 1), dtype=np.float32),
-2 * np.ones((2, 3, 3, 1), dtype=np.float32),
r"18 invalid indices:",
),
# Wrong dtype
(
Expand Down Expand Up @@ -111,7 +124,7 @@ def step(self, action):

test_env = TestEnv()
with pytest.raises(AssertionError, match=error_message):
check_env(env=test_env)
check_env(env=test_env, warn=False)


class LimitedStepsTestEnv(gym.Env):
Expand Down

0 comments on commit 17f02a8

Please sign in to comment.