Skip to content

Commit 17bacbb

Browse files
author
Ervin T
authored
[refactor] Refactor Actor and Critic classes (#4287)
1 parent 38c3dd1 commit 17bacbb

File tree

10 files changed

+617
-143
lines changed

10 files changed

+617
-143
lines changed

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_value_estimates(
6262
"""
6363
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
6464

65-
value_estimates, mean_value = self.policy.actor_critic.critic_pass(
65+
value_estimates = self.policy.actor_critic.critic_pass(
6666
np.expand_dims(vec_vis_obs.vector_observations[idx], 0),
6767
np.expand_dims(vec_vis_obs.visual_observations[idx], 0),
6868
)
@@ -97,11 +97,11 @@ def get_trajectory_value_estimates(
9797
next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
9898
next_memory = torch.zeros([1, 1, self.policy.m_size])
9999

100-
value_estimates, mean_value = self.policy.actor_critic.critic_pass(
100+
value_estimates = self.policy.actor_critic.critic_pass(
101101
vector_obs, visual_obs, memory
102102
)
103103

104-
next_value_estimate, next_value = self.policy.actor_critic.critic_pass(
104+
next_value_estimate = self.policy.actor_critic.critic_pass(
105105
next_obs, next_obs, next_memory
106106
)
107107

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List
22
import numpy as np
33
import torch
44

@@ -14,7 +14,8 @@
1414

1515
from mlagents.trainers.settings import TrainerSettings, TestingConfiguration
1616
from mlagents.trainers.trajectory import SplitObservations
17-
from mlagents.trainers.torch.networks import ActorCritic
17+
from mlagents.trainers.torch.networks import SharedActorCritic, SeparateActorCritic
18+
from mlagents.trainers.torch.utils import ModelUtils
1819

1920
EPSILON = 1e-7 # Small value to avoid divide by zero
2021

@@ -29,8 +30,8 @@ def __init__(
2930
load: bool = False,
3031
tanh_squash: bool = False,
3132
reparameterize: bool = False,
33+
separate_critic: bool = True,
3234
condition_sigma_on_obs: bool = True,
33-
separate_critic: Optional[bool] = None,
3435
):
3536
"""
3637
Policy that uses a multilayer perceptron to map the observations to actions. Could
@@ -69,15 +70,16 @@ def __init__(
6970
"Losses/Value Loss": "value_loss",
7071
"Losses/Policy Loss": "policy_loss",
7172
}
72-
self.actor_critic = ActorCritic(
73+
if separate_critic:
74+
ac_class = SeparateActorCritic
75+
else:
76+
ac_class = SharedActorCritic
77+
self.actor_critic = ac_class(
7378
observation_shapes=self.behavior_spec.observation_shapes,
7479
network_settings=trainer_settings.network_settings,
7580
act_type=behavior_spec.action_type,
7681
act_size=self.act_size,
7782
stream_names=reward_signal_names,
78-
separate_critic=separate_critic
79-
if separate_critic is not None
80-
else self.use_continuous_act,
8183
conditional_sigma=self.condition_sigma_on_obs,
8284
tanh_squash=tanh_squash,
8385
)
@@ -117,16 +119,11 @@ def sample_actions(
117119
"""
118120
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
119121
"""
120-
(
121-
dists,
122-
(value_heads, mean_value),
123-
memories,
124-
) = self.actor_critic.get_dist_and_value(
122+
dists, value_heads, memories = self.actor_critic.get_dist_and_value(
125123
vec_obs, vis_obs, masks, memories, seq_len
126124
)
127-
128125
action_list = self.actor_critic.sample_action(dists)
129-
log_probs, entropies, all_logs = self.actor_critic.get_probs_and_entropy(
126+
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
130127
action_list, dists
131128
)
132129
actions = torch.stack(action_list, dim=-1)
@@ -146,15 +143,13 @@ def sample_actions(
146143
def evaluate_actions(
147144
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
148145
):
149-
dists, (value_heads, mean_value), _ = self.actor_critic.get_dist_and_value(
146+
dists, value_heads, _ = self.actor_critic.get_dist_and_value(
150147
vec_obs, vis_obs, masks, memories, seq_len
151148
)
152149
if len(actions.shape) <= 2:
153150
actions = actions.unsqueeze(-1)
154151
action_list = [actions[..., i] for i in range(actions.shape[2])]
155-
log_probs, entropies, _ = self.actor_critic.get_probs_and_entropy(
156-
action_list, dists
157-
)
152+
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)
158153

159154
return log_probs, entropies, value_heads
160155

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def create_torch_policy(
233233
self.artifact_path,
234234
self.load,
235235
condition_sigma_on_obs=False, # Faster training for PPO
236+
separate_critic=behavior_spec.is_action_continuous(),
236237
)
237238
return policy
238239

ml-agents/mlagents/trainers/tests/test_reward_signals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import mlagents.trainers.tests.mock_brain as mb
55
from mlagents.trainers.policy.tf_policy import TFPolicy
66
from mlagents.trainers.sac.optimizer import SACOptimizer
7-
from mlagents.trainers.ppo.optimizer import PPOOptimizer
7+
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer
88
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG
99
from mlagents.trainers.settings import (
1010
GAILSettings,
@@ -75,7 +75,7 @@ def create_optimizer_mock(
7575
if trainer_settings.trainer_type == TrainerType.SAC:
7676
optimizer = SACOptimizer(policy, trainer_settings)
7777
else:
78-
optimizer = PPOOptimizer(policy, trainer_settings)
78+
optimizer = TFPPOOptimizer(policy, trainer_settings)
7979
return optimizer
8080

8181

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import pytest
2+
3+
import torch
4+
from mlagents.trainers.torch.networks import (
5+
NetworkBody,
6+
ValueNetwork,
7+
SimpleActor,
8+
SharedActorCritic,
9+
SeparateActorCritic,
10+
)
11+
from mlagents.trainers.settings import NetworkSettings
12+
from mlagents_envs.base_env import ActionType
13+
from mlagents.trainers.torch.distributions import (
14+
GaussianDistInstance,
15+
CategoricalDistInstance,
16+
)
17+
18+
19+
def test_networkbody_vector():
20+
obs_size = 4
21+
network_settings = NetworkSettings()
22+
obs_shapes = [(obs_size,)]
23+
24+
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2)
25+
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
26+
sample_obs = torch.ones((1, obs_size))
27+
sample_act = torch.ones((1, 2))
28+
29+
for _ in range(100):
30+
encoded, _ = networkbody([sample_obs], [], sample_act)
31+
assert encoded.shape == (1, network_settings.hidden_units)
32+
# Try to force output to 1
33+
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
34+
optimizer.zero_grad()
35+
loss.backward()
36+
optimizer.step()
37+
# In the last step, values should be close to 1
38+
for _enc in encoded.flatten():
39+
assert _enc == pytest.approx(1.0, abs=0.1)
40+
41+
42+
def test_networkbody_lstm():
43+
obs_size = 4
44+
seq_len = 16
45+
network_settings = NetworkSettings(
46+
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4)
47+
)
48+
obs_shapes = [(obs_size,)]
49+
50+
networkbody = NetworkBody(obs_shapes, network_settings)
51+
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
52+
sample_obs = torch.ones((1, seq_len, obs_size))
53+
54+
for _ in range(100):
55+
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4))
56+
# Try to force output to 1
57+
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
58+
optimizer.zero_grad()
59+
loss.backward()
60+
optimizer.step()
61+
# In the last step, values should be close to 1
62+
for _enc in encoded.flatten():
63+
assert _enc == pytest.approx(1.0, abs=0.1)
64+
65+
66+
def test_networkbody_visual():
67+
vec_obs_size = 4
68+
obs_size = (84, 84, 3)
69+
network_settings = NetworkSettings()
70+
obs_shapes = [(vec_obs_size,), obs_size]
71+
torch.random.manual_seed(0)
72+
73+
networkbody = NetworkBody(obs_shapes, network_settings)
74+
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
75+
sample_obs = torch.ones((1, 84, 84, 3))
76+
sample_vec_obs = torch.ones((1, vec_obs_size))
77+
78+
for _ in range(100):
79+
encoded, _ = networkbody([sample_vec_obs], [sample_obs])
80+
assert encoded.shape == (1, network_settings.hidden_units)
81+
# Try to force output to 1
82+
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
83+
optimizer.zero_grad()
84+
loss.backward()
85+
optimizer.step()
86+
# In the last step, values should be close to 1
87+
for _enc in encoded.flatten():
88+
assert _enc == pytest.approx(1.0, abs=0.1)
89+
90+
91+
def test_valuenetwork():
92+
obs_size = 4
93+
num_outputs = 2
94+
network_settings = NetworkSettings()
95+
obs_shapes = [(obs_size,)]
96+
97+
stream_names = [f"stream_name{n}" for n in range(4)]
98+
value_net = ValueNetwork(
99+
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs
100+
)
101+
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3)
102+
103+
for _ in range(50):
104+
sample_obs = torch.ones((1, obs_size))
105+
values, _ = value_net([sample_obs], [])
106+
loss = 0
107+
for s_name in stream_names:
108+
assert values[s_name].shape == (1, num_outputs)
109+
# Try to force output to 1
110+
loss += torch.nn.functional.mse_loss(
111+
values[s_name], torch.ones((1, num_outputs))
112+
)
113+
114+
optimizer.zero_grad()
115+
loss.backward()
116+
optimizer.step()
117+
# In the last step, values should be close to 1
118+
for value in values.values():
119+
for _out in value:
120+
assert _out[0] == pytest.approx(1.0, abs=0.1)
121+
122+
123+
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
124+
def test_simple_actor(action_type):
125+
obs_size = 4
126+
network_settings = NetworkSettings()
127+
obs_shapes = [(obs_size,)]
128+
act_size = [2]
129+
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
130+
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
131+
# Test get_dist
132+
sample_obs = torch.ones((1, obs_size))
133+
dists, _ = actor.get_dists([sample_obs], [], masks=masks)
134+
for dist in dists:
135+
if action_type == ActionType.CONTINUOUS:
136+
assert isinstance(dist, GaussianDistInstance)
137+
else:
138+
assert isinstance(dist, CategoricalDistInstance)
139+
140+
# Test sample_actions
141+
actions = actor.sample_action(dists)
142+
for act in actions:
143+
if action_type == ActionType.CONTINUOUS:
144+
assert act.shape == (1, act_size[0])
145+
else:
146+
assert act.shape == (1, 1)
147+
148+
# Test forward
149+
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
150+
[sample_obs], [], masks=masks
151+
)
152+
for act in actions:
153+
if action_type == ActionType.CONTINUOUS:
154+
assert act.shape == (
155+
act_size[0],
156+
1,
157+
) # This is different from above for ONNX export
158+
else:
159+
assert act.shape == (1, 1)
160+
161+
# TODO: Once export works properly. fix the shapes here.
162+
assert mem_size == 0
163+
assert is_cont == int(action_type == ActionType.CONTINUOUS)
164+
assert act_size_vec == torch.tensor(act_size)
165+
166+
167+
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic])
168+
@pytest.mark.parametrize("lstm", [True, False])
169+
def test_actor_critic(ac_type, lstm):
170+
obs_size = 4
171+
network_settings = NetworkSettings(
172+
memory=NetworkSettings.MemorySettings() if lstm else None
173+
)
174+
obs_shapes = [(obs_size,)]
175+
act_size = [2]
176+
stream_names = [f"stream_name{n}" for n in range(4)]
177+
actor = ac_type(
178+
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
179+
)
180+
if lstm:
181+
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
182+
memories = torch.ones(
183+
(
184+
1,
185+
network_settings.memory.sequence_length,
186+
network_settings.memory.memory_size,
187+
)
188+
)
189+
else:
190+
sample_obs = torch.ones((1, obs_size))
191+
memories = None
192+
# Test critic pass
193+
value_out = actor.critic_pass([sample_obs], [], memories=memories)
194+
for stream in stream_names:
195+
if lstm:
196+
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
197+
else:
198+
assert value_out[stream].shape == (1,)
199+
200+
# Test get_dist_and_value
201+
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories)
202+
for dist in dists:
203+
assert isinstance(dist, GaussianDistInstance)
204+
for stream in stream_names:
205+
if lstm:
206+
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
207+
else:
208+
assert value_out[stream].shape == (1,)
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from typing import List, Dict
2+
13
import torch
24
from torch import nn
35

46

57
class ValueHeads(nn.Module):
6-
def __init__(self, stream_names, input_size, output_size=1):
8+
def __init__(self, stream_names: List[str], input_size: int, output_size: int = 1):
79
super().__init__()
810
self.stream_names = stream_names
911
_value_heads = {}
@@ -13,11 +15,8 @@ def __init__(self, stream_names, input_size, output_size=1):
1315
_value_heads[name] = value
1416
self.value_heads = nn.ModuleDict(_value_heads)
1517

16-
def forward(self, hidden):
18+
def forward(self, hidden: torch.Tensor) -> Dict[str, torch.Tensor]:
1719
value_outputs = {}
1820
for stream_name, head in self.value_heads.items():
1921
value_outputs[stream_name] = head(hidden).squeeze(-1)
20-
return (
21-
value_outputs,
22-
torch.mean(torch.stack(list(value_outputs.values())), dim=0),
23-
)
22+
return value_outputs

0 commit comments

Comments
 (0)