Skip to content

Commit

Permalink
Add a complete compatibility wrapper (#3066)
Browse files Browse the repository at this point in the history
* Added a new compatibility wrapper along with tests

* Fix for 3.6 typing

* Fix for 3.6 typing again

* Add make integration

* Unrelated change that for some reason is necessary to fix pyright

* Ignore weird (and very non-critical) type check bug

* Adjust old tests

* Rename the compatibility argument in make

* Rename the compatibility argument in register and envspec

* Documentation updates

* Remove test envs from the registry

* Some rogue renames

* Add nicer str and repr to the compatibility layer

* Reorder the compatibility layer application

* Add metadata to test envs

* Add proper handling of automatic human rendering

* Add auto human rendering to reset

* Enable setting render_mode in gym.make

* Documentation update

* Fix an unrelated stochastic test
  • Loading branch information
RedTachyon authored Sep 6, 2022
1 parent 2f33096 commit d818750
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 37 deletions.
38 changes: 23 additions & 15 deletions gym/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
HumanRendering,
OrderEnforcing,
RenderCollection,
StepAPICompatibility,
TimeLimit,
)
from gym.wrappers.compatibility import EnvCompatibility
from gym.wrappers.env_checker import PassiveEnvChecker

if sys.version_info < (3, 10):
Expand Down Expand Up @@ -141,7 +141,7 @@ class EnvSpec:
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
apply_step_compatibility: bool = field(default=False)
apply_api_compatibility: bool = field(default=False)

# Environment arguments
kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -440,7 +440,7 @@ def register(
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_step_compatibility: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
"""Register an environment with gym.
Expand All @@ -459,7 +459,7 @@ def register(
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order
autoreset: If to add the autoreset wrapper such that reset does not need to be called.
disable_env_checker: If to disable the environment checker for the environment. Recommended to False.
apply_step_compatibility: If to apply the `StepAPICompatibility` wrapper.
apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper.
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
"""
global registry, current_namespace
Expand Down Expand Up @@ -490,7 +490,7 @@ def register(
order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker,
apply_step_compatibility=apply_step_compatibility,
apply_api_compatibility=apply_api_compatibility,
**kwargs,
)
_check_spec_register(new_spec)
Expand All @@ -503,7 +503,7 @@ def make(
id: Union[str, EnvSpec],
max_episode_steps: Optional[int] = None,
autoreset: bool = False,
apply_step_compatibility: Optional[bool] = None,
apply_api_compatibility: Optional[bool] = None,
disable_env_checker: Optional[bool] = None,
**kwargs,
) -> Env:
Expand All @@ -515,10 +515,10 @@ def make(
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
apply_step_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
converts the environment step from a done bool to return termination and truncation bools.
By default, the argument is None to which the environment specification `apply_step_compatibility` is used
which defaults to False. Otherwise, the value of `apply_step_compatibility` is used.
By default, the argument is None to which the environment specification `apply_api_compatibility` is used
which defaults to False. Otherwise, the value of `apply_api_compatibility` is used.
If `True`, the wrapper is applied otherwise, the wrapper is not applied.
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
(which is by default False, running the environment checker),
Expand Down Expand Up @@ -628,6 +628,14 @@ def make(
f"The environment creator metadata doesn't include `render_modes`, contains: {list(env_creator.metadata.keys())}"
)

if apply_api_compatibility is True or (
apply_api_compatibility is None and spec_.apply_api_compatibility is True
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = _kwargs.pop("render_mode", None)
else:
render_mode = None

try:
env = env_creator(**_kwargs)
except TypeError as e:
Expand All @@ -648,18 +656,18 @@ def make(
spec_.kwargs = _kwargs
env.unwrapped.spec = spec_

# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and spec_.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)

# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and spec_.disable_env_checker is False
):
env = PassiveEnvChecker(env)

# Add step API wrapper
if apply_step_compatibility is True or (
apply_step_compatibility is None and spec_.apply_step_compatibility is True
):
env = StepAPICompatibility(env, output_truncation_bool=True)

# Add the order enforcing wrapper
if spec_.order_enforce:
env = OrderEnforcing(env)
Expand Down
130 changes: 130 additions & 0 deletions gym/wrappers/compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""A compatibility wrapper converting an old-style environment into a valid environment."""
import sys
from typing import Any, Dict, Optional, Tuple

import gym
from gym.core import ObsType
from gym.utils.step_api_compatibility import convert_to_terminated_truncated_step_api

if sys.version_info >= (3, 8):
from typing import Protocol, runtime_checkable
elif sys.version_info >= (3, 7):
from typing_extensions import Protocol, runtime_checkable
else:
Protocol = object
runtime_checkable = lambda x: x # noqa: E731


@runtime_checkable
class LegacyEnv(Protocol):
"""A protocol for environments using the old step API."""

observation_space: gym.Space
action_space: gym.Space

def reset(self) -> Any:
"""Reset the environment and return the initial observation."""
...

def step(self, action: Any) -> Tuple[Any, float, bool, Dict]:
"""Run one timestep of the environment's dynamics."""
...

def render(self, mode: Optional[str] = "human") -> Any:
"""Render the environment."""
...

def close(self):
"""Close the environment."""
...

def seed(self, seed: Optional[int] = None):
"""Set the seed for this env's random number generator(s)."""
...


class EnvCompatibility(gym.Env):
r"""A wrapper which can transform an environment from the old API to the new API.
Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation.
New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info).
(Refer to docs for details on the API change)
Known limitations:
- Environments that use `self.np_random` might not work as expected.
"""

def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None):
"""A wrapper which converts old-style envs to valid modern envs.
Some information may be lost in the conversion, so we recommend updating your environment.
Args:
old_env (LegacyEnv): the env to wrap, implemented with the old API
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render
"""
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
self.render_mode = render_mode
self.reward_range = getattr(old_env, "reward_range", None)
self.spec = getattr(old_env, "spec", None)
self.env = old_env

self.observation_space = old_env.observation_space
self.action_space = old_env.action_space

def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.
Args:
seed: the seed to reset the environment with
options: the options to reset the environment with
Returns:
(observation, info)
"""
if seed is not None:
self.env.seed(seed)
# Options are ignored

if self.render_mode == "human":
self.render()

return self.env.reset(), {}

def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
"""Steps through the environment.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info)
"""
obs, reward, done, info = self.env.step(action)

if self.render_mode == "human":
self.render()

return convert_to_terminated_truncated_step_api((obs, reward, done, info))

def render(self) -> Any:
"""Renders the environment.
Returns:
The rendering of the environment, depending on the render mode
"""
return self.env.render(mode=self.render_mode)

def close(self):
"""Closes the environment."""
self.env.close()

def __str__(self):
"""Returns the wrapper name and the unwrapped environment string."""
return f"<{type(self).__name__}{self.env}>"

def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)
2 changes: 1 addition & 1 deletion gym/wrappers/step_api_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class StepAPICompatibility(gym.Wrapper):
>>> env = gym.make("CartPole-v1")
>>> env # wrapper not applied by default, set to new API
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
>>> env = gym.make("CartPole-v1", apply_step_compatibility=True) # set to old API
>>> env = gym.make("CartPole-v1", apply_api_compatibility=True) # set to old API
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
>>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ cloudpickle>=1.2.0
importlib_metadata>=4.8.0; python_version < '3.10'
gym_notices>=0.0.4
dataclasses==0.8; python_version == '3.6'
typing_extensions==4.3.0; python_version == '3.7'
opencv-python>=3.0
lz4>=3.1.0
matplotlib>=3.0
Expand Down
130 changes: 130 additions & 0 deletions tests/envs/test_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import sys
from typing import Any, Dict, Optional, Tuple

import numpy as np

import gym
from gym.spaces import Discrete
from gym.wrappers.compatibility import EnvCompatibility, LegacyEnv


class LegacyEnvExplicit(LegacyEnv, gym.Env):
"""Legacy env that explicitly implements the old API."""

observation_space = Discrete(1)
action_space = Discrete(1)
metadata = {"render.modes": ["human", "rgb_array"]}

def __init__(self):
pass

def reset(self):
return 0

def step(self, action):
return 0, 0, False, {}

def render(self, mode="human"):
if mode == "human":
return
elif mode == "rgb_array":
return np.zeros((1, 1, 3), dtype=np.uint8)

def close(self):
pass

def seed(self, seed=None):
pass


class LegacyEnvImplicit(gym.Env):
"""Legacy env that implicitly implements the old API as a protocol."""

observation_space = Discrete(1)
action_space = Discrete(1)
metadata = {"render.modes": ["human", "rgb_array"]}

def __init__(self):
pass

def reset(self): # type: ignore
return 0 # type: ignore

def step(self, action: Any) -> Tuple[int, float, bool, Dict]:
return 0, 0.0, False, {}

def render(self, mode: Optional[str] = "human") -> Any:
if mode == "human":
return
elif mode == "rgb_array":
return np.zeros((1, 1, 3), dtype=np.uint8)

def close(self):
pass

def seed(self, seed: Optional[int] = None):
pass


def test_explicit():
old_env = LegacyEnvExplicit()
assert isinstance(old_env, LegacyEnv)
env = EnvCompatibility(old_env, render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
assert env.render().shape == (1, 1, 3)
env.close()


def test_implicit():
old_env = LegacyEnvImplicit()
if sys.version_info >= (3, 7):
# We need to give up on typing in Python 3.6
assert isinstance(old_env, LegacyEnv)
env = EnvCompatibility(old_env, render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
assert env.render().shape == (1, 1, 3)
env.close()


def test_make_compatibility_in_spec():
gym.register(
id="LegacyTestEnv-v0",
entry_point=LegacyEnvExplicit,
apply_api_compatibility=True,
)
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
img = env.render()
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]


def test_make_compatibility_in_make():
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gym.make(
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
)
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
assert env.step(0) == (0, 0, False, False, {})
img = env.render()
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]
Loading

0 comments on commit d818750

Please sign in to comment.