Skip to content

Commit 7562a25

Browse files
Gym no longer uses brain infos (#3060)
1 parent 1e5e422 commit 7562a25

File tree

4 files changed

+61
-64
lines changed

4 files changed

+61
-64
lines changed

gym-unity/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ For more on using the gym interface, see our
6969
observations by using the `allow_multiple_visual_obs=True` option in the gym
7070
parameters. If set to `True`, you will receive a list of `observation` instead
7171
of only the first one.
72-
* All `BrainInfo` output from the environment can still be accessed from the
72+
* The `BatchedStepResult` output from the environment can still be accessed from the
7373
`info` provided by `env.step(action)`.
7474
* Stacked vector observations are not supported.
7575
* Environment registration for use with `gym.make()` is currently not supported.

gym-unity/gym_unity/envs/__init__.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
import numpy as np
55
from mlagents.envs.environment import UnityEnvironment
66
from gym import error, spaces
7-
from mlagents.envs.brain_conversion_utils import (
8-
step_result_to_brain_info,
9-
group_spec_to_brain_parameters,
10-
)
117

128

139
class UnityGymException(error.Error):
@@ -78,16 +74,14 @@ def __init__(
7874

7975
self.brain_name = self._env.get_agent_groups()[0]
8076
self.name = self.brain_name
81-
brain = group_spec_to_brain_parameters(
82-
self.brain_name, self._env.get_agent_group_spec(self.brain_name)
83-
)
77+
self.group_spec = self._env.get_agent_group_spec(self.brain_name)
8478

85-
if use_visual and brain.number_visual_observations == 0:
79+
if use_visual and self._get_n_vis_obs() == 0:
8680
raise UnityGymException(
8781
"`use_visual` was set to True, however there are no"
8882
" visual observations as part of this environment."
8983
)
90-
self.use_visual = brain.number_visual_observations >= 1 and use_visual
84+
self.use_visual = self._get_n_vis_obs() >= 1 and use_visual
9185

9286
if not use_visual and uint8_visual:
9387
logger.warning(
@@ -97,7 +91,7 @@ def __init__(
9791
else:
9892
self.uint8_visual = uint8_visual
9993

100-
if brain.number_visual_observations > 1 and not self._allow_multiple_visual_obs:
94+
if self._get_n_vis_obs() > 1 and not self._allow_multiple_visual_obs:
10195
logger.warning(
10296
"The environment contains more than one visual observation. "
10397
"You must define allow_multiple_visual_obs=True to received them all. "
@@ -106,41 +100,32 @@ def __init__(
106100

107101
# Check for number of agents in scene.
108102
self._env.reset()
109-
initial_info = step_result_to_brain_info(
110-
self._env.get_step_result(self.brain_name),
111-
self._env.get_agent_group_spec(self.brain_name),
112-
)
113-
self._check_agents(len(initial_info.agents))
103+
step_result = self._env.get_step_result(self.brain_name)
104+
self._check_agents(step_result.n_agents())
114105

115106
# Set observation and action spaces
116-
if brain.vector_action_space_type == "discrete":
117-
if len(brain.vector_action_space_size) == 1:
118-
self._action_space = spaces.Discrete(brain.vector_action_space_size[0])
107+
if self.group_spec.is_action_discrete():
108+
branches = self.group_spec.discrete_action_branches
109+
if self.group_spec.action_shape == 1:
110+
self._action_space = spaces.Discrete(branches[0])
119111
else:
120112
if flatten_branched:
121-
self._flattener = ActionFlattener(brain.vector_action_space_size)
113+
self._flattener = ActionFlattener(branches)
122114
self._action_space = self._flattener.action_space
123115
else:
124-
self._action_space = spaces.MultiDiscrete(
125-
brain.vector_action_space_size
126-
)
116+
self._action_space = spaces.MultiDiscrete(branches)
127117

128118
else:
129119
if flatten_branched:
130120
logger.warning(
131121
"The environment has a non-discrete action space. It will "
132122
"not be flattened."
133123
)
134-
high = np.array([1] * brain.vector_action_space_size[0])
124+
high = np.array([1] * self.group_spec.action_shape)
135125
self._action_space = spaces.Box(-high, high, dtype=np.float32)
136-
high = np.array([np.inf] * brain.vector_observation_space_size)
137-
self.action_meanings = brain.vector_action_descriptions
126+
high = np.array([np.inf] * self._get_vec_obs_size())
138127
if self.use_visual:
139-
shape = (
140-
brain.camera_resolutions[0].height,
141-
brain.camera_resolutions[0].width,
142-
brain.camera_resolutions[0].num_channels,
143-
)
128+
shape = self._get_vis_obs_shape()
144129
if uint8_visual:
145130
self._observation_space = spaces.Box(
146131
0, 255, dtype=np.uint8, shape=shape
@@ -160,11 +145,8 @@ def reset(self):
160145
space.
161146
"""
162147
self._env.reset()
163-
info = step_result_to_brain_info(
164-
self._env.get_step_result(self.brain_name),
165-
self._env.get_agent_group_spec(self.brain_name),
166-
)
167-
n_agents = len(info.agents)
148+
info = self._env.get_step_result(self.brain_name)
149+
n_agents = info.n_agents()
168150
self._check_agents(n_agents)
169151
self.game_over = False
170152

@@ -211,14 +193,12 @@ def step(self, action):
211193
# Translate action into list
212194
action = self._flattener.lookup_action(action)
213195

214-
spec = self._env.get_agent_group_spec(self.brain_name)
196+
spec = self.group_spec
215197
action = np.array(action).reshape((self._n_agents, spec.action_size))
216198
self._env.set_actions(self.brain_name, action)
217199
self._env.step()
218-
info = step_result_to_brain_info(
219-
self._env.get_step_result(self.brain_name), spec
220-
)
221-
n_agents = len(info.agents)
200+
info = self._env.get_step_result(self.brain_name)
201+
n_agents = info.n_agents()
222202
self._check_agents(n_agents)
223203
self._current_state = info
224204

@@ -232,7 +212,7 @@ def step(self, action):
232212

233213
def _single_step(self, info):
234214
if self.use_visual:
235-
visual_obs = info.visual_observations
215+
visual_obs = self._get_vis_obs_list(info)
236216

237217
if self._allow_multiple_visual_obs:
238218
visual_obs_list = []
@@ -244,14 +224,9 @@ def _single_step(self, info):
244224

245225
default_observation = self.visual_obs
246226
else:
247-
default_observation = info.vector_observations[0, :]
227+
default_observation = self._get_vector_obs(info)[0, :]
248228

249-
return (
250-
default_observation,
251-
info.rewards[0],
252-
info.local_done[0],
253-
{"text_observation": None, "brain_info": info},
254-
)
229+
return (default_observation, info.reward[0], info.done[0], info)
255230

256231
def _preprocess_single(self, single_visual_obs):
257232
if self.uint8_visual:
@@ -261,16 +236,44 @@ def _preprocess_single(self, single_visual_obs):
261236

262237
def _multi_step(self, info):
263238
if self.use_visual:
264-
self.visual_obs = self._preprocess_multi(info.visual_observations)
239+
self.visual_obs = self._preprocess_multi(self._get_vis_obs_list(info))
265240
default_observation = self.visual_obs
266241
else:
267-
default_observation = info.vector_observations
268-
return (
269-
list(default_observation),
270-
info.rewards,
271-
info.local_done,
272-
{"text_observation": None, "brain_info": info},
273-
)
242+
default_observation = self._get_vector_obs(info)
243+
return (list(default_observation), list(info.reward), list(info.done), info)
244+
245+
def _get_n_vis_obs(self) -> int:
246+
result = 0
247+
for shape in self.group_spec.observation_shapes:
248+
if len(shape) == 3:
249+
result += 1
250+
return result
251+
252+
def _get_vis_obs_shape(self):
253+
for shape in self.group_spec.observation_shapes:
254+
if len(shape) == 3:
255+
return shape
256+
257+
def _get_vis_obs_list(self, step_result):
258+
result = []
259+
for obs in step_result.obs:
260+
if len(obs.shape) == 4:
261+
result += [obs]
262+
return result
263+
264+
def _get_vector_obs(self, step_result):
265+
result = []
266+
for obs in step_result.obs:
267+
if len(obs.shape) == 2:
268+
result += [obs]
269+
return np.concatenate(result, axis=1)
270+
271+
def _get_vec_obs_size(self) -> int:
272+
result = 0
273+
for shape in self.group_spec.observation_shapes:
274+
if len(shape) == 1:
275+
result += shape[0]
276+
return result
274277

275278
def _preprocess_multi(self, multiple_visual_obs):
276279
if self.uint8_visual:
@@ -291,9 +294,6 @@ def close(self):
291294
"""
292295
self._env.close()
293296

294-
def get_action_meanings(self):
295-
return self.action_meanings
296-
297297
def seed(self, seed=None):
298298
"""Sets the seed for this env's random number generator(s).
299299
Currently not implemented.

gym-unity/gym_unity/tests/test_gym.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def test_gym_wrapper(mock_env):
2323
assert isinstance(obs, np.ndarray)
2424
assert isinstance(rew, float)
2525
assert isinstance(done, (bool, np.bool_))
26-
assert isinstance(info, dict)
2726

2827

2928
@mock.patch("gym_unity.envs.UnityEnvironment")
@@ -42,7 +41,6 @@ def test_multi_agent(mock_env):
4241
assert isinstance(obs, list)
4342
assert isinstance(rew, list)
4443
assert isinstance(done, list)
45-
assert isinstance(info, dict)
4644

4745

4846
@mock.patch("gym_unity.envs.UnityEnvironment")
@@ -81,7 +79,6 @@ def test_gym_wrapper_visual(mock_env, use_uint8):
8179
assert isinstance(obs, np.ndarray)
8280
assert isinstance(rew, float)
8381
assert isinstance(done, (bool, np.bool_))
84-
assert isinstance(info, dict)
8582

8683

8784
# Helper methods

notebooks/getting-started-gym.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
"metadata": {},
6464
"outputs": [],
6565
"source": [
66-
"env_name = \"../envs/3DBall\" # Name of the Unity environment binary to launch\n",
66+
"env_name = \"../envs/GridWorld\" # Name of the Unity environment binary to launch\n",
6767
"env = UnityEnv(env_name, worker_id=0, use_visual=True)\n",
6868
"\n",
6969
"# Examine environment parameters\n",

0 commit comments

Comments
 (0)