@@ -71,27 +71,25 @@ def update_normalization(self, buffer: AgentBuffer) -> None:
71
71
72
72
def baseline (
73
73
self ,
74
- self_obs : List [List [torch .Tensor ]],
75
- obs : List [List [torch .Tensor ]],
76
- actions : List [AgentAction ],
74
+ obs_without_actions : List [torch .Tensor ],
75
+ obs_with_actions : Tuple [List [List [torch .Tensor ]], List [AgentAction ]],
77
76
memories : Optional [torch .Tensor ] = None ,
78
77
sequence_length : int = 1 ,
79
78
) -> Tuple [Dict [str , torch .Tensor ], torch .Tensor ]:
80
79
"""
81
80
The POCA baseline marginalizes the action of the agent associated with self_obs.
82
81
It calls the forward pass of the MultiAgentNetworkBody with the state action
83
82
pairs of groupmates but just the state of the agent in question.
84
- :param self_obs: The obs of the agent for w.r.t. which to compute the baseline
85
- :param obs: List of observations for all groupmates. Should be the same length
86
- as actions.
87
- :param actions: List of actions for all groupmates. Should be the same length
88
- as obs.
83
+ :param obs_without_actions: The obs of the agent for which to compute the baseline.
84
+ :param obs_with_actions: Tuple of observations and actions for all groupmates.
89
85
:param memories: If using memory, a Tensor of initial memories.
90
86
:param sequence_length: If using memory, the sequence length.
87
+
91
88
:return: A Tuple of Dict of reward stream to tensor and critic memories.
92
89
"""
90
+ (obs , actions ) = obs_with_actions
93
91
encoding , memories = self .network_body (
94
- obs_only = self_obs ,
92
+ obs_only = [ obs_without_actions ] ,
95
93
obs = obs ,
96
94
actions = actions ,
97
95
memories = memories ,
@@ -256,7 +254,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
256
254
257
255
act_masks = ModelUtils .list_to_tensor (batch [BufferKey .ACTION_MASK ])
258
256
actions = AgentAction .from_buffer (batch )
259
- group_actions = AgentAction .group_from_buffer (batch )
257
+ groupmate_actions = AgentAction .group_from_buffer (batch )
260
258
261
259
memories = [
262
260
ModelUtils .list_to_tensor (batch [BufferKey .MEMORY ][i ])
@@ -295,10 +293,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
295
293
memories = value_memories ,
296
294
sequence_length = self .policy .sequence_length ,
297
295
)
296
+ groupmate_obs_and_actions = (groupmate_obs , groupmate_actions )
298
297
baselines , _ = self .critic .baseline (
299
- [current_obs ],
300
- groupmate_obs ,
301
- group_actions ,
298
+ current_obs ,
299
+ groupmate_obs_and_actions ,
302
300
memories = baseline_memories ,
303
301
sequence_length = self .policy .sequence_length ,
304
302
)
@@ -391,22 +389,22 @@ def _evaluate_by_sequence_team(
391
389
first_seq_len = leftover if leftover > 0 else self .policy .sequence_length
392
390
393
391
self_seq_obs = []
394
- team_seq_obs = []
395
- team_seq_act = []
392
+ groupmate_seq_obs = []
393
+ groupmate_seq_act = []
396
394
seq_obs = []
397
395
for _self_obs in self_obs :
398
396
first_seq_obs = _self_obs [0 :first_seq_len ]
399
397
seq_obs .append (first_seq_obs )
400
398
self_seq_obs .append (seq_obs )
401
399
402
- for team_obs , team_action in zip (obs , actions ):
400
+ for groupmate_obs , groupmate_action in zip (obs , actions ):
403
401
seq_obs = []
404
- for _obs in team_obs :
402
+ for _obs in groupmate_obs :
405
403
first_seq_obs = _obs [0 :first_seq_len ]
406
404
seq_obs .append (first_seq_obs )
407
- team_seq_obs .append (seq_obs )
408
- _act = team_action .slice (0 , first_seq_len )
409
- team_seq_act .append (_act )
405
+ groupmate_seq_obs .append (seq_obs )
406
+ _act = groupmate_action .slice (0 , first_seq_len )
407
+ groupmate_seq_act .append (_act )
410
408
411
409
# For the first sequence, the initial memory should be the one at the
412
410
# beginning of this trajectory.
@@ -416,7 +414,7 @@ def _evaluate_by_sequence_team(
416
414
ModelUtils .to_numpy (init_baseline_mem .squeeze ())
417
415
)
418
416
419
- all_seq_obs = self_seq_obs + team_seq_obs
417
+ all_seq_obs = self_seq_obs + groupmate_seq_obs
420
418
init_values , _value_mem = self .critic .critic_pass (
421
419
all_seq_obs , init_value_mem , sequence_length = first_seq_len
422
420
)
@@ -425,10 +423,10 @@ def _evaluate_by_sequence_team(
425
423
for signal_name in init_values .keys ()
426
424
}
427
425
426
+ groupmate_obs_and_actions = (groupmate_seq_obs , groupmate_seq_act )
428
427
init_baseline , _baseline_mem = self .critic .baseline (
429
- self_seq_obs ,
430
- team_seq_obs ,
431
- team_seq_act ,
428
+ self_seq_obs [0 ],
429
+ groupmate_obs_and_actions ,
432
430
init_baseline_mem ,
433
431
sequence_length = first_seq_len ,
434
432
)
@@ -456,34 +454,34 @@ def _evaluate_by_sequence_team(
456
454
)
457
455
458
456
self_seq_obs = []
459
- team_seq_obs = []
460
- team_seq_act = []
457
+ groupmate_seq_obs = []
458
+ groupmate_seq_act = []
461
459
seq_obs = []
462
460
for _self_obs in self_obs :
463
461
seq_obs .append (_obs [start :end ])
464
462
self_seq_obs .append (seq_obs )
465
463
466
- for team_obs , team_action in zip (obs , actions ):
464
+ for groupmate_obs , team_action in zip (obs , actions ):
467
465
seq_obs = []
468
- for (_obs ,) in team_obs :
466
+ for (_obs ,) in groupmate_obs :
469
467
first_seq_obs = _obs [start :end ]
470
468
seq_obs .append (first_seq_obs )
471
- team_seq_obs .append (seq_obs )
469
+ groupmate_seq_obs .append (seq_obs )
472
470
_act = team_action .slice (start , end )
473
- team_seq_act .append (_act )
471
+ groupmate_seq_act .append (_act )
474
472
475
- all_seq_obs = self_seq_obs + team_seq_obs
473
+ all_seq_obs = self_seq_obs + groupmate_seq_obs
476
474
values , _value_mem = self .critic .critic_pass (
477
475
all_seq_obs , _value_mem , sequence_length = self .policy .sequence_length
478
476
)
479
477
all_values = {
480
478
signal_name : [init_values [signal_name ]] for signal_name in values .keys ()
481
479
}
482
480
481
+ groupmate_obs_and_actions = (groupmate_seq_obs , groupmate_seq_act )
483
482
baselines , _baseline_mem = self .critic .baseline (
484
- self_seq_obs ,
485
- team_seq_obs ,
486
- team_seq_act ,
483
+ self_seq_obs [0 ],
484
+ groupmate_obs_and_actions ,
487
485
_baseline_mem ,
488
486
sequence_length = first_seq_len ,
489
487
)
@@ -565,15 +563,15 @@ def get_trajectory_and_baseline_value_estimates(
565
563
n_obs = len (self .policy .behavior_spec .observation_specs )
566
564
567
565
current_obs = ObsUtil .from_buffer (batch , n_obs )
568
- team_obs = GroupObsUtil .from_buffer (batch , n_obs )
566
+ groupmate_obs = GroupObsUtil .from_buffer (batch , n_obs )
569
567
570
568
current_obs = [ModelUtils .list_to_tensor (obs ) for obs in current_obs ]
571
- team_obs = [
572
- [ModelUtils .list_to_tensor (obs ) for obs in _teammate_obs ]
573
- for _teammate_obs in team_obs
569
+ groupmate_obs = [
570
+ [ModelUtils .list_to_tensor (obs ) for obs in _groupmate_obs ]
571
+ for _groupmate_obs in groupmate_obs
574
572
]
575
573
576
- team_actions = AgentAction .group_from_buffer (batch )
574
+ groupmate_actions = AgentAction .group_from_buffer (batch )
577
575
578
576
next_obs = [ModelUtils .list_to_tensor (obs ) for obs in next_obs ]
579
577
next_obs = [obs .unsqueeze (0 ) for obs in next_obs ]
@@ -604,7 +602,11 @@ def get_trajectory_and_baseline_value_estimates(
604
602
else None
605
603
)
606
604
607
- all_obs = [current_obs ] + team_obs if team_obs is not None else [current_obs ]
605
+ all_obs = (
606
+ [current_obs ] + groupmate_obs
607
+ if groupmate_obs is not None
608
+ else [current_obs ]
609
+ )
608
610
all_next_value_mem : Optional [AgentBufferField ] = None
609
611
all_next_baseline_mem : Optional [AgentBufferField ] = None
610
612
with torch .no_grad ():
@@ -618,20 +620,19 @@ def get_trajectory_and_baseline_value_estimates(
618
620
next_baseline_mem ,
619
621
) = self ._evaluate_by_sequence_team (
620
622
current_obs ,
621
- team_obs ,
622
- team_actions ,
623
+ groupmate_obs ,
624
+ groupmate_actions ,
623
625
_init_value_mem ,
624
626
_init_baseline_mem ,
625
627
)
626
628
else :
627
629
value_estimates , next_value_mem = self .critic .critic_pass (
628
630
all_obs , _init_value_mem , sequence_length = batch .num_experiences
629
631
)
630
-
632
+ groupmate_obs_and_actions = ( groupmate_obs , groupmate_actions )
631
633
baseline_estimates , next_baseline_mem = self .critic .baseline (
632
- [current_obs ],
633
- team_obs ,
634
- team_actions ,
634
+ current_obs ,
635
+ groupmate_obs_and_actions ,
635
636
_init_baseline_mem ,
636
637
sequence_length = batch .num_experiences ,
637
638
)
0 commit comments