Skip to content

[refactor] Remove duplicated logic from AgentProcessor #3728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 82 additions & 125 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import sys
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Any
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Any, Union
from collections import defaultdict, Counter, deque

from mlagents_envs.base_env import DecisionSteps, DecisionStep, TerminalSteps
from mlagents_envs.base_env import (
DecisionSteps,
DecisionStep,
TerminalSteps,
TerminalStep,
)
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.policy.tf_policy import TFPolicy
Expand Down Expand Up @@ -78,138 +83,21 @@ def add_experiences(
self.last_take_action_outputs[global_id] = take_action_outputs

# Iterate over all the terminal steps
for terminated_step in terminal_steps.values():
local_id = terminated_step.agent_id
for terminal_step in terminal_steps.values():
local_id = terminal_step.agent_id
global_id = get_global_agent_id(worker_id, local_id)
stored_decision_step, idx = self.last_step_result.get(
global_id, (None, None)
self._process_step(
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
)
stored_take_action_outputs = self.last_take_action_outputs.get(
global_id, None
)
# We do not need to store this step as this agent is terminated
# This state is the consequence of a past action
if (
stored_decision_step is not None
and stored_take_action_outputs is not None
):
obs = stored_decision_step.obs
if self.policy.use_recurrent:
memory = self.policy.retrieve_memories([global_id])[0, :]
else:
memory = None
done = True # The agent is terminated
max_step = terminated_step.max_step
# Add the outputs of the last eval
action = stored_take_action_outputs["action"][idx]
if self.policy.use_continuous_act:
action_pre = stored_take_action_outputs["pre_action"][idx]
else:
action_pre = None
action_probs = stored_take_action_outputs["log_probs"][idx]
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
experience = AgentExperience(
obs=obs,
reward=terminated_step.reward,
done=done,
action=action,
action_probs=action_probs,
action_pre=action_pre,
action_mask=action_mask,
prev_action=prev_action,
max_step=max_step,
memory=memory,
)
# Add the value outputs if needed
self.experience_buffers[global_id].append(experience)
self.episode_rewards[global_id] += terminated_step.reward

# Since the Agent is done, we must generate the trajectory
# Make next AgentExperience
next_obs = terminated_step.obs
trajectory = Trajectory(
steps=self.experience_buffers[global_id],
agent_id=global_id,
next_obs=next_obs,
behavior_id=self.behavior_id,
)
for traj_queue in self.trajectory_queues:
traj_queue.put(trajectory)
self.experience_buffers[global_id] = []
self._clean_agent_data(global_id)

# Iterate over all the decision steps
for ongoing_step in decision_steps.values():
local_id = ongoing_step.agent_id
global_id = get_global_agent_id(worker_id, local_id)
stored_decision_step, idx = self.last_step_result.get(
global_id, (None, None)
)
stored_take_action_outputs = self.last_take_action_outputs.get(
global_id, None
)
# Index is needed to grab from last_take_action_outputs
self.last_step_result[global_id] = (
ongoing_step,
decision_steps.agent_id_to_index[local_id],
self._process_step(
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
)

# This state is the consequence of a past action
if (
stored_decision_step is not None
and stored_take_action_outputs is not None
):
obs = stored_decision_step.obs
if self.policy.use_recurrent:
memory = self.policy.retrieve_memories([global_id])[0, :]
else:
memory = None
done = False # Since this is an ongoing step
max_step = False
# Add the outputs of the last eval
action = stored_take_action_outputs["action"][idx]
if self.policy.use_continuous_act:
action_pre = stored_take_action_outputs["pre_action"][idx]
else:
action_pre = None
action_probs = stored_take_action_outputs["log_probs"][idx]
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
experience = AgentExperience(
obs=obs,
reward=ongoing_step.reward,
done=done,
action=action,
action_probs=action_probs,
action_pre=action_pre,
action_mask=action_mask,
prev_action=prev_action,
max_step=max_step,
memory=memory,
)
# Add the value outputs if needed
self.experience_buffers[global_id].append(experience)
self.episode_rewards[global_id] += ongoing_step.reward
self.episode_steps[global_id] += 1

# if the trajectory is too long, we truncate it
if (
len(self.experience_buffers[global_id])
>= self.max_trajectory_length
):
# Make next AgentExperience
next_obs = ongoing_step.obs
trajectory = Trajectory(
steps=self.experience_buffers[global_id],
agent_id=global_id,
next_obs=next_obs,
behavior_id=self.behavior_id,
)
for traj_queue in self.trajectory_queues:
traj_queue.put(trajectory)
self.experience_buffers[global_id] = []

for _gid in action_global_agent_ids:
# If the ID doesn't have a last step result, the agent just reset,
# don't store the action.
Expand All @@ -219,6 +107,75 @@ def add_experiences(
[_gid], take_action_outputs["action"]
)

def _process_step(
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int
) -> None:
terminated = isinstance(step, TerminalStep)
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None)
if not terminated:
# Index is needed to grab from last_take_action_outputs
self.last_step_result[global_id] = (step, index)

# This state is the consequence of a past action
if stored_decision_step is not None and stored_take_action_outputs is not None:
obs = stored_decision_step.obs
if self.policy.use_recurrent:
memory = self.policy.retrieve_memories([global_id])[0, :]
else:
memory = None
done = terminated # Since this is an ongoing step
max_step = step.max_step if terminated else False
# Add the outputs of the last eval
action = stored_take_action_outputs["action"][idx]
if self.policy.use_continuous_act:
action_pre = stored_take_action_outputs["pre_action"][idx]
else:
action_pre = None
action_probs = stored_take_action_outputs["log_probs"][idx]
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
experience = AgentExperience(
obs=obs,
reward=step.reward,
done=done,
action=action,
action_probs=action_probs,
action_pre=action_pre,
action_mask=action_mask,
prev_action=prev_action,
max_step=max_step,
memory=memory,
)
# Add the value outputs if needed
self.experience_buffers[global_id].append(experience)
self.episode_rewards[global_id] += step.reward
if not terminated:
self.episode_steps[global_id] += 1

# if the trajectory is too long, we truncate it
if (
len(self.experience_buffers[global_id]) >= self.max_trajectory_length
or terminated
):
# Make next AgentExperience
next_obs = step.obs
trajectory = Trajectory(
steps=self.experience_buffers[global_id],
agent_id=global_id,
next_obs=next_obs,
behavior_id=self.behavior_id,
)
for traj_queue in self.trajectory_queues:
traj_queue.put(trajectory)
self.experience_buffers[global_id] = []
if terminated:
# Record episode length.
self.stats_reporter.add_stat(
"Environment/Episode Length", self.episode_steps.get(global_id, 0)
)
self._clean_agent_data(global_id)

def _clean_agent_data(self, global_id: str) -> None:
"""
Removes the data for an Agent.
Expand Down