11
11
from mlagents .tf_utils import tf
12
12
13
13
from mlagents_envs .logging_util import get_logger
14
- from mlagents .trainers .env_manager import EnvManager
14
+ from mlagents .trainers .env_manager import EnvManager , EnvironmentStep
15
15
from mlagents_envs .exception import (
16
16
UnityEnvironmentException ,
17
17
UnityCommunicationException ,
@@ -59,6 +59,7 @@ def __init__(
59
59
self .train_model = train
60
60
self .param_manager = param_manager
61
61
self .ghost_controller = self .trainer_factory .ghost_controller
62
+ self .registered_behavior_ids : Set [str ] = set ()
62
63
63
64
self .trainer_threads : List [threading .Thread ] = []
64
65
self .kill_trainers = False
@@ -101,15 +102,17 @@ def _create_output_path(output_path):
101
102
)
102
103
103
104
@timed
104
- def _reset_env (self , env : EnvManager ) -> None :
105
+ def _reset_env (self , env_manager : EnvManager ) -> None :
105
106
"""Resets the environment.
106
107
107
108
Returns:
108
109
A Data structure corresponding to the initial reset state of the
109
110
environment.
110
111
"""
111
112
new_config = self .param_manager .get_current_samplers ()
112
- env .reset (config = new_config )
113
+ env_manager .reset (config = new_config )
114
+ # Register any new behavior ids that were generated on the reset.
115
+ self ._register_new_behaviors (env_manager , env_manager .first_step_infos )
113
116
114
117
def _not_done_training (self ) -> bool :
115
118
return (
@@ -169,15 +172,10 @@ def _create_trainers_and_managers(
169
172
def start_learning (self , env_manager : EnvManager ) -> None :
170
173
self ._create_output_path (self .output_path )
171
174
tf .reset_default_graph ()
172
- last_brain_behavior_ids : Set [str ] = set ()
173
175
try :
174
176
# Initial reset
175
177
self ._reset_env (env_manager )
176
178
while self ._not_done_training ():
177
- external_brain_behavior_ids = set (env_manager .training_behaviors .keys ())
178
- new_behavior_ids = external_brain_behavior_ids - last_brain_behavior_ids
179
- self ._create_trainers_and_managers (env_manager , new_behavior_ids )
180
- last_brain_behavior_ids = external_brain_behavior_ids
181
179
n_steps = self .advance (env_manager )
182
180
for _ in range (n_steps ):
183
181
self .reset_env_if_ready (env_manager )
@@ -233,10 +231,12 @@ def reset_env_if_ready(self, env: EnvManager) -> None:
233
231
env .set_env_parameters (self .param_manager .get_current_samplers ())
234
232
235
233
@timed
236
- def advance (self , env : EnvManager ) -> int :
234
+ def advance (self , env_manager : EnvManager ) -> int :
237
235
# Get steps
238
236
with hierarchical_timer ("env_step" ):
239
- num_steps = env .advance ()
237
+ new_step_infos = env_manager .get_steps ()
238
+ self ._register_new_behaviors (env_manager , new_step_infos )
239
+ num_steps = env_manager .process_steps (new_step_infos )
240
240
241
241
# Report current lesson for each environment parameter
242
242
for (
@@ -255,6 +255,22 @@ def advance(self, env: EnvManager) -> int:
255
255
256
256
return num_steps
257
257
258
+ def _register_new_behaviors (
259
+ self , env_manager : EnvManager , step_infos : List [EnvironmentStep ]
260
+ ) -> None :
261
+ """
262
+ Handle registration (adding trainers and managers) of new behaviors ids.
263
+ :param env_manager:
264
+ :param step_infos:
265
+ :return:
266
+ """
267
+ step_behavior_ids : Set [str ] = set ()
268
+ for s in step_infos :
269
+ step_behavior_ids |= set (s .name_behavior_ids )
270
+ new_behavior_ids = step_behavior_ids - self .registered_behavior_ids
271
+ self ._create_trainers_and_managers (env_manager , new_behavior_ids )
272
+ self .registered_behavior_ids |= step_behavior_ids
273
+
258
274
def join_threads (self , timeout_seconds : float = 1.0 ) -> None :
259
275
"""
260
276
Wait for threads to finish, and merge their timer information into the main thread.
0 commit comments