Skip to content

Commit 2cd69c0

Browse files
author
Chris Elion
authored
[MLA-1172] Reduce calls to training_behaviors (#4259)
1 parent 400975b commit 2cd69c0

File tree

5 files changed

+49
-20
lines changed

5 files changed

+49
-20
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ and this project adheres to
1111
### Major Changes
1212
#### com.unity.ml-agents (C#)
1313
#### ml-agents / ml-agents-envs / gym-unity (Python)
14-
The minimum supported python version for ml-agents-envs was changed to 3.6.1. (#4244)
14+
- The minimum supported python version for ml-agents-envs was changed to 3.6.1. (#4244)
15+
- The interaction between EnvManager and TrainerController was changed; EnvManager.advance() was split into to stages,
16+
and TrainerController now uses the results from the first stage to handle new behavior names. This change speeds up
17+
Python training by approximately 5-10%. (#4259)
1518

1619
### Minor Changes
1720
#### com.unity.ml-agents (C#)

ml-agents/mlagents/trainers/env_manager.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class EnvManager(ABC):
3838
def __init__(self):
3939
self.policies: Dict[BehaviorName, TFPolicy] = {}
4040
self.agent_managers: Dict[BehaviorName, AgentManager] = {}
41-
self.first_step_infos: List[EnvironmentStep] = None
41+
self.first_step_infos: List[EnvironmentStep] = []
4242

4343
def set_policy(self, brain_name: BehaviorName, policy: TFPolicy) -> None:
4444
self.policies[brain_name] = policy
@@ -84,15 +84,20 @@ def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
8484
def close(self):
8585
pass
8686

87-
def advance(self):
87+
def get_steps(self) -> List[EnvironmentStep]:
88+
"""
89+
Updates the policies, steps the environments, and returns the step information from the environments.
90+
Calling code should pass the returned EnvironmentSteps to process_steps() after calling this.
91+
:return: The list of EnvironmentSteps
92+
"""
8893
# If we had just reset, process the first EnvironmentSteps.
8994
# Note that we do it here instead of in reset() so that on the very first reset(),
9095
# we can create the needed AgentManagers before calling advance() and processing the EnvironmentSteps.
91-
if self.first_step_infos is not None:
96+
if self.first_step_infos:
9297
self._process_step_infos(self.first_step_infos)
93-
self.first_step_infos = None
98+
self.first_step_infos = []
9499
# Get new policies if found. Always get the latest policy.
95-
for brain_name in self.training_behaviors:
100+
for brain_name in self.agent_managers.keys():
96101
_policy = None
97102
try:
98103
# We make sure to empty the policy queue before continuing to produce steps.
@@ -101,9 +106,13 @@ def advance(self):
101106
_policy = self.agent_managers[brain_name].policy_queue.get_nowait()
102107
except AgentManagerQueue.Empty:
103108
if _policy is not None:
104-
self.set_policy(brain_name, _policy)
105-
# Step the environment
109+
# policy_queue contains Policy, but we need a TFPolicy here
110+
self.set_policy(brain_name, _policy) # type: ignore
111+
# Step the environments
106112
new_step_infos = self._step()
113+
return new_step_infos
114+
115+
def process_steps(self, new_step_infos: List[EnvironmentStep]) -> int:
107116
# Add to AgentProcessor
108117
num_step_infos = self._process_step_infos(new_step_infos)
109118
return num_step_infos

ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_advance(self, mock_create_worker, training_behaviors_mock, step_mock):
166166
}
167167
step_info = EnvironmentStep(step_info_dict, 0, action_info_dict, env_stats)
168168
step_mock.return_value = [step_info]
169-
env_manager.advance()
169+
env_manager.process_steps(env_manager.get_steps())
170170

171171
# Test add_experiences
172172
env_manager._step.assert_called_once()

ml-agents/mlagents/trainers/tests/test_trainer_controller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def test_advance_adds_experiences_to_trainer_and_trains(
141141
tc.advance(env_mock)
142142

143143
env_mock.reset.assert_not_called()
144-
env_mock.advance.assert_called_once()
144+
env_mock.get_steps.assert_called_once()
145+
env_mock.process_steps.assert_called_once()
145146
# May have been called many times due to thread
146147
trainer_mock.advance.call_count > 0

ml-agents/mlagents/trainers/trainer_controller.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mlagents.tf_utils import tf
1212

1313
from mlagents_envs.logging_util import get_logger
14-
from mlagents.trainers.env_manager import EnvManager
14+
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep
1515
from mlagents_envs.exception import (
1616
UnityEnvironmentException,
1717
UnityCommunicationException,
@@ -59,6 +59,7 @@ def __init__(
5959
self.train_model = train
6060
self.param_manager = param_manager
6161
self.ghost_controller = self.trainer_factory.ghost_controller
62+
self.registered_behavior_ids: Set[str] = set()
6263

6364
self.trainer_threads: List[threading.Thread] = []
6465
self.kill_trainers = False
@@ -101,15 +102,17 @@ def _create_output_path(output_path):
101102
)
102103

103104
@timed
104-
def _reset_env(self, env: EnvManager) -> None:
105+
def _reset_env(self, env_manager: EnvManager) -> None:
105106
"""Resets the environment.
106107
107108
Returns:
108109
A Data structure corresponding to the initial reset state of the
109110
environment.
110111
"""
111112
new_config = self.param_manager.get_current_samplers()
112-
env.reset(config=new_config)
113+
env_manager.reset(config=new_config)
114+
# Register any new behavior ids that were generated on the reset.
115+
self._register_new_behaviors(env_manager, env_manager.first_step_infos)
113116

114117
def _not_done_training(self) -> bool:
115118
return (
@@ -169,15 +172,10 @@ def _create_trainers_and_managers(
169172
def start_learning(self, env_manager: EnvManager) -> None:
170173
self._create_output_path(self.output_path)
171174
tf.reset_default_graph()
172-
last_brain_behavior_ids: Set[str] = set()
173175
try:
174176
# Initial reset
175177
self._reset_env(env_manager)
176178
while self._not_done_training():
177-
external_brain_behavior_ids = set(env_manager.training_behaviors.keys())
178-
new_behavior_ids = external_brain_behavior_ids - last_brain_behavior_ids
179-
self._create_trainers_and_managers(env_manager, new_behavior_ids)
180-
last_brain_behavior_ids = external_brain_behavior_ids
181179
n_steps = self.advance(env_manager)
182180
for _ in range(n_steps):
183181
self.reset_env_if_ready(env_manager)
@@ -233,10 +231,12 @@ def reset_env_if_ready(self, env: EnvManager) -> None:
233231
env.set_env_parameters(self.param_manager.get_current_samplers())
234232

235233
@timed
236-
def advance(self, env: EnvManager) -> int:
234+
def advance(self, env_manager: EnvManager) -> int:
237235
# Get steps
238236
with hierarchical_timer("env_step"):
239-
num_steps = env.advance()
237+
new_step_infos = env_manager.get_steps()
238+
self._register_new_behaviors(env_manager, new_step_infos)
239+
num_steps = env_manager.process_steps(new_step_infos)
240240

241241
# Report current lesson for each environment parameter
242242
for (
@@ -255,6 +255,22 @@ def advance(self, env: EnvManager) -> int:
255255

256256
return num_steps
257257

258+
def _register_new_behaviors(
259+
self, env_manager: EnvManager, step_infos: List[EnvironmentStep]
260+
) -> None:
261+
"""
262+
Handle registration (adding trainers and managers) of new behaviors ids.
263+
:param env_manager:
264+
:param step_infos:
265+
:return:
266+
"""
267+
step_behavior_ids: Set[str] = set()
268+
for s in step_infos:
269+
step_behavior_ids |= set(s.name_behavior_ids)
270+
new_behavior_ids = step_behavior_ids - self.registered_behavior_ids
271+
self._create_trainers_and_managers(env_manager, new_behavior_ids)
272+
self.registered_behavior_ids |= step_behavior_ids
273+
258274
def join_threads(self, timeout_seconds: float = 1.0) -> None:
259275
"""
260276
Wait for threads to finish, and merge their timer information into the main thread.

0 commit comments

Comments
 (0)