Skip to content

Commit 6150aa9

Browse files
author
Ervin T
authored
[bug-fix] Fix when group terminal steps are deleted, robust test (#5441)
* Fix when terminal steps are deleted, robust test * Update changelog * Fix test comment
1 parent 5921162 commit 6150aa9

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ and this project adheres to
1616
### Bug Fixes
1717
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
1818
#### ml-agents / ml-agents-envs / gym-unity (Python)
19+
- Fixed a bug in multi-agent cooperative training where agents might not receive all of the states of
20+
terminated teammates. (#5441)
1921

2022
## [2.1.0-exp.1] - 2021-06-09
2123
### Minor Changes

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ def add_experiences(
122122
self._process_step(
123123
terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id]
124124
)
125-
# Clear the last seen group obs when agents die.
126-
self._clear_group_status_and_obs(global_id)
127125

128126
# Iterate over all the decision steps, first gather all the group obs
129127
# and then create the trajectories. _add_to_group_status
@@ -135,6 +133,12 @@ def add_experiences(
135133
self._process_step(
136134
ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id]
137135
)
136+
# Clear the last seen group obs when agents die, but only after all of the group
137+
# statuses were added to the trajectory.
138+
for terminal_step in terminal_steps.values():
139+
local_id = terminal_step.agent_id
140+
global_id = get_global_agent_id(worker_id, local_id)
141+
self._clear_group_status_and_obs(global_id)
138142

139143
for _gid in action_global_agent_ids:
140144
# If the ID doesn't have a last step result, the agent just reset,

ml-agents/mlagents/trainers/tests/mock_brain.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import List, Optional, Tuple
22
import numpy as np
33

44
from mlagents.trainers.buffer import AgentBuffer, AgentBufferKey
@@ -21,6 +21,7 @@ def create_mock_steps(
2121
action_spec: ActionSpec,
2222
done: bool = False,
2323
grouped: bool = False,
24+
agent_ids: Optional[List[int]] = None,
2425
) -> Tuple[DecisionSteps, TerminalSteps]:
2526
"""
2627
Creates a mock Tuple[DecisionSteps, TerminalSteps] with observations.
@@ -43,7 +44,10 @@ def create_mock_steps(
4344

4445
reward = np.array(num_agents * [1.0], dtype=np.float32)
4546
interrupted = np.array(num_agents * [False], dtype=np.bool)
46-
agent_id = np.arange(num_agents, dtype=np.int32)
47+
if agent_ids is not None:
48+
agent_id = np.array(agent_ids, dtype=np.int32)
49+
else:
50+
agent_id = np.arange(num_agents, dtype=np.int32)
4751
_gid = 1 if grouped else 0
4852
group_id = np.array(num_agents * [_gid], dtype=np.int32)
4953
group_reward = np.array(num_agents * [0.0], dtype=np.float32)

ml-agents/mlagents/trainers/tests/test_agent_processor.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,32 +137,54 @@ def test_group_statuses():
137137
)
138138

139139
# Make terminal steps for some dead agents
140-
mock_decision_steps_2, mock_terminal_steps_2 = mb.create_mock_steps(
140+
_, mock_terminal_steps_2 = mb.create_mock_steps(
141141
num_agents=2,
142142
observation_specs=create_observation_specs_with_shapes([(8,)]),
143143
action_spec=ActionSpec.create_continuous(2),
144144
done=True,
145145
grouped=True,
146+
agent_ids=[2, 3],
147+
)
148+
# Make decision steps continue for other agents
149+
mock_decision_steps_2, _ = mb.create_mock_steps(
150+
num_agents=2,
151+
observation_specs=create_observation_specs_with_shapes([(8,)]),
152+
action_spec=ActionSpec.create_continuous(2),
153+
done=False,
154+
grouped=True,
155+
agent_ids=[0, 1],
146156
)
147157

148158
processor.add_experiences(
149159
mock_decision_steps_2, mock_terminal_steps_2, 0, fake_action_info
150160
)
151-
fake_action_info = _create_action_info(4, mock_decision_steps.agent_id)
161+
# Continue to add for remaining live agents
162+
fake_action_info = _create_action_info(4, mock_decision_steps_2.agent_id)
152163
for _ in range(3):
153164
processor.add_experiences(
154-
mock_decision_steps, mock_terminal_steps, 0, fake_action_info
165+
mock_decision_steps_2, mock_terminal_steps, 0, fake_action_info
155166
)
156167

157168
# Assert that four trajectories have been added to the Trainer
158169
assert len(tqueue.put.call_args_list) == 4
159-
# Last trajectory should be the longest
170+
171+
# Get the first trajectory, which should have been agent 2 (one of the killed agents)
160172
trajectory = tqueue.put.call_args_list[0][0][-1]
173+
assert len(trajectory.steps) == 3
174+
# Make sure trajectory has the right Groupmate Experiences.
175+
# All three steps should contain all agents
176+
for step in trajectory.steps:
177+
assert len(step.group_status) == 3
178+
179+
# Last trajectory should be the longest. It should be that of agent 1, one of the surviving agents.
180+
trajectory = tqueue.put.call_args_list[-1][0][-1]
181+
assert len(trajectory.steps) == 5
161182

162-
# Make sure trajectory has the right Groupmate Experiences
183+
# Make sure trajectory has the right Groupmate Experiences.
184+
# THe first 3 steps should contain all of the obs (that 3rd step is also the terminal step of 2 of the agents)
163185
for step in trajectory.steps[0:3]:
164186
assert len(step.group_status) == 3
165-
# After 2 agents has died
187+
# After 2 agents has died, there should only be 1 group status.
166188
for step in trajectory.steps[3:]:
167189
assert len(step.group_status) == 1
168190

0 commit comments

Comments
 (0)