@@ -248,10 +248,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
248
248
current_obs = ObsUtil .from_buffer (batch , n_obs )
249
249
# Convert to tensors
250
250
current_obs = [ModelUtils .list_to_tensor (obs ) for obs in current_obs ]
251
- group_obs = GroupObsUtil .from_buffer (batch , n_obs )
252
- group_obs = [
251
+ groupmate_obs = GroupObsUtil .from_buffer (batch , n_obs )
252
+ groupmate_obs = [
253
253
[ModelUtils .list_to_tensor (obs ) for obs in _groupmate_obs ]
254
- for _groupmate_obs in group_obs
254
+ for _groupmate_obs in groupmate_obs
255
255
]
256
256
257
257
act_masks = ModelUtils .list_to_tensor (batch [BufferKey .ACTION_MASK ])
@@ -289,15 +289,15 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
289
289
memories = memories ,
290
290
seq_len = self .policy .sequence_length ,
291
291
)
292
- all_obs = [current_obs ] + group_obs
292
+ all_obs = [current_obs ] + groupmate_obs
293
293
values , _ = self .critic .critic_pass (
294
294
all_obs ,
295
295
memories = value_memories ,
296
296
sequence_length = self .policy .sequence_length ,
297
297
)
298
298
baselines , _ = self .critic .baseline (
299
299
[current_obs ],
300
- group_obs ,
300
+ groupmate_obs ,
301
301
group_actions ,
302
302
memories = baseline_memories ,
303
303
sequence_length = self .policy .sequence_length ,
@@ -538,7 +538,7 @@ def get_trajectory_and_baseline_value_estimates(
538
538
self ,
539
539
batch : AgentBuffer ,
540
540
next_obs : List [np .ndarray ],
541
- next_group_obs : List [List [np .ndarray ]],
541
+ next_groupmate_obs : List [List [np .ndarray ]],
542
542
done : bool ,
543
543
agent_id : str = "" ,
544
544
) -> Tuple [
@@ -553,7 +553,7 @@ def get_trajectory_and_baseline_value_estimates(
553
553
:param batch: An AgentBuffer that consists of a trajectory.
554
554
:param next_obs: the next observation (after the trajectory). Used for boostrapping
555
555
if this is not a termiinal trajectory.
556
- :param next_group_obs : the next observations from other members of the group.
556
+ :param next_groupmate_obs : the next observations from other members of the group.
557
557
:param done: Set true if this is a terminal trajectory.
558
558
:param agent_id: Agent ID of the agent that this trajectory belongs to.
559
559
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
@@ -578,12 +578,14 @@ def get_trajectory_and_baseline_value_estimates(
578
578
next_obs = [ModelUtils .list_to_tensor (obs ) for obs in next_obs ]
579
579
next_obs = [obs .unsqueeze (0 ) for obs in next_obs ]
580
580
581
- next_group_obs = [
582
- ModelUtils .list_to_tensor_list (_list_obs ) for _list_obs in next_group_obs
581
+ next_groupmate_obs = [
582
+ ModelUtils .list_to_tensor_list (_list_obs )
583
+ for _list_obs in next_groupmate_obs
583
584
]
584
585
# Expand dimensions of next critic obs
585
- next_group_obs = [
586
- [_obs .unsqueeze (0 ) for _obs in _list_obs ] for _list_obs in next_group_obs
586
+ next_groupmate_obs = [
587
+ [_obs .unsqueeze (0 ) for _obs in _list_obs ]
588
+ for _list_obs in next_groupmate_obs
587
589
]
588
590
589
591
if agent_id in self .value_memory_dict :
@@ -638,7 +640,9 @@ def get_trajectory_and_baseline_value_estimates(
638
640
self .baseline_memory_dict [agent_id ] = next_baseline_mem
639
641
640
642
all_next_obs = (
641
- [next_obs ] + next_group_obs if next_group_obs is not None else [next_obs ]
643
+ [next_obs ] + next_groupmate_obs
644
+ if next_groupmate_obs is not None
645
+ else [next_obs ]
642
646
)
643
647
644
648
next_value_estimates , _ = self .critic .critic_pass (
0 commit comments