forked from denisyarats/exorl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgym_wrapper.py
148 lines (121 loc) · 4.56 KB
/
gym_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Any, Dict, List, Optional
import gym
import dm_env
import numpy as np
import tree
from . import types
from . import specs
from gym import spaces
class GymWrapper(dm_env.Environment):
"""Environment wrapper for OpenAI Gym environments."""
# Note: we don't inherit from base.EnvironmentWrapper because that class
# assumes that the wrapped environment is a dm_env.Environment.
def __init__(self, environment: gym.Env):
self._environment = environment
self._reset_next_step = True
self._last_info = None
# Convert action and observation specs.
obs_space = self._environment.observation_space
act_space = self._environment.action_space
# self._observation_spec = _convert_to_spec(obs_space, name='observation')
self._observation_spec = specs.Array(shape=obs_space.shape, dtype=obs_space.dtype, name='observation')
self._action_spec = _convert_to_spec(act_space, name='action')
def reset(self) -> dm_env.TimeStep:
"""Resets the episode."""
self._reset_next_step = False
observation = self._environment.reset()
# Reset the diagnostic information.
self._last_info = None
return dm_env.restart(observation)
def step(self, action: types.NestedArray) -> dm_env.TimeStep:
"""Steps the environment."""
if self._reset_next_step:
return self.reset()
observation, reward, done, info = self._environment.step(action)
self._reset_next_step = done
self._last_info = info
# Convert the type of the reward based on the spec, respecting the scalar or
# array property.
reward = tree.map_structure(
lambda x, t: ( # pylint: disable=g-long-lambda
t.dtype.type(x)
if np.isscalar(x) else np.asarray(x, dtype=t.dtype)),
reward,
self.reward_spec())
if done:
truncated = info.get('TimeLimit.truncated', False)
if truncated:
return dm_env.truncation(reward, observation)
return dm_env.termination(reward, observation)
return dm_env.transition(reward, observation)
def observation_spec(self) -> types.NestedSpec:
return self._observation_spec
def action_spec(self) -> types.NestedSpec:
return self._action_spec
def get_info(self) -> Optional[Dict[str, Any]]:
"""Returns the last info returned from env.step(action).
Returns:
info: dictionary of diagnostic information from the last environment step
"""
return self._last_info
@property
def environment(self) -> gym.Env:
"""Returns the wrapped environment."""
return self._environment
def __getattr__(self, name: str):
if name.startswith('__'):
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name))
return getattr(self._environment, name)
def close(self):
self._environment.close()
def _convert_to_spec(space: gym.Space,
name: Optional[str] = None) -> types.NestedSpec:
"""Converts an OpenAI Gym space to a dm_env spec or nested structure of specs.
Box, MultiBinary and MultiDiscrete Gym spaces are converted to BoundedArray
specs. Discrete OpenAI spaces are converted to DiscreteArray specs. Tuple and
Dict spaces are recursively converted to tuples and dictionaries of specs.
Args:
space: The Gym space to convert.
name: Optional name to apply to all return spec(s).
Returns:
A dm_env spec or nested structure of specs, corresponding to the input
space.
"""
if isinstance(space, spaces.Discrete):
return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name)
elif isinstance(space, spaces.Box):
return specs.BoundedArray(
shape=space.shape,
dtype=space.dtype,
minimum=space.low,
maximum=space.high,
name=name)
elif isinstance(space, spaces.MultiBinary):
return specs.BoundedArray(
shape=space.shape,
dtype=space.dtype,
minimum=0.0,
maximum=1.0,
name=name)
elif isinstance(space, spaces.MultiDiscrete):
return specs.BoundedArray(
shape=space.shape,
dtype=space.dtype,
minimum=np.zeros(space.shape),
maximum=space.nvec - 1,
name=name)
elif isinstance(space, spaces.Tuple):
return tuple(_convert_to_spec(s, name) for s in space.spaces)
elif isinstance(space, spaces.Dict):
return {
key: _convert_to_spec(value, key)
for key, value in space.spaces.items()
}
else:
raise ValueError('Unexpected gym space: {}'.format(space))
if __name__=='__main__':
env = gym.make("CartPole-v1")
env = GymWrapper(env)
print(env)
print(env.reset())