Skip to content

Commit 5d3e500

Browse files
committed
rename group obs to groupmate obs
1 parent d50b873 commit 5d3e500

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

ml-agents/mlagents/trainers/poca/optimizer_torch.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
248248
current_obs = ObsUtil.from_buffer(batch, n_obs)
249249
# Convert to tensors
250250
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
251-
group_obs = GroupObsUtil.from_buffer(batch, n_obs)
252-
group_obs = [
251+
groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs)
252+
groupmate_obs = [
253253
[ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs]
254-
for _groupmate_obs in group_obs
254+
for _groupmate_obs in groupmate_obs
255255
]
256256

257257
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
@@ -289,15 +289,15 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
289289
memories=memories,
290290
seq_len=self.policy.sequence_length,
291291
)
292-
all_obs = [current_obs] + group_obs
292+
all_obs = [current_obs] + groupmate_obs
293293
values, _ = self.critic.critic_pass(
294294
all_obs,
295295
memories=value_memories,
296296
sequence_length=self.policy.sequence_length,
297297
)
298298
baselines, _ = self.critic.baseline(
299299
[current_obs],
300-
group_obs,
300+
groupmate_obs,
301301
group_actions,
302302
memories=baseline_memories,
303303
sequence_length=self.policy.sequence_length,
@@ -538,7 +538,7 @@ def get_trajectory_and_baseline_value_estimates(
538538
self,
539539
batch: AgentBuffer,
540540
next_obs: List[np.ndarray],
541-
next_group_obs: List[List[np.ndarray]],
541+
next_groupmate_obs: List[List[np.ndarray]],
542542
done: bool,
543543
agent_id: str = "",
544544
) -> Tuple[
@@ -553,7 +553,7 @@ def get_trajectory_and_baseline_value_estimates(
553553
:param batch: An AgentBuffer that consists of a trajectory.
554554
:param next_obs: the next observation (after the trajectory). Used for boostrapping
555555
if this is not a termiinal trajectory.
556-
:param next_group_obs: the next observations from other members of the group.
556+
:param next_groupmate_obs: the next observations from other members of the group.
557557
:param done: Set true if this is a terminal trajectory.
558558
:param agent_id: Agent ID of the agent that this trajectory belongs to.
559559
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
@@ -578,12 +578,14 @@ def get_trajectory_and_baseline_value_estimates(
578578
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
579579
next_obs = [obs.unsqueeze(0) for obs in next_obs]
580580

581-
next_group_obs = [
582-
ModelUtils.list_to_tensor_list(_list_obs) for _list_obs in next_group_obs
581+
next_groupmate_obs = [
582+
ModelUtils.list_to_tensor_list(_list_obs)
583+
for _list_obs in next_groupmate_obs
583584
]
584585
# Expand dimensions of next critic obs
585-
next_group_obs = [
586-
[_obs.unsqueeze(0) for _obs in _list_obs] for _list_obs in next_group_obs
586+
next_groupmate_obs = [
587+
[_obs.unsqueeze(0) for _obs in _list_obs]
588+
for _list_obs in next_groupmate_obs
587589
]
588590

589591
if agent_id in self.value_memory_dict:
@@ -638,7 +640,9 @@ def get_trajectory_and_baseline_value_estimates(
638640
self.baseline_memory_dict[agent_id] = next_baseline_mem
639641

640642
all_next_obs = (
641-
[next_obs] + next_group_obs if next_group_obs is not None else [next_obs]
643+
[next_obs] + next_groupmate_obs
644+
if next_groupmate_obs is not None
645+
else [next_obs]
642646
)
643647

644648
next_value_estimates, _ = self.critic.critic_pass(

0 commit comments

Comments
 (0)