Skip to content

Commit f9273bb

Browse files
authored
Pytorch ghost trainer (#4370)
1 parent 71e7b17 commit f9273bb

File tree

9 files changed

+217
-19
lines changed

9 files changed

+217
-19
lines changed

ml-agents/mlagents/trainers/ghost/trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,10 @@ def save_model(self) -> None:
304304
self.trainer.save_model()
305305

306306
def create_policy(
307-
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
307+
self,
308+
parsed_behavior_id: BehaviorIdentifiers,
309+
behavior_spec: BehaviorSpec,
310+
create_graph: bool = False,
308311
) -> Policy:
309312
"""
310313
Creates policy with the wrapped trainer's create_policy function
@@ -313,10 +316,10 @@ def create_policy(
313316
team are grouped. All policies associated with this team are added to the
314317
wrapped trainer to be trained.
315318
"""
316-
policy = self.trainer.create_policy(parsed_behavior_id, behavior_spec)
317-
policy.create_tf_graph()
319+
policy = self.trainer.create_policy(
320+
parsed_behavior_id, behavior_spec, create_graph=True
321+
)
318322
self.trainer.saver.initialize_or_load(policy)
319-
policy.init_load_weights()
320323
team_id = parsed_behavior_id.team_id
321324
self.controller.subscribe_team_id(team_id, self)
322325

@@ -326,7 +329,6 @@ def create_policy(
326329
parsed_behavior_id, behavior_spec
327330
)
328331
self.trainer.add_policy(parsed_behavior_id, internal_trainer_policy)
329-
internal_trainer_policy.init_load_weights()
330332
self.current_policy_snapshot[
331333
parsed_behavior_id.brain_name
332334
] = internal_trainer_policy.get_weights()

ml-agents/mlagents/trainers/policy/tf_policy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def create_tf_graph(self) -> None:
152152
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
153153
# it will re-load the full graph
154154
self.initialize()
155+
# Create assignment ops for Ghost Trainer
156+
self.init_load_weights()
155157

156158
def _create_encoder(
157159
self,

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Dict, List, Tuple, Optional
22
import numpy as np
33
import torch
4+
import copy
45

56
from mlagents.trainers.action_info import ActionInfo
67
from mlagents.trainers.behavior_id_utils import get_global_agent_id
@@ -256,13 +257,13 @@ def increment_step(self, n_steps):
256257
return self.get_current_step()
257258

258259
def load_weights(self, values: List[np.ndarray]) -> None:
259-
pass
260+
self.actor_critic.load_state_dict(values)
260261

261262
def init_load_weights(self) -> None:
262263
pass
263264

264265
def get_weights(self) -> List[np.ndarray]:
265-
return []
266+
return copy.deepcopy(self.actor_critic.state_dict())
266267

267268
def get_modules(self):
268269
return {"Policy": self.actor_critic, "global_step": self.global_step}

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,23 @@ def _update_policy(self):
210210
return True
211211

212212
def create_tf_policy(
213-
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
213+
self,
214+
parsed_behavior_id: BehaviorIdentifiers,
215+
behavior_spec: BehaviorSpec,
216+
create_graph: bool = False,
214217
) -> TFPolicy:
215218
"""
216219
Creates a PPO policy to trainers list of policies.
217220
:param behavior_spec: specifications for policy construction
221+
:param create_graph: whether to create the graph when policy is constructed
218222
:return policy
219223
"""
220224
policy = TFPolicy(
221225
self.seed,
222226
behavior_spec,
223227
self.trainer_settings,
224228
condition_sigma_on_obs=False, # Faster training for PPO
229+
create_tf_graph=create_graph,
225230
)
226231
return policy
227232

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,18 @@ def maybe_load_replay_buffer(self):
228228
)
229229

230230
def create_tf_policy(
231-
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
231+
self,
232+
parsed_behavior_id: BehaviorIdentifiers,
233+
behavior_spec: BehaviorSpec,
234+
create_graph: bool = False,
232235
) -> TFPolicy:
233236
policy = TFPolicy(
234237
self.seed,
235238
behavior_spec,
236239
self.trainer_settings,
237240
tanh_squash=True,
238241
reparameterize=True,
239-
create_tf_graph=False,
242+
create_tf_graph=create_graph,
240243
)
241244
self.maybe_load_replay_buffer()
242245
return policy

ml-agents/mlagents/trainers/tests/test_ghost.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,9 @@ def test_load_and_set(dummy_config, use_discrete):
3838
trainer_params = dummy_config
3939
trainer = PPOTrainer("test", 0, trainer_params, True, False, 0, "0")
4040
trainer.seed = 1
41-
policy = trainer.create_policy("test", mock_specs)
42-
policy.create_tf_graph()
41+
policy = trainer.create_policy("test", mock_specs, create_graph=True)
4342
trainer.seed = 20 # otherwise graphs are the same
44-
to_load_policy = trainer.create_policy("test", mock_specs)
45-
to_load_policy.create_tf_graph()
46-
to_load_policy.init_load_weights()
43+
to_load_policy = trainer.create_policy("test", mock_specs, create_graph=True)
4744

4845
weights = policy.get_weights()
4946
load_weights = to_load_policy.get_weights()
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import pytest
2+
3+
import numpy as np
4+
5+
from mlagents.trainers.ghost.trainer import GhostTrainer
6+
from mlagents.trainers.ghost.controller import GhostController
7+
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
8+
from mlagents.trainers.ppo.trainer import PPOTrainer
9+
from mlagents.trainers.agent_processor import AgentManagerQueue
10+
from mlagents.trainers.tests import mock_brain as mb
11+
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory
12+
from mlagents.trainers.settings import TrainerSettings, SelfPlaySettings, FrameworkType
13+
14+
15+
@pytest.fixture
16+
def dummy_config():
17+
return TrainerSettings(
18+
self_play=SelfPlaySettings(), framework=FrameworkType.PYTORCH
19+
)
20+
21+
22+
VECTOR_ACTION_SPACE = 1
23+
VECTOR_OBS_SPACE = 8
24+
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
25+
BUFFER_INIT_SAMPLES = 513
26+
NUM_AGENTS = 12
27+
28+
29+
@pytest.mark.parametrize("use_discrete", [True, False])
30+
def test_load_and_set(dummy_config, use_discrete):
31+
mock_specs = mb.setup_test_behavior_specs(
32+
use_discrete,
33+
False,
34+
vector_action_space=DISCRETE_ACTION_SPACE
35+
if use_discrete
36+
else VECTOR_ACTION_SPACE,
37+
vector_obs_space=VECTOR_OBS_SPACE,
38+
)
39+
40+
trainer_params = dummy_config
41+
trainer = PPOTrainer("test", 0, trainer_params, True, False, 0, "0")
42+
trainer.seed = 1
43+
policy = trainer.create_policy("test", mock_specs)
44+
trainer.seed = 20 # otherwise graphs are the same
45+
to_load_policy = trainer.create_policy("test", mock_specs)
46+
47+
weights = policy.get_weights()
48+
load_weights = to_load_policy.get_weights()
49+
try:
50+
for w, lw in zip(weights, load_weights):
51+
np.testing.assert_array_equal(w, lw)
52+
except AssertionError:
53+
pass
54+
55+
to_load_policy.load_weights(weights)
56+
load_weights = to_load_policy.get_weights()
57+
58+
for w, lw in zip(weights, load_weights):
59+
np.testing.assert_array_equal(w, lw)
60+
61+
62+
def test_process_trajectory(dummy_config):
63+
mock_specs = mb.setup_test_behavior_specs(
64+
True, False, vector_action_space=[2], vector_obs_space=1
65+
)
66+
behavior_id_team0 = "test_brain?team=0"
67+
behavior_id_team1 = "test_brain?team=1"
68+
brain_name = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0).brain_name
69+
70+
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0")
71+
controller = GhostController(100)
72+
trainer = GhostTrainer(
73+
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0"
74+
)
75+
76+
# first policy encountered becomes policy trained by wrapped PPO
77+
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0)
78+
policy = trainer.create_policy(parsed_behavior_id0, mock_specs)
79+
trainer.add_policy(parsed_behavior_id0, policy)
80+
trajectory_queue0 = AgentManagerQueue(behavior_id_team0)
81+
trainer.subscribe_trajectory_queue(trajectory_queue0)
82+
83+
# Ghost trainer should ignore this queue because off policy
84+
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team1)
85+
policy = trainer.create_policy(parsed_behavior_id1, mock_specs)
86+
trainer.add_policy(parsed_behavior_id1, policy)
87+
trajectory_queue1 = AgentManagerQueue(behavior_id_team1)
88+
trainer.subscribe_trajectory_queue(trajectory_queue1)
89+
90+
time_horizon = 15
91+
trajectory = make_fake_trajectory(
92+
length=time_horizon,
93+
max_step_complete=True,
94+
observation_shapes=[(1,)],
95+
action_space=[2],
96+
)
97+
trajectory_queue0.put(trajectory)
98+
trainer.advance()
99+
100+
# Check that trainer put trajectory in update buffer
101+
assert trainer.trainer.update_buffer.num_experiences == 15
102+
103+
trajectory_queue1.put(trajectory)
104+
trainer.advance()
105+
106+
# Check that ghost trainer ignored off policy queue
107+
assert trainer.trainer.update_buffer.num_experiences == 15
108+
# Check that it emptied the queue
109+
assert trajectory_queue1.empty()
110+
111+
112+
def test_publish_queue(dummy_config):
113+
mock_specs = mb.setup_test_behavior_specs(
114+
True, False, vector_action_space=[1], vector_obs_space=8
115+
)
116+
117+
behavior_id_team0 = "test_brain?team=0"
118+
behavior_id_team1 = "test_brain?team=1"
119+
120+
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team0)
121+
122+
brain_name = parsed_behavior_id0.brain_name
123+
124+
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0")
125+
controller = GhostController(100)
126+
trainer = GhostTrainer(
127+
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0"
128+
)
129+
130+
# First policy encountered becomes policy trained by wrapped PPO
131+
# This queue should remain empty after swap snapshot
132+
policy = trainer.create_policy(parsed_behavior_id0, mock_specs)
133+
trainer.add_policy(parsed_behavior_id0, policy)
134+
policy_queue0 = AgentManagerQueue(behavior_id_team0)
135+
trainer.publish_policy_queue(policy_queue0)
136+
137+
# Ghost trainer should use this queue for ghost policy swap
138+
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(behavior_id_team1)
139+
policy = trainer.create_policy(parsed_behavior_id1, mock_specs)
140+
trainer.add_policy(parsed_behavior_id1, policy)
141+
policy_queue1 = AgentManagerQueue(behavior_id_team1)
142+
trainer.publish_policy_queue(policy_queue1)
143+
144+
# check ghost trainer swap pushes to ghost queue and not trainer
145+
assert policy_queue0.empty() and policy_queue1.empty()
146+
trainer._swap_snapshots()
147+
assert policy_queue0.empty() and not policy_queue1.empty()
148+
# clear
149+
policy_queue1.get_nowait()
150+
151+
mock_specs = mb.setup_test_behavior_specs(
152+
False,
153+
False,
154+
vector_action_space=VECTOR_ACTION_SPACE,
155+
vector_obs_space=VECTOR_OBS_SPACE,
156+
)
157+
158+
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_specs)
159+
# Mock out reward signal eval
160+
buffer["extrinsic_rewards"] = buffer["environment_rewards"]
161+
buffer["extrinsic_returns"] = buffer["environment_rewards"]
162+
buffer["extrinsic_value_estimates"] = buffer["environment_rewards"]
163+
buffer["curiosity_rewards"] = buffer["environment_rewards"]
164+
buffer["curiosity_returns"] = buffer["environment_rewards"]
165+
buffer["curiosity_value_estimates"] = buffer["environment_rewards"]
166+
buffer["advantages"] = buffer["environment_rewards"]
167+
trainer.trainer.update_buffer = buffer
168+
169+
# when ghost trainer advance and wrapped trainer buffers full
170+
# the wrapped trainer pushes updated policy to correct queue
171+
assert policy_queue0.empty() and policy_queue1.empty()
172+
trainer.advance()
173+
assert not policy_queue0.empty() and policy_queue1.empty()
174+
175+
176+
if __name__ == "__main__":
177+
pytest.main()

ml-agents/mlagents/trainers/trainer/rl_trainer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def _is_ready_update(self):
113113
return False
114114

115115
def create_policy(
116-
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
116+
self,
117+
parsed_behavior_id: BehaviorIdentifiers,
118+
behavior_spec: BehaviorSpec,
119+
create_graph: bool = False,
117120
) -> Policy:
118121
if self.framework == FrameworkType.PYTORCH and TorchPolicy is None:
119122
raise UnityTrainerException(
@@ -122,7 +125,9 @@ def create_policy(
122125
elif self.framework == FrameworkType.PYTORCH:
123126
return self.create_torch_policy(parsed_behavior_id, behavior_spec)
124127
else:
125-
return self.create_tf_policy(parsed_behavior_id, behavior_spec)
128+
return self.create_tf_policy(
129+
parsed_behavior_id, behavior_spec, create_graph=create_graph
130+
)
126131

127132
@abc.abstractmethod
128133
def create_torch_policy(
@@ -135,7 +140,10 @@ def create_torch_policy(
135140

136141
@abc.abstractmethod
137142
def create_tf_policy(
138-
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
143+
self,
144+
parsed_behavior_id: BehaviorIdentifiers,
145+
behavior_spec: BehaviorSpec,
146+
create_graph: bool = False,
139147
) -> TFPolicy:
140148
"""
141149
Create a Policy object that uses the TensorFlow backend.

ml-agents/mlagents/trainers/trainer/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ def end_episode(self):
125125

126126
@abc.abstractmethod
127127
def create_policy(
128-
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
128+
self,
129+
parsed_behavior_id: BehaviorIdentifiers,
130+
behavior_spec: BehaviorSpec,
131+
create_graph: bool = False,
129132
) -> Policy:
130133
"""
131134
Creates policy

0 commit comments

Comments
 (0)