Skip to content

Commit 10d63ae

Browse files
authored
rename to groupmate obs (#5094)
1 parent 100a7ac commit 10d63ae

File tree

1 file changed

+47
-46
lines changed

1 file changed

+47
-46
lines changed

ml-agents/mlagents/trainers/poca/optimizer_torch.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,25 @@ def update_normalization(self, buffer: AgentBuffer) -> None:
7171

7272
def baseline(
7373
self,
74-
self_obs: List[List[torch.Tensor]],
75-
obs: List[List[torch.Tensor]],
76-
actions: List[AgentAction],
74+
obs_without_actions: List[torch.Tensor],
75+
obs_with_actions: Tuple[List[List[torch.Tensor]], List[AgentAction]],
7776
memories: Optional[torch.Tensor] = None,
7877
sequence_length: int = 1,
7978
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
8079
"""
8180
The POCA baseline marginalizes the action of the agent associated with self_obs.
8281
It calls the forward pass of the MultiAgentNetworkBody with the state action
8382
pairs of groupmates but just the state of the agent in question.
84-
:param self_obs: The obs of the agent for w.r.t. which to compute the baseline
85-
:param obs: List of observations for all groupmates. Should be the same length
86-
as actions.
87-
:param actions: List of actions for all groupmates. Should be the same length
88-
as obs.
83+
:param obs_without_actions: The obs of the agent for which to compute the baseline.
84+
:param obs_with_actions: Tuple of observations and actions for all groupmates.
8985
:param memories: If using memory, a Tensor of initial memories.
9086
:param sequence_length: If using memory, the sequence length.
87+
9188
:return: A Tuple of Dict of reward stream to tensor and critic memories.
9289
"""
90+
(obs, actions) = obs_with_actions
9391
encoding, memories = self.network_body(
94-
obs_only=self_obs,
92+
obs_only=[obs_without_actions],
9593
obs=obs,
9694
actions=actions,
9795
memories=memories,
@@ -256,7 +254,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
256254

257255
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
258256
actions = AgentAction.from_buffer(batch)
259-
group_actions = AgentAction.group_from_buffer(batch)
257+
groupmate_actions = AgentAction.group_from_buffer(batch)
260258

261259
memories = [
262260
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i])
@@ -295,10 +293,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
295293
memories=value_memories,
296294
sequence_length=self.policy.sequence_length,
297295
)
296+
groupmate_obs_and_actions = (groupmate_obs, groupmate_actions)
298297
baselines, _ = self.critic.baseline(
299-
[current_obs],
300-
groupmate_obs,
301-
group_actions,
298+
current_obs,
299+
groupmate_obs_and_actions,
302300
memories=baseline_memories,
303301
sequence_length=self.policy.sequence_length,
304302
)
@@ -391,22 +389,22 @@ def _evaluate_by_sequence_team(
391389
first_seq_len = leftover if leftover > 0 else self.policy.sequence_length
392390

393391
self_seq_obs = []
394-
team_seq_obs = []
395-
team_seq_act = []
392+
groupmate_seq_obs = []
393+
groupmate_seq_act = []
396394
seq_obs = []
397395
for _self_obs in self_obs:
398396
first_seq_obs = _self_obs[0:first_seq_len]
399397
seq_obs.append(first_seq_obs)
400398
self_seq_obs.append(seq_obs)
401399

402-
for team_obs, team_action in zip(obs, actions):
400+
for groupmate_obs, groupmate_action in zip(obs, actions):
403401
seq_obs = []
404-
for _obs in team_obs:
402+
for _obs in groupmate_obs:
405403
first_seq_obs = _obs[0:first_seq_len]
406404
seq_obs.append(first_seq_obs)
407-
team_seq_obs.append(seq_obs)
408-
_act = team_action.slice(0, first_seq_len)
409-
team_seq_act.append(_act)
405+
groupmate_seq_obs.append(seq_obs)
406+
_act = groupmate_action.slice(0, first_seq_len)
407+
groupmate_seq_act.append(_act)
410408

411409
# For the first sequence, the initial memory should be the one at the
412410
# beginning of this trajectory.
@@ -416,7 +414,7 @@ def _evaluate_by_sequence_team(
416414
ModelUtils.to_numpy(init_baseline_mem.squeeze())
417415
)
418416

419-
all_seq_obs = self_seq_obs + team_seq_obs
417+
all_seq_obs = self_seq_obs + groupmate_seq_obs
420418
init_values, _value_mem = self.critic.critic_pass(
421419
all_seq_obs, init_value_mem, sequence_length=first_seq_len
422420
)
@@ -425,10 +423,10 @@ def _evaluate_by_sequence_team(
425423
for signal_name in init_values.keys()
426424
}
427425

426+
groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act)
428427
init_baseline, _baseline_mem = self.critic.baseline(
429-
self_seq_obs,
430-
team_seq_obs,
431-
team_seq_act,
428+
self_seq_obs[0],
429+
groupmate_obs_and_actions,
432430
init_baseline_mem,
433431
sequence_length=first_seq_len,
434432
)
@@ -456,34 +454,34 @@ def _evaluate_by_sequence_team(
456454
)
457455

458456
self_seq_obs = []
459-
team_seq_obs = []
460-
team_seq_act = []
457+
groupmate_seq_obs = []
458+
groupmate_seq_act = []
461459
seq_obs = []
462460
for _self_obs in self_obs:
463461
seq_obs.append(_obs[start:end])
464462
self_seq_obs.append(seq_obs)
465463

466-
for team_obs, team_action in zip(obs, actions):
464+
for groupmate_obs, team_action in zip(obs, actions):
467465
seq_obs = []
468-
for (_obs,) in team_obs:
466+
for (_obs,) in groupmate_obs:
469467
first_seq_obs = _obs[start:end]
470468
seq_obs.append(first_seq_obs)
471-
team_seq_obs.append(seq_obs)
469+
groupmate_seq_obs.append(seq_obs)
472470
_act = team_action.slice(start, end)
473-
team_seq_act.append(_act)
471+
groupmate_seq_act.append(_act)
474472

475-
all_seq_obs = self_seq_obs + team_seq_obs
473+
all_seq_obs = self_seq_obs + groupmate_seq_obs
476474
values, _value_mem = self.critic.critic_pass(
477475
all_seq_obs, _value_mem, sequence_length=self.policy.sequence_length
478476
)
479477
all_values = {
480478
signal_name: [init_values[signal_name]] for signal_name in values.keys()
481479
}
482480

481+
groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act)
483482
baselines, _baseline_mem = self.critic.baseline(
484-
self_seq_obs,
485-
team_seq_obs,
486-
team_seq_act,
483+
self_seq_obs[0],
484+
groupmate_obs_and_actions,
487485
_baseline_mem,
488486
sequence_length=first_seq_len,
489487
)
@@ -565,15 +563,15 @@ def get_trajectory_and_baseline_value_estimates(
565563
n_obs = len(self.policy.behavior_spec.observation_specs)
566564

567565
current_obs = ObsUtil.from_buffer(batch, n_obs)
568-
team_obs = GroupObsUtil.from_buffer(batch, n_obs)
566+
groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs)
569567

570568
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
571-
team_obs = [
572-
[ModelUtils.list_to_tensor(obs) for obs in _teammate_obs]
573-
for _teammate_obs in team_obs
569+
groupmate_obs = [
570+
[ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs]
571+
for _groupmate_obs in groupmate_obs
574572
]
575573

576-
team_actions = AgentAction.group_from_buffer(batch)
574+
groupmate_actions = AgentAction.group_from_buffer(batch)
577575

578576
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
579577
next_obs = [obs.unsqueeze(0) for obs in next_obs]
@@ -604,7 +602,11 @@ def get_trajectory_and_baseline_value_estimates(
604602
else None
605603
)
606604

607-
all_obs = [current_obs] + team_obs if team_obs is not None else [current_obs]
605+
all_obs = (
606+
[current_obs] + groupmate_obs
607+
if groupmate_obs is not None
608+
else [current_obs]
609+
)
608610
all_next_value_mem: Optional[AgentBufferField] = None
609611
all_next_baseline_mem: Optional[AgentBufferField] = None
610612
with torch.no_grad():
@@ -618,20 +620,19 @@ def get_trajectory_and_baseline_value_estimates(
618620
next_baseline_mem,
619621
) = self._evaluate_by_sequence_team(
620622
current_obs,
621-
team_obs,
622-
team_actions,
623+
groupmate_obs,
624+
groupmate_actions,
623625
_init_value_mem,
624626
_init_baseline_mem,
625627
)
626628
else:
627629
value_estimates, next_value_mem = self.critic.critic_pass(
628630
all_obs, _init_value_mem, sequence_length=batch.num_experiences
629631
)
630-
632+
groupmate_obs_and_actions = (groupmate_obs, groupmate_actions)
631633
baseline_estimates, next_baseline_mem = self.critic.baseline(
632-
[current_obs],
633-
team_obs,
634-
team_actions,
634+
current_obs,
635+
groupmate_obs_and_actions,
635636
_init_baseline_mem,
636637
sequence_length=batch.num_experiences,
637638
)

0 commit comments

Comments
 (0)