Skip to content

Commit 1161d33

Browse files
author
Chris Elion
authored
handle multiple dones in a single step (#3700)
* handle multiple dones in a single step
1 parent 25ddb70 commit 1161d33

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1717
- The way that UnityEnvironment decides the port was changed. If no port is specified, the behavior will depend on the `file_name` parameter. If it is `None`, 5004 (the editor port) will be used; otherwise 5005 (the base environment port) will be used.
1818
- Fixed an issue where switching models using `SetModel()` during training would use an excessive amount of memory. (#3664)
1919
- Environment subprocesses now close immediately on timeout or wrong API version. (#3679)
20+
- Fixed an issue in the gym wrapper that would raise an exception if an Agent called EndEpisode multiple times in the same step. (#3700)
2021
- Fixed an issue where exceptions from environments provided a returncode of 0. (#3680)
2122

2223
## [0.15.0-preview] - 2020-03-18

gym-unity/gym_unity/envs/__init__.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,8 @@ def _check_agents(self, n_agents: int) -> None:
364364

365365
def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
366366
n_extra_agents = step_result.n_agents() - self._n_agents
367-
if n_extra_agents < 0 or n_extra_agents > self._n_agents:
367+
if n_extra_agents < 0:
368368
# In this case, some Agents did not request a decision when expected
369-
# or too many requested a decision
370369
raise UnityGymException(
371370
"The number of agents in the scene does not match the expected number."
372371
)
@@ -386,6 +385,10 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
386385
# only cares about the ordering.
387386
for index, agent_id in enumerate(step_result.agent_id):
388387
if not self._previous_step_result.contains_agent(agent_id):
388+
if step_result.done[index]:
389+
# If the Agent is already done (e.g. it ended its epsiode twice in one step)
390+
# Don't try to register it here.
391+
continue
389392
# Register this agent, and get the reward of the previous agent that
390393
# was in its index, so that we can return it to the gym.
391394
last_reward = self.agent_mapper.register_new_agent_id(agent_id)
@@ -528,8 +531,12 @@ def mark_agent_done(self, agent_id: int, reward: float) -> None:
528531
"""
529532
Declare the agent done with the corresponding final reward.
530533
"""
531-
gym_index = self._agent_id_to_gym_index.pop(agent_id)
532-
self._done_agents_index_to_last_reward[gym_index] = reward
534+
if agent_id in self._agent_id_to_gym_index:
535+
gym_index = self._agent_id_to_gym_index.pop(agent_id)
536+
self._done_agents_index_to_last_reward[gym_index] = reward
537+
else:
538+
# Agent was never registered in the first place (e.g. EndEpisode called multiple times)
539+
pass
533540

534541
def register_new_agent_id(self, agent_id: int) -> float:
535542
"""
@@ -581,9 +588,13 @@ def set_initial_agents(self, agent_ids: List[int]) -> None:
581588
self._gym_id_order = list(agent_ids)
582589

583590
def mark_agent_done(self, agent_id: int, reward: float) -> None:
584-
gym_index = self._gym_id_order.index(agent_id)
585-
self._done_agents_index_to_last_reward[gym_index] = reward
586-
self._gym_id_order[gym_index] = -1
591+
try:
592+
gym_index = self._gym_id_order.index(agent_id)
593+
self._done_agents_index_to_last_reward[gym_index] = reward
594+
self._gym_id_order[gym_index] = -1
595+
except ValueError:
596+
# Agent was never registered in the first place (e.g. EndEpisode called multiple times)
597+
pass
587598

588599
def register_new_agent_id(self, agent_id: int) -> float:
589600
original_index = self._gym_id_order.index(-1)

gym-unity/gym_unity/tests/test_gym.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,50 @@ def test_sanitize_action_one_agent_done(mock_env):
129129
assert expected_agent_id == agent_id
130130

131131

132+
@mock.patch("gym_unity.envs.UnityEnvironment")
133+
def test_sanitize_action_new_agent_done(mock_env):
134+
mock_spec = create_mock_group_spec(
135+
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
136+
)
137+
mock_step = create_mock_vector_step_result(num_agents=3)
138+
mock_step.agent_id = np.array(range(5))
139+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
140+
env = UnityEnv(" ", use_visual=False, multiagent=True)
141+
142+
received_step_result = create_mock_vector_step_result(num_agents=7)
143+
received_step_result.agent_id = np.array(range(7))
144+
# agent #3 (id = 2) is Done
145+
# so is the "new" agent (id = 5)
146+
done = [False] * 7
147+
done[2] = True
148+
done[5] = True
149+
received_step_result.done = np.array(done)
150+
sanitized_result = env._sanitize_info(received_step_result)
151+
for expected_agent_id, agent_id in zip([0, 1, 6, 3, 4], sanitized_result.agent_id):
152+
assert expected_agent_id == agent_id
153+
154+
155+
@mock.patch("gym_unity.envs.UnityEnvironment")
156+
def test_sanitize_action_single_agent_multiple_done(mock_env):
157+
mock_spec = create_mock_group_spec(
158+
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
159+
)
160+
mock_step = create_mock_vector_step_result(num_agents=1)
161+
mock_step.agent_id = np.array(range(1))
162+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
163+
env = UnityEnv(" ", use_visual=False, multiagent=False)
164+
165+
received_step_result = create_mock_vector_step_result(num_agents=3)
166+
received_step_result.agent_id = np.array(range(3))
167+
# original agent (id = 0) is Done
168+
# so is the "new" agent (id = 1)
169+
done = [True, True, False]
170+
received_step_result.done = np.array(done)
171+
sanitized_result = env._sanitize_info(received_step_result)
172+
for expected_agent_id, agent_id in zip([2], sanitized_result.agent_id):
173+
assert expected_agent_id == agent_id
174+
175+
132176
# Helper methods
133177

134178

@@ -200,6 +244,10 @@ def test_agent_id_index_mapper(mapper_cls):
200244
mapper.mark_agent_done(1001, 42.0)
201245
mapper.mark_agent_done(1004, 1337.0)
202246

247+
# Make sure we can handle an unknown agent id being marked done.
248+
# This can happen when an agent ends an episode on the same step it starts.
249+
mapper.mark_agent_done(9999, -1.0)
250+
203251
# Now add new agents, and get the rewards of the agent they replaced.
204252
old_reward1 = mapper.register_new_agent_id(2001)
205253
old_reward2 = mapper.register_new_agent_id(2002)

0 commit comments

Comments
 (0)