Skip to content

Commit 5ce6272

Browse files
author
Ervin T
authored
[add-fire] Add LSTM to SAC, LSTM fixes and initializations (#4324)
1 parent 83e17bb commit 5ce6272

File tree

10 files changed

+263
-108
lines changed

10 files changed

+263
-108
lines changed

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 14 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from typing import Dict, Optional, Tuple, List
22
import torch
33
import numpy as np
4-
from mlagents_envs.base_env import DecisionSteps
54

65
from mlagents.trainers.buffer import AgentBuffer
6+
from mlagents.trainers.trajectory import SplitObservations
77
from mlagents.trainers.torch.components.bc.module import BCModule
88
from mlagents.trainers.torch.components.reward_providers import create_reward_provider
99

1010
from mlagents.trainers.policy.torch_policy import TorchPolicy
1111
from mlagents.trainers.optimizer import Optimizer
1212
from mlagents.trainers.settings import TrainerSettings
13-
from mlagents.trainers.trajectory import SplitObservations
1413
from mlagents.trainers.torch.utils import ModelUtils
1514

1615

@@ -50,35 +49,6 @@ def create_reward_signals(self, reward_signal_configs):
5049
reward_signal, self.policy.behavior_spec, settings
5150
)
5251

53-
def get_value_estimates(
54-
self, decision_requests: DecisionSteps, idx: int, done: bool
55-
) -> Dict[str, float]:
56-
"""
57-
Generates value estimates for bootstrapping.
58-
:param decision_requests:
59-
:param idx: Index in BrainInfo of agent.
60-
:param done: Whether or not this is the last element of the episode,
61-
in which case the value estimate will be 0.
62-
:return: The value estimate dictionary with key being the name of the reward signal
63-
and the value the corresponding value estimate.
64-
"""
65-
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
66-
67-
value_estimates = self.policy.actor_critic.critic_pass(
68-
np.expand_dims(vec_vis_obs.vector_observations[idx], 0),
69-
np.expand_dims(vec_vis_obs.visual_observations[idx], 0),
70-
)
71-
72-
value_estimates = {k: float(v) for k, v in value_estimates.items()}
73-
74-
# If we're done, reassign all of the value estimates that need terminal states.
75-
if done:
76-
for k in value_estimates:
77-
if not self.reward_signals[k].ignore_done:
78-
value_estimates[k] = 0.0
79-
80-
return value_estimates
81-
8252
def get_trajectory_value_estimates(
8353
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
8454
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
@@ -93,18 +63,23 @@ def get_trajectory_value_estimates(
9363
else:
9464
visual_obs = []
9565

96-
memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size])
66+
memory = torch.zeros([1, 1, self.policy.m_size])
9767

98-
next_obs = np.concatenate(next_obs, axis=-1)
99-
next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
100-
next_memory = torch.zeros([1, 1, self.policy.m_size])
68+
vec_vis_obs = SplitObservations.from_observations(next_obs)
69+
next_vec_obs = [
70+
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0)
71+
]
72+
next_vis_obs = [
73+
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0)
74+
for _vis_ob in vec_vis_obs.visual_observations
75+
]
10176

102-
value_estimates = self.policy.actor_critic.critic_pass(
103-
vector_obs, visual_obs, memory
77+
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
78+
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences
10479
)
10580

106-
next_value_estimate = self.policy.actor_critic.critic_pass(
107-
next_obs, next_obs, next_memory
81+
next_value_estimate, _ = self.policy.actor_critic.critic_pass(
82+
next_vec_obs, next_vis_obs, next_memory, sequence_length=1
10883
)
10984

11085
for name, estimate in value_estimates.items():

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def evaluate(
186186
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
187187
run_out["learning_rate"] = 0.0
188188
if self.use_recurrent:
189-
run_out["memories"] = memories.detach().cpu().numpy()
189+
run_out["memory_out"] = memories.detach().cpu().numpy().squeeze(0)
190190
return run_out
191191

192192
def get_action(

ml-agents/mlagents/trainers/ppo/optimizer_torch.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,15 @@ def ppo_value_loss(
6161
old_values: Dict[str, torch.Tensor],
6262
returns: Dict[str, torch.Tensor],
6363
epsilon: float,
64+
loss_masks: torch.Tensor,
6465
) -> torch.Tensor:
6566
"""
66-
Creates training-specific Tensorflow ops for PPO models.
67-
:param returns:
68-
:param old_values:
69-
:param values:
67+
Evaluates value loss for PPO.
68+
:param values: Value output of the current network.
69+
:param old_values: Value stored with experiences in buffer.
70+
:param returns: Computed returns.
71+
:param epsilon: Clipping value for value estimate.
72+
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
7073
"""
7174
value_losses = []
7275
for name, head in values.items():
@@ -77,18 +80,24 @@ def ppo_value_loss(
7780
)
7881
v_opt_a = (returns_tensor - head) ** 2
7982
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
80-
value_loss = torch.mean(torch.max(v_opt_a, v_opt_b))
83+
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
8184
value_losses.append(value_loss)
8285
value_loss = torch.mean(torch.stack(value_losses))
8386
return value_loss
8487

85-
def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks):
88+
def ppo_policy_loss(
89+
self,
90+
advantages: torch.Tensor,
91+
log_probs: torch.Tensor,
92+
old_log_probs: torch.Tensor,
93+
loss_masks: torch.Tensor,
94+
) -> torch.Tensor:
8695
"""
87-
Creates training-specific Tensorflow ops for PPO models.
88-
:param masks:
89-
:param advantages:
96+
Evaluate PPO policy loss.
97+
:param advantages: Computed advantages.
9098
:param log_probs: Current policy probabilities
9199
:param old_log_probs: Past policy probabilities
100+
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
92101
"""
93102
advantage = advantages.unsqueeze(-1)
94103

@@ -99,7 +108,9 @@ def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks):
99108
p_opt_b = (
100109
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
101110
)
102-
policy_loss = -torch.mean(torch.min(p_opt_a, p_opt_b))
111+
policy_loss = -1 * ModelUtils.masked_mean(
112+
torch.min(p_opt_a, p_opt_b).flatten(), loss_masks
113+
)
103114
return policy_loss
104115

105116
@timed
@@ -153,14 +164,21 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
153164
memories=memories,
154165
seq_len=self.policy.sequence_length,
155166
)
156-
value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps)
167+
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
168+
value_loss = self.ppo_value_loss(
169+
values, old_values, returns, decay_eps, loss_masks
170+
)
157171
policy_loss = self.ppo_policy_loss(
158172
ModelUtils.list_to_tensor(batch["advantages"]),
159173
log_probs,
160174
ModelUtils.list_to_tensor(batch["action_probs"]),
161-
ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32),
175+
loss_masks,
176+
)
177+
loss = (
178+
policy_loss
179+
+ 0.5 * value_loss
180+
- decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks)
162181
)
163-
loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy)
164182

165183
# Set optimizer learning rate
166184
ModelUtils.update_learning_rate(self.optimizer, decay_lr)

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
2-
from typing import Dict, List, Mapping, cast, Tuple
2+
from typing import Dict, List, Mapping, cast, Tuple, Optional
33
import torch
44
from torch import nn
5+
import attr
56

67
from mlagents_envs.logging_util import get_logger
78
from mlagents_envs.base_env import ActionType
@@ -56,10 +57,24 @@ def forward(
5657
self,
5758
vec_inputs: List[torch.Tensor],
5859
vis_inputs: List[torch.Tensor],
59-
actions: torch.Tensor = None,
60+
actions: Optional[torch.Tensor] = None,
61+
memories: Optional[torch.Tensor] = None,
62+
sequence_length: int = 1,
6063
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
61-
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions)
62-
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions)
64+
q1_out, _ = self.q1_network(
65+
vec_inputs,
66+
vis_inputs,
67+
actions=actions,
68+
memories=memories,
69+
sequence_length=sequence_length,
70+
)
71+
q2_out, _ = self.q2_network(
72+
vec_inputs,
73+
vis_inputs,
74+
actions=actions,
75+
memories=memories,
76+
sequence_length=sequence_length,
77+
)
6378
return q1_out, q2_out
6479

6580
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
@@ -87,17 +102,28 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
87102
for name in self.stream_names
88103
}
89104

105+
# Critics should have 1/2 of the memory of the policy
106+
critic_memory = policy_network_settings.memory
107+
if critic_memory is not None:
108+
critic_memory = attr.evolve(
109+
critic_memory, memory_size=critic_memory.memory_size // 2
110+
)
111+
value_network_settings = attr.evolve(
112+
policy_network_settings, memory=critic_memory
113+
)
114+
90115
self.value_network = TorchSACOptimizer.PolicyValueNetwork(
91116
self.stream_names,
92117
self.policy.behavior_spec.observation_shapes,
93-
policy_network_settings,
118+
value_network_settings,
94119
self.policy.behavior_spec.action_type,
95120
self.act_size,
96121
)
122+
97123
self.target_network = ValueNetwork(
98124
self.stream_names,
99125
self.policy.behavior_spec.observation_shapes,
100-
policy_network_settings,
126+
value_network_settings,
101127
)
102128
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)
103129

@@ -168,11 +194,11 @@ def sac_q_loss(
168194
* self.gammas[i]
169195
* target_values[name]
170196
)
171-
_q1_loss = 0.5 * torch.mean(
172-
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream)
197+
_q1_loss = 0.5 * ModelUtils.masked_mean(
198+
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks
173199
)
174-
_q2_loss = 0.5 * torch.mean(
175-
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream)
200+
_q2_loss = 0.5 * ModelUtils.masked_mean(
201+
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks
176202
)
177203

178204
q1_losses.append(_q1_loss)
@@ -232,9 +258,8 @@ def sac_value_loss(
232258
v_backup = min_policy_qs[name] - torch.sum(
233259
_ent_coef * log_probs, dim=1
234260
)
235-
# print(log_probs, v_backup, _ent_coef, loss_masks)
236-
value_loss = 0.5 * torch.mean(
237-
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup)
261+
value_loss = 0.5 * ModelUtils.masked_mean(
262+
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
238263
)
239264
value_losses.append(value_loss)
240265
else:
@@ -253,9 +278,9 @@ def sac_value_loss(
253278
v_backup = min_policy_qs[name] - torch.mean(
254279
branched_ent_bonus, axis=0
255280
)
256-
value_loss = 0.5 * torch.mean(
257-
loss_masks
258-
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze())
281+
value_loss = 0.5 * ModelUtils.masked_mean(
282+
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
283+
loss_masks,
259284
)
260285
value_losses.append(value_loss)
261286
value_loss = torch.mean(torch.stack(value_losses))
@@ -275,7 +300,7 @@ def sac_policy_loss(
275300
if not discrete:
276301
mean_q1 = mean_q1.unsqueeze(1)
277302
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
278-
policy_loss = torch.mean(loss_masks * batch_policy_loss)
303+
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
279304
else:
280305
action_probs = log_probs.exp()
281306
branched_per_action_ent = ModelUtils.break_into_branches(
@@ -322,9 +347,8 @@ def sac_entropy_loss(
322347
target_current_diff = torch.squeeze(
323348
target_current_diff_branched, axis=2
324349
)
325-
entropy_loss = -torch.mean(
326-
loss_masks
327-
* torch.mean(self._log_ent_coef * target_current_diff, axis=1)
350+
entropy_loss = -1 * ModelUtils.masked_mean(
351+
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
328352
)
329353

330354
return entropy_loss
@@ -369,12 +393,28 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
369393
else:
370394
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)
371395

372-
memories = [
396+
memories_list = [
373397
ModelUtils.list_to_tensor(batch["memory"][i])
374398
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
375399
]
376-
if len(memories) > 0:
377-
memories = torch.stack(memories).unsqueeze(0)
400+
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
401+
offset = 1 if self.policy.sequence_length > 1 else 0
402+
next_memories_list = [
403+
ModelUtils.list_to_tensor(
404+
batch["memory"][i][self.policy.m_size // 2 :]
405+
) # only pass value part of memory to target network
406+
for i in range(offset, len(batch["memory"]), self.policy.sequence_length)
407+
]
408+
409+
if len(memories_list) > 0:
410+
memories = torch.stack(memories_list).unsqueeze(0)
411+
next_memories = torch.stack(next_memories_list).unsqueeze(0)
412+
else:
413+
memories = None
414+
next_memories = None
415+
# Q network memories are 0'ed out, since we don't have them during inference.
416+
q_memories = torch.zeros_like(next_memories)
417+
378418
vis_obs: List[torch.Tensor] = []
379419
next_vis_obs: List[torch.Tensor] = []
380420
if self.policy.use_vis_obs:
@@ -415,19 +455,46 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
415455
)
416456
if self.policy.use_continuous_act:
417457
squeezed_actions = actions.squeeze(-1)
418-
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions)
419-
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions)
458+
q1p_out, q2p_out = self.value_network(
459+
vec_obs,
460+
vis_obs,
461+
sampled_actions,
462+
memories=q_memories,
463+
sequence_length=self.policy.sequence_length,
464+
)
465+
q1_out, q2_out = self.value_network(
466+
vec_obs,
467+
vis_obs,
468+
squeezed_actions,
469+
memories=q_memories,
470+
sequence_length=self.policy.sequence_length,
471+
)
420472
q1_stream, q2_stream = q1_out, q2_out
421473
else:
422474
with torch.no_grad():
423-
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs)
424-
q1_out, q2_out = self.value_network(vec_obs, vis_obs)
475+
q1p_out, q2p_out = self.value_network(
476+
vec_obs,
477+
vis_obs,
478+
memories=q_memories,
479+
sequence_length=self.policy.sequence_length,
480+
)
481+
q1_out, q2_out = self.value_network(
482+
vec_obs,
483+
vis_obs,
484+
memories=q_memories,
485+
sequence_length=self.policy.sequence_length,
486+
)
425487
q1_stream = self._condense_q_streams(q1_out, actions)
426488
q2_stream = self._condense_q_streams(q2_out, actions)
427489

428490
with torch.no_grad():
429-
target_values, _ = self.target_network(next_vec_obs, next_vis_obs)
430-
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32)
491+
target_values, _ = self.target_network(
492+
next_vec_obs,
493+
next_vis_obs,
494+
memories=next_memories,
495+
sequence_length=self.policy.sequence_length,
496+
)
497+
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
431498
use_discrete = not self.policy.use_continuous_act
432499
dones = ModelUtils.list_to_tensor(batch["done"])
433500

0 commit comments

Comments
 (0)