Skip to content

[bug-fix] Make Python able to deal with 0-step episodes #3671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Any
from collections import defaultdict, Counter, deque

from mlagents_envs.base_env import BatchedStepResult, StepResult
Expand Down Expand Up @@ -66,7 +66,6 @@ def add_experiences(
for _entropy in take_action_outputs["entropy"]:
self.stats_reporter.add_stat("Policy/Entropy", _entropy)

terminated_agents: Set[str] = set()
# Make unique agent_ids that are global across workers
action_global_agent_ids = [
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
Expand All @@ -85,6 +84,7 @@ def add_experiences(
stored_take_action_outputs = self.last_take_action_outputs.get(
global_id, None
)

if stored_agent_step is not None and stored_take_action_outputs is not None:
# We know the step is from the same worker, so use the local agent id.
obs = stored_agent_step.obs
Expand Down Expand Up @@ -143,11 +143,12 @@ def add_experiences(
traj_queue.put(trajectory)
self.experience_buffers[global_id] = []
if curr_agent_step.done:
# Record episode length for agents which have had at least
Copy link
Contributor Author

@ervteng ervteng Mar 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we'd have a bunch of 0-length episodes if people change models

# 1 step. Done after reset ignored.
self.stats_reporter.add_stat(
"Environment/Episode Length",
self.episode_steps.get(global_id, 0),
)
terminated_agents.add(global_id)
elif not curr_agent_step.done:
self.episode_steps[global_id] += 1

Expand All @@ -156,9 +157,9 @@ def add_experiences(
curr_agent_step,
batched_step_result.agent_id_to_index[_id],
)

for terminated_id in terminated_agents:
self._clean_agent_data(terminated_id)
# Delete all done agents, regardless of if they had a 0-length episode.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely simpler.

if curr_agent_step.done:
self._clean_agent_data(global_id)

for _gid in action_global_agent_ids:
# If the ID doesn't have a last step result, the agent just reset,
Expand All @@ -173,14 +174,22 @@ def _clean_agent_data(self, global_id: str) -> None:
"""
Removes the data for an Agent.
"""
del self.experience_buffers[global_id]
del self.last_take_action_outputs[global_id]
del self.last_step_result[global_id]
del self.episode_steps[global_id]
del self.episode_rewards[global_id]
self._safe_delete(self.experience_buffers, global_id)
self._safe_delete(self.last_take_action_outputs, global_id)
self._safe_delete(self.last_step_result, global_id)
self._safe_delete(self.episode_steps, global_id)
self._safe_delete(self.episode_rewards, global_id)
self.policy.remove_previous_action([global_id])
self.policy.remove_memories([global_id])

def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None:
"""
Safe removes data from a dictionary. If not found,
don't delete.
"""
if key in my_dictionary:
del my_dictionary[key]

def publish_trajectory_queue(
self, trajectory_queue: "AgentManagerQueue[Trajectory]"
) -> None:
Expand Down
19 changes: 5 additions & 14 deletions ml-agents/mlagents/trainers/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,6 @@ def get_action(
if batched_step_result.n_agents() == 0:
return ActionInfo.empty()

agents_done = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was left over from when we used to do it here (and not in AgentProcessor).

agent
for agent, done in zip(
batched_step_result.agent_id, batched_step_result.done
)
if done
]

self.remove_memories(agents_done)
self.remove_previous_action(agents_done)

global_agent_ids = [
get_global_agent_id(worker_id, int(agent_id))
for agent_id in batched_step_result.agent_id
Expand Down Expand Up @@ -379,9 +368,11 @@ def _initialize_tensorflow_references(self):

def create_input_placeholders(self):
with self.graph.as_default():
self.global_step, self.increment_step_op, self.steps_to_increment = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, black.

ModelUtils.create_global_steps()
)
(
self.global_step,
self.increment_step_op,
self.steps_to_increment,
) = ModelUtils.create_global_steps()
self.visual_in = ModelUtils.create_visual_input_placeholders(
self.brain.camera_resolutions
)
Expand Down
9 changes: 9 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ def test_agent_deletion():
assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0
assert len(processor.last_step_result.keys()) == 0

# check that steps with immediate dones don't add to dicts
processor.add_experiences(mock_done_step, 0, ActionInfo.empty())
assert len(processor.experience_buffers.keys()) == 0
assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0
assert len(processor.last_step_result.keys()) == 0


def test_end_episode():
Expand Down