1
1
import numpy as np
2
- from typing import Dict , List , Mapping , cast , Tuple
2
+ from typing import Dict , List , Mapping , cast , Tuple , Optional
3
3
import torch
4
4
from torch import nn
5
+ import attr
5
6
6
7
from mlagents_envs .logging_util import get_logger
7
8
from mlagents_envs .base_env import ActionType
@@ -56,10 +57,24 @@ def forward(
56
57
self ,
57
58
vec_inputs : List [torch .Tensor ],
58
59
vis_inputs : List [torch .Tensor ],
59
- actions : torch .Tensor = None ,
60
+ actions : Optional [torch .Tensor ] = None ,
61
+ memories : Optional [torch .Tensor ] = None ,
62
+ sequence_length : int = 1 ,
60
63
) -> Tuple [Dict [str , torch .Tensor ], Dict [str , torch .Tensor ]]:
61
- q1_out , _ = self .q1_network (vec_inputs , vis_inputs , actions = actions )
62
- q2_out , _ = self .q2_network (vec_inputs , vis_inputs , actions = actions )
64
+ q1_out , _ = self .q1_network (
65
+ vec_inputs ,
66
+ vis_inputs ,
67
+ actions = actions ,
68
+ memories = memories ,
69
+ sequence_length = sequence_length ,
70
+ )
71
+ q2_out , _ = self .q2_network (
72
+ vec_inputs ,
73
+ vis_inputs ,
74
+ actions = actions ,
75
+ memories = memories ,
76
+ sequence_length = sequence_length ,
77
+ )
63
78
return q1_out , q2_out
64
79
65
80
def __init__ (self , policy : TorchPolicy , trainer_params : TrainerSettings ):
@@ -87,17 +102,28 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
87
102
for name in self .stream_names
88
103
}
89
104
105
+ # Critics should have 1/2 of the memory of the policy
106
+ critic_memory = policy_network_settings .memory
107
+ if critic_memory is not None :
108
+ critic_memory = attr .evolve (
109
+ critic_memory , memory_size = critic_memory .memory_size // 2
110
+ )
111
+ value_network_settings = attr .evolve (
112
+ policy_network_settings , memory = critic_memory
113
+ )
114
+
90
115
self .value_network = TorchSACOptimizer .PolicyValueNetwork (
91
116
self .stream_names ,
92
117
self .policy .behavior_spec .observation_shapes ,
93
- policy_network_settings ,
118
+ value_network_settings ,
94
119
self .policy .behavior_spec .action_type ,
95
120
self .act_size ,
96
121
)
122
+
97
123
self .target_network = ValueNetwork (
98
124
self .stream_names ,
99
125
self .policy .behavior_spec .observation_shapes ,
100
- policy_network_settings ,
126
+ value_network_settings ,
101
127
)
102
128
self .soft_update (self .policy .actor_critic .critic , self .target_network , 1.0 )
103
129
@@ -168,11 +194,11 @@ def sac_q_loss(
168
194
* self .gammas [i ]
169
195
* target_values [name ]
170
196
)
171
- _q1_loss = 0.5 * torch . mean (
172
- loss_masks * torch .nn .functional .mse_loss (q_backup , q1_stream )
197
+ _q1_loss = 0.5 * ModelUtils . masked_mean (
198
+ torch .nn .functional .mse_loss (q_backup , q1_stream ), loss_masks
173
199
)
174
- _q2_loss = 0.5 * torch . mean (
175
- loss_masks * torch .nn .functional .mse_loss (q_backup , q2_stream )
200
+ _q2_loss = 0.5 * ModelUtils . masked_mean (
201
+ torch .nn .functional .mse_loss (q_backup , q2_stream ), loss_masks
176
202
)
177
203
178
204
q1_losses .append (_q1_loss )
@@ -232,9 +258,8 @@ def sac_value_loss(
232
258
v_backup = min_policy_qs [name ] - torch .sum (
233
259
_ent_coef * log_probs , dim = 1
234
260
)
235
- # print(log_probs, v_backup, _ent_coef, loss_masks)
236
- value_loss = 0.5 * torch .mean (
237
- loss_masks * torch .nn .functional .mse_loss (values [name ], v_backup )
261
+ value_loss = 0.5 * ModelUtils .masked_mean (
262
+ torch .nn .functional .mse_loss (values [name ], v_backup ), loss_masks
238
263
)
239
264
value_losses .append (value_loss )
240
265
else :
@@ -253,9 +278,9 @@ def sac_value_loss(
253
278
v_backup = min_policy_qs [name ] - torch .mean (
254
279
branched_ent_bonus , axis = 0
255
280
)
256
- value_loss = 0.5 * torch . mean (
257
- loss_masks
258
- * torch . nn . functional . mse_loss ( values [ name ], v_backup . squeeze ())
281
+ value_loss = 0.5 * ModelUtils . masked_mean (
282
+ torch . nn . functional . mse_loss ( values [ name ], v_backup . squeeze ()),
283
+ loss_masks ,
259
284
)
260
285
value_losses .append (value_loss )
261
286
value_loss = torch .mean (torch .stack (value_losses ))
@@ -275,7 +300,7 @@ def sac_policy_loss(
275
300
if not discrete :
276
301
mean_q1 = mean_q1 .unsqueeze (1 )
277
302
batch_policy_loss = torch .mean (_ent_coef * log_probs - mean_q1 , dim = 1 )
278
- policy_loss = torch . mean ( loss_masks * batch_policy_loss )
303
+ policy_loss = ModelUtils . masked_mean ( batch_policy_loss , loss_masks )
279
304
else :
280
305
action_probs = log_probs .exp ()
281
306
branched_per_action_ent = ModelUtils .break_into_branches (
@@ -322,9 +347,8 @@ def sac_entropy_loss(
322
347
target_current_diff = torch .squeeze (
323
348
target_current_diff_branched , axis = 2
324
349
)
325
- entropy_loss = - torch .mean (
326
- loss_masks
327
- * torch .mean (self ._log_ent_coef * target_current_diff , axis = 1 )
350
+ entropy_loss = - 1 * ModelUtils .masked_mean (
351
+ torch .mean (self ._log_ent_coef * target_current_diff , axis = 1 ), loss_masks
328
352
)
329
353
330
354
return entropy_loss
@@ -369,12 +393,28 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
369
393
else :
370
394
actions = ModelUtils .list_to_tensor (batch ["actions" ], dtype = torch .long )
371
395
372
- memories = [
396
+ memories_list = [
373
397
ModelUtils .list_to_tensor (batch ["memory" ][i ])
374
398
for i in range (0 , len (batch ["memory" ]), self .policy .sequence_length )
375
399
]
376
- if len (memories ) > 0 :
377
- memories = torch .stack (memories ).unsqueeze (0 )
400
+ # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
401
+ offset = 1 if self .policy .sequence_length > 1 else 0
402
+ next_memories_list = [
403
+ ModelUtils .list_to_tensor (
404
+ batch ["memory" ][i ][self .policy .m_size // 2 :]
405
+ ) # only pass value part of memory to target network
406
+ for i in range (offset , len (batch ["memory" ]), self .policy .sequence_length )
407
+ ]
408
+
409
+ if len (memories_list ) > 0 :
410
+ memories = torch .stack (memories_list ).unsqueeze (0 )
411
+ next_memories = torch .stack (next_memories_list ).unsqueeze (0 )
412
+ else :
413
+ memories = None
414
+ next_memories = None
415
+ # Q network memories are 0'ed out, since we don't have them during inference.
416
+ q_memories = torch .zeros_like (next_memories )
417
+
378
418
vis_obs : List [torch .Tensor ] = []
379
419
next_vis_obs : List [torch .Tensor ] = []
380
420
if self .policy .use_vis_obs :
@@ -415,19 +455,46 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
415
455
)
416
456
if self .policy .use_continuous_act :
417
457
squeezed_actions = actions .squeeze (- 1 )
418
- q1p_out , q2p_out = self .value_network (vec_obs , vis_obs , sampled_actions )
419
- q1_out , q2_out = self .value_network (vec_obs , vis_obs , squeezed_actions )
458
+ q1p_out , q2p_out = self .value_network (
459
+ vec_obs ,
460
+ vis_obs ,
461
+ sampled_actions ,
462
+ memories = q_memories ,
463
+ sequence_length = self .policy .sequence_length ,
464
+ )
465
+ q1_out , q2_out = self .value_network (
466
+ vec_obs ,
467
+ vis_obs ,
468
+ squeezed_actions ,
469
+ memories = q_memories ,
470
+ sequence_length = self .policy .sequence_length ,
471
+ )
420
472
q1_stream , q2_stream = q1_out , q2_out
421
473
else :
422
474
with torch .no_grad ():
423
- q1p_out , q2p_out = self .value_network (vec_obs , vis_obs )
424
- q1_out , q2_out = self .value_network (vec_obs , vis_obs )
475
+ q1p_out , q2p_out = self .value_network (
476
+ vec_obs ,
477
+ vis_obs ,
478
+ memories = q_memories ,
479
+ sequence_length = self .policy .sequence_length ,
480
+ )
481
+ q1_out , q2_out = self .value_network (
482
+ vec_obs ,
483
+ vis_obs ,
484
+ memories = q_memories ,
485
+ sequence_length = self .policy .sequence_length ,
486
+ )
425
487
q1_stream = self ._condense_q_streams (q1_out , actions )
426
488
q2_stream = self ._condense_q_streams (q2_out , actions )
427
489
428
490
with torch .no_grad ():
429
- target_values , _ = self .target_network (next_vec_obs , next_vis_obs )
430
- masks = ModelUtils .list_to_tensor (batch ["masks" ], dtype = torch .int32 )
491
+ target_values , _ = self .target_network (
492
+ next_vec_obs ,
493
+ next_vis_obs ,
494
+ memories = next_memories ,
495
+ sequence_length = self .policy .sequence_length ,
496
+ )
497
+ masks = ModelUtils .list_to_tensor (batch ["masks" ], dtype = torch .bool )
431
498
use_discrete = not self .policy .use_continuous_act
432
499
dones = ModelUtils .list_to_tensor (batch ["done" ])
433
500
0 commit comments