-
Notifications
You must be signed in to change notification settings - Fork 4.3k
[bug-fix] Make Python able to deal with 0-step episodes #3671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import sys | ||
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set | ||
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Any | ||
from collections import defaultdict, Counter, deque | ||
|
||
from mlagents_envs.base_env import BatchedStepResult, StepResult | ||
|
@@ -66,7 +66,6 @@ def add_experiences( | |
for _entropy in take_action_outputs["entropy"]: | ||
self.stats_reporter.add_stat("Policy/Entropy", _entropy) | ||
|
||
terminated_agents: Set[str] = set() | ||
# Make unique agent_ids that are global across workers | ||
action_global_agent_ids = [ | ||
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids | ||
|
@@ -85,6 +84,7 @@ def add_experiences( | |
stored_take_action_outputs = self.last_take_action_outputs.get( | ||
global_id, None | ||
) | ||
|
||
if stored_agent_step is not None and stored_take_action_outputs is not None: | ||
# We know the step is from the same worker, so use the local agent id. | ||
obs = stored_agent_step.obs | ||
|
@@ -143,11 +143,12 @@ def add_experiences( | |
traj_queue.put(trajectory) | ||
self.experience_buffers[global_id] = [] | ||
if curr_agent_step.done: | ||
# Record episode length for agents which have had at least | ||
# 1 step. Done after reset ignored. | ||
self.stats_reporter.add_stat( | ||
"Environment/Episode Length", | ||
self.episode_steps.get(global_id, 0), | ||
) | ||
terminated_agents.add(global_id) | ||
elif not curr_agent_step.done: | ||
self.episode_steps[global_id] += 1 | ||
|
||
|
@@ -156,9 +157,9 @@ def add_experiences( | |
curr_agent_step, | ||
batched_step_result.agent_id_to_index[_id], | ||
) | ||
|
||
for terminated_id in terminated_agents: | ||
self._clean_agent_data(terminated_id) | ||
# Delete all done agents, regardless of if they had a 0-length episode. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is definitely simpler. |
||
if curr_agent_step.done: | ||
self._clean_agent_data(global_id) | ||
|
||
for _gid in action_global_agent_ids: | ||
# If the ID doesn't have a last step result, the agent just reset, | ||
|
@@ -173,14 +174,22 @@ def _clean_agent_data(self, global_id: str) -> None: | |
""" | ||
Removes the data for an Agent. | ||
""" | ||
del self.experience_buffers[global_id] | ||
del self.last_take_action_outputs[global_id] | ||
del self.last_step_result[global_id] | ||
del self.episode_steps[global_id] | ||
del self.episode_rewards[global_id] | ||
self._safe_delete(self.experience_buffers, global_id) | ||
self._safe_delete(self.last_take_action_outputs, global_id) | ||
self._safe_delete(self.last_step_result, global_id) | ||
self._safe_delete(self.episode_steps, global_id) | ||
self._safe_delete(self.episode_rewards, global_id) | ||
self.policy.remove_previous_action([global_id]) | ||
self.policy.remove_memories([global_id]) | ||
|
||
def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None: | ||
""" | ||
Safe removes data from a dictionary. If not found, | ||
don't delete. | ||
""" | ||
if key in my_dictionary: | ||
del my_dictionary[key] | ||
|
||
def publish_trajectory_queue( | ||
self, trajectory_queue: "AgentManagerQueue[Trajectory]" | ||
) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -174,17 +174,6 @@ def get_action( | |
if batched_step_result.n_agents() == 0: | ||
return ActionInfo.empty() | ||
|
||
agents_done = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was left over from when we used to do it here (and not in AgentProcessor). |
||
agent | ||
for agent, done in zip( | ||
batched_step_result.agent_id, batched_step_result.done | ||
) | ||
if done | ||
] | ||
|
||
self.remove_memories(agents_done) | ||
self.remove_previous_action(agents_done) | ||
|
||
global_agent_ids = [ | ||
get_global_agent_id(worker_id, int(agent_id)) | ||
for agent_id in batched_step_result.agent_id | ||
|
@@ -379,9 +368,11 @@ def _initialize_tensorflow_references(self): | |
|
||
def create_input_placeholders(self): | ||
with self.graph.as_default(): | ||
self.global_step, self.increment_step_op, self.steps_to_increment = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, |
||
ModelUtils.create_global_steps() | ||
) | ||
( | ||
self.global_step, | ||
self.increment_step_op, | ||
self.steps_to_increment, | ||
) = ModelUtils.create_global_steps() | ||
self.visual_in = ModelUtils.create_visual_input_placeholders( | ||
self.brain.camera_resolutions | ||
) | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise we'd have a bunch of 0-length episodes if people change models