44import numpy as np
55from mlagents .envs .environment import UnityEnvironment
66from gym import error , spaces
7- from mlagents .envs .brain_conversion_utils import (
8- step_result_to_brain_info ,
9- group_spec_to_brain_parameters ,
10- )
117
128
139class UnityGymException (error .Error ):
@@ -78,16 +74,14 @@ def __init__(
7874
7975 self .brain_name = self ._env .get_agent_groups ()[0 ]
8076 self .name = self .brain_name
81- brain = group_spec_to_brain_parameters (
82- self .brain_name , self ._env .get_agent_group_spec (self .brain_name )
83- )
77+ self .group_spec = self ._env .get_agent_group_spec (self .brain_name )
8478
85- if use_visual and brain . number_visual_observations == 0 :
79+ if use_visual and self . _get_n_vis_obs () == 0 :
8680 raise UnityGymException (
8781 "`use_visual` was set to True, however there are no"
8882 " visual observations as part of this environment."
8983 )
90- self .use_visual = brain . number_visual_observations >= 1 and use_visual
84+ self .use_visual = self . _get_n_vis_obs () >= 1 and use_visual
9185
9286 if not use_visual and uint8_visual :
9387 logger .warning (
@@ -97,7 +91,7 @@ def __init__(
9791 else :
9892 self .uint8_visual = uint8_visual
9993
100- if brain . number_visual_observations > 1 and not self ._allow_multiple_visual_obs :
94+ if self . _get_n_vis_obs () > 1 and not self ._allow_multiple_visual_obs :
10195 logger .warning (
10296 "The environment contains more than one visual observation. "
10397 "You must define allow_multiple_visual_obs=True to received them all. "
@@ -106,41 +100,32 @@ def __init__(
106100
107101 # Check for number of agents in scene.
108102 self ._env .reset ()
109- initial_info = step_result_to_brain_info (
110- self ._env .get_step_result (self .brain_name ),
111- self ._env .get_agent_group_spec (self .brain_name ),
112- )
113- self ._check_agents (len (initial_info .agents ))
103+ step_result = self ._env .get_step_result (self .brain_name )
104+ self ._check_agents (step_result .n_agents ())
114105
115106 # Set observation and action spaces
116- if brain .vector_action_space_type == "discrete" :
117- if len (brain .vector_action_space_size ) == 1 :
118- self ._action_space = spaces .Discrete (brain .vector_action_space_size [0 ])
107+ if self .group_spec .is_action_discrete ():
108+ branches = self .group_spec .discrete_action_branches
109+ if self .group_spec .action_shape == 1 :
110+ self ._action_space = spaces .Discrete (branches [0 ])
119111 else :
120112 if flatten_branched :
121- self ._flattener = ActionFlattener (brain . vector_action_space_size )
113+ self ._flattener = ActionFlattener (branches )
122114 self ._action_space = self ._flattener .action_space
123115 else :
124- self ._action_space = spaces .MultiDiscrete (
125- brain .vector_action_space_size
126- )
116+ self ._action_space = spaces .MultiDiscrete (branches )
127117
128118 else :
129119 if flatten_branched :
130120 logger .warning (
131121 "The environment has a non-discrete action space. It will "
132122 "not be flattened."
133123 )
134- high = np .array ([1 ] * brain . vector_action_space_size [ 0 ] )
124+ high = np .array ([1 ] * self . group_spec . action_shape )
135125 self ._action_space = spaces .Box (- high , high , dtype = np .float32 )
136- high = np .array ([np .inf ] * brain .vector_observation_space_size )
137- self .action_meanings = brain .vector_action_descriptions
126+ high = np .array ([np .inf ] * self ._get_vec_obs_size ())
138127 if self .use_visual :
139- shape = (
140- brain .camera_resolutions [0 ].height ,
141- brain .camera_resolutions [0 ].width ,
142- brain .camera_resolutions [0 ].num_channels ,
143- )
128+ shape = self ._get_vis_obs_shape ()
144129 if uint8_visual :
145130 self ._observation_space = spaces .Box (
146131 0 , 255 , dtype = np .uint8 , shape = shape
@@ -160,11 +145,8 @@ def reset(self):
160145 space.
161146 """
162147 self ._env .reset ()
163- info = step_result_to_brain_info (
164- self ._env .get_step_result (self .brain_name ),
165- self ._env .get_agent_group_spec (self .brain_name ),
166- )
167- n_agents = len (info .agents )
148+ info = self ._env .get_step_result (self .brain_name )
149+ n_agents = info .n_agents ()
168150 self ._check_agents (n_agents )
169151 self .game_over = False
170152
@@ -211,14 +193,12 @@ def step(self, action):
211193 # Translate action into list
212194 action = self ._flattener .lookup_action (action )
213195
214- spec = self ._env . get_agent_group_spec ( self . brain_name )
196+ spec = self .group_spec
215197 action = np .array (action ).reshape ((self ._n_agents , spec .action_size ))
216198 self ._env .set_actions (self .brain_name , action )
217199 self ._env .step ()
218- info = step_result_to_brain_info (
219- self ._env .get_step_result (self .brain_name ), spec
220- )
221- n_agents = len (info .agents )
200+ info = self ._env .get_step_result (self .brain_name )
201+ n_agents = info .n_agents ()
222202 self ._check_agents (n_agents )
223203 self ._current_state = info
224204
@@ -232,7 +212,7 @@ def step(self, action):
232212
233213 def _single_step (self , info ):
234214 if self .use_visual :
235- visual_obs = info . visual_observations
215+ visual_obs = self . _get_vis_obs_list ( info )
236216
237217 if self ._allow_multiple_visual_obs :
238218 visual_obs_list = []
@@ -244,14 +224,9 @@ def _single_step(self, info):
244224
245225 default_observation = self .visual_obs
246226 else :
247- default_observation = info . vector_observations [0 , :]
227+ default_observation = self . _get_vector_obs ( info ) [0 , :]
248228
249- return (
250- default_observation ,
251- info .rewards [0 ],
252- info .local_done [0 ],
253- {"text_observation" : None , "brain_info" : info },
254- )
229+ return (default_observation , info .reward [0 ], info .done [0 ], info )
255230
256231 def _preprocess_single (self , single_visual_obs ):
257232 if self .uint8_visual :
@@ -261,16 +236,44 @@ def _preprocess_single(self, single_visual_obs):
261236
262237 def _multi_step (self , info ):
263238 if self .use_visual :
264- self .visual_obs = self ._preprocess_multi (info . visual_observations )
239+ self .visual_obs = self ._preprocess_multi (self . _get_vis_obs_list ( info ) )
265240 default_observation = self .visual_obs
266241 else :
267- default_observation = info .vector_observations
268- return (
269- list (default_observation ),
270- info .rewards ,
271- info .local_done ,
272- {"text_observation" : None , "brain_info" : info },
273- )
242+ default_observation = self ._get_vector_obs (info )
243+ return (list (default_observation ), list (info .reward ), list (info .done ), info )
244+
245+ def _get_n_vis_obs (self ) -> int :
246+ result = 0
247+ for shape in self .group_spec .observation_shapes :
248+ if len (shape ) == 3 :
249+ result += 1
250+ return result
251+
252+ def _get_vis_obs_shape (self ):
253+ for shape in self .group_spec .observation_shapes :
254+ if len (shape ) == 3 :
255+ return shape
256+
257+ def _get_vis_obs_list (self , step_result ):
258+ result = []
259+ for obs in step_result .obs :
260+ if len (obs .shape ) == 4 :
261+ result += [obs ]
262+ return result
263+
264+ def _get_vector_obs (self , step_result ):
265+ result = []
266+ for obs in step_result .obs :
267+ if len (obs .shape ) == 2 :
268+ result += [obs ]
269+ return np .concatenate (result , axis = 1 )
270+
271+ def _get_vec_obs_size (self ) -> int :
272+ result = 0
273+ for shape in self .group_spec .observation_shapes :
274+ if len (shape ) == 1 :
275+ result += shape [0 ]
276+ return result
274277
275278 def _preprocess_multi (self , multiple_visual_obs ):
276279 if self .uint8_visual :
@@ -291,9 +294,6 @@ def close(self):
291294 """
292295 self ._env .close ()
293296
294- def get_action_meanings (self ):
295- return self .action_meanings
296-
297297 def seed (self , seed = None ):
298298 """Sets the seed for this env's random number generator(s).
299299 Currently not implemented.
0 commit comments