-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a complete compatibility wrapper (#3066)
* Added a new compatibility wrapper along with tests * Fix for 3.6 typing * Fix for 3.6 typing again * Add make integration * Unrelated change that for some reason is necessary to fix pyright * Ignore weird (and very non-critical) type check bug * Adjust old tests * Rename the compatibility argument in make * Rename the compatibility argument in register and envspec * Documentation updates * Remove test envs from the registry * Some rogue renames * Add nicer str and repr to the compatibility layer * Reorder the compatibility layer application * Add metadata to test envs * Add proper handling of automatic human rendering * Add auto human rendering to reset * Enable setting render_mode in gym.make * Documentation update * Fix an unrelated stochastic test
- Loading branch information
1 parent
2f33096
commit d818750
Showing
8 changed files
with
299 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
"""A compatibility wrapper converting an old-style environment into a valid environment.""" | ||
import sys | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import gym | ||
from gym.core import ObsType | ||
from gym.utils.step_api_compatibility import convert_to_terminated_truncated_step_api | ||
|
||
if sys.version_info >= (3, 8): | ||
from typing import Protocol, runtime_checkable | ||
elif sys.version_info >= (3, 7): | ||
from typing_extensions import Protocol, runtime_checkable | ||
else: | ||
Protocol = object | ||
runtime_checkable = lambda x: x # noqa: E731 | ||
|
||
|
||
@runtime_checkable | ||
class LegacyEnv(Protocol): | ||
"""A protocol for environments using the old step API.""" | ||
|
||
observation_space: gym.Space | ||
action_space: gym.Space | ||
|
||
def reset(self) -> Any: | ||
"""Reset the environment and return the initial observation.""" | ||
... | ||
|
||
def step(self, action: Any) -> Tuple[Any, float, bool, Dict]: | ||
"""Run one timestep of the environment's dynamics.""" | ||
... | ||
|
||
def render(self, mode: Optional[str] = "human") -> Any: | ||
"""Render the environment.""" | ||
... | ||
|
||
def close(self): | ||
"""Close the environment.""" | ||
... | ||
|
||
def seed(self, seed: Optional[int] = None): | ||
"""Set the seed for this env's random number generator(s).""" | ||
... | ||
|
||
|
||
class EnvCompatibility(gym.Env): | ||
r"""A wrapper which can transform an environment from the old API to the new API. | ||
Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation. | ||
New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info). | ||
(Refer to docs for details on the API change) | ||
Known limitations: | ||
- Environments that use `self.np_random` might not work as expected. | ||
""" | ||
|
||
def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None): | ||
"""A wrapper which converts old-style envs to valid modern envs. | ||
Some information may be lost in the conversion, so we recommend updating your environment. | ||
Args: | ||
old_env (LegacyEnv): the env to wrap, implemented with the old API | ||
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render | ||
""" | ||
self.metadata = getattr(old_env, "metadata", {"render_modes": []}) | ||
self.render_mode = render_mode | ||
self.reward_range = getattr(old_env, "reward_range", None) | ||
self.spec = getattr(old_env, "spec", None) | ||
self.env = old_env | ||
|
||
self.observation_space = old_env.observation_space | ||
self.action_space = old_env.action_space | ||
|
||
def reset( | ||
self, seed: Optional[int] = None, options: Optional[dict] = None | ||
) -> Tuple[ObsType, dict]: | ||
"""Resets the environment. | ||
Args: | ||
seed: the seed to reset the environment with | ||
options: the options to reset the environment with | ||
Returns: | ||
(observation, info) | ||
""" | ||
if seed is not None: | ||
self.env.seed(seed) | ||
# Options are ignored | ||
|
||
if self.render_mode == "human": | ||
self.render() | ||
|
||
return self.env.reset(), {} | ||
|
||
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: | ||
"""Steps through the environment. | ||
Args: | ||
action: action to step through the environment with | ||
Returns: | ||
(observation, reward, terminated, truncated, info) | ||
""" | ||
obs, reward, done, info = self.env.step(action) | ||
|
||
if self.render_mode == "human": | ||
self.render() | ||
|
||
return convert_to_terminated_truncated_step_api((obs, reward, done, info)) | ||
|
||
def render(self) -> Any: | ||
"""Renders the environment. | ||
Returns: | ||
The rendering of the environment, depending on the render mode | ||
""" | ||
return self.env.render(mode=self.render_mode) | ||
|
||
def close(self): | ||
"""Closes the environment.""" | ||
self.env.close() | ||
|
||
def __str__(self): | ||
"""Returns the wrapper name and the unwrapped environment string.""" | ||
return f"<{type(self).__name__}{self.env}>" | ||
|
||
def __repr__(self): | ||
"""Returns the string representation of the wrapper.""" | ||
return str(self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import sys | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
import gym | ||
from gym.spaces import Discrete | ||
from gym.wrappers.compatibility import EnvCompatibility, LegacyEnv | ||
|
||
|
||
class LegacyEnvExplicit(LegacyEnv, gym.Env): | ||
"""Legacy env that explicitly implements the old API.""" | ||
|
||
observation_space = Discrete(1) | ||
action_space = Discrete(1) | ||
metadata = {"render.modes": ["human", "rgb_array"]} | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def reset(self): | ||
return 0 | ||
|
||
def step(self, action): | ||
return 0, 0, False, {} | ||
|
||
def render(self, mode="human"): | ||
if mode == "human": | ||
return | ||
elif mode == "rgb_array": | ||
return np.zeros((1, 1, 3), dtype=np.uint8) | ||
|
||
def close(self): | ||
pass | ||
|
||
def seed(self, seed=None): | ||
pass | ||
|
||
|
||
class LegacyEnvImplicit(gym.Env): | ||
"""Legacy env that implicitly implements the old API as a protocol.""" | ||
|
||
observation_space = Discrete(1) | ||
action_space = Discrete(1) | ||
metadata = {"render.modes": ["human", "rgb_array"]} | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def reset(self): # type: ignore | ||
return 0 # type: ignore | ||
|
||
def step(self, action: Any) -> Tuple[int, float, bool, Dict]: | ||
return 0, 0.0, False, {} | ||
|
||
def render(self, mode: Optional[str] = "human") -> Any: | ||
if mode == "human": | ||
return | ||
elif mode == "rgb_array": | ||
return np.zeros((1, 1, 3), dtype=np.uint8) | ||
|
||
def close(self): | ||
pass | ||
|
||
def seed(self, seed: Optional[int] = None): | ||
pass | ||
|
||
|
||
def test_explicit(): | ||
old_env = LegacyEnvExplicit() | ||
assert isinstance(old_env, LegacyEnv) | ||
env = EnvCompatibility(old_env, render_mode="rgb_array") | ||
assert env.observation_space == Discrete(1) | ||
assert env.action_space == Discrete(1) | ||
assert env.reset() == (0, {}) | ||
assert env.reset(seed=0, options={"some": "option"}) == (0, {}) | ||
assert env.step(0) == (0, 0, False, False, {}) | ||
assert env.render().shape == (1, 1, 3) | ||
env.close() | ||
|
||
|
||
def test_implicit(): | ||
old_env = LegacyEnvImplicit() | ||
if sys.version_info >= (3, 7): | ||
# We need to give up on typing in Python 3.6 | ||
assert isinstance(old_env, LegacyEnv) | ||
env = EnvCompatibility(old_env, render_mode="rgb_array") | ||
assert env.observation_space == Discrete(1) | ||
assert env.action_space == Discrete(1) | ||
assert env.reset() == (0, {}) | ||
assert env.reset(seed=0, options={"some": "option"}) == (0, {}) | ||
assert env.step(0) == (0, 0, False, False, {}) | ||
assert env.render().shape == (1, 1, 3) | ||
env.close() | ||
|
||
|
||
def test_make_compatibility_in_spec(): | ||
gym.register( | ||
id="LegacyTestEnv-v0", | ||
entry_point=LegacyEnvExplicit, | ||
apply_api_compatibility=True, | ||
) | ||
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array") | ||
assert env.observation_space == Discrete(1) | ||
assert env.action_space == Discrete(1) | ||
assert env.reset() == (0, {}) | ||
assert env.reset(seed=0, options={"some": "option"}) == (0, {}) | ||
assert env.step(0) == (0, 0, False, False, {}) | ||
img = env.render() | ||
assert isinstance(img, np.ndarray) | ||
assert img.shape == (1, 1, 3) # type: ignore | ||
env.close() | ||
del gym.envs.registration.registry["LegacyTestEnv-v0"] | ||
|
||
|
||
def test_make_compatibility_in_make(): | ||
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit) | ||
env = gym.make( | ||
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array" | ||
) | ||
assert env.observation_space == Discrete(1) | ||
assert env.action_space == Discrete(1) | ||
assert env.reset() == (0, {}) | ||
assert env.reset(seed=0, options={"some": "option"}) == (0, {}) | ||
assert env.step(0) == (0, 0, False, False, {}) | ||
img = env.render() | ||
assert isinstance(img, np.ndarray) | ||
assert img.shape == (1, 1, 3) # type: ignore | ||
env.close() | ||
del gym.envs.registration.registry["LegacyTestEnv-v0"] |
Oops, something went wrong.