-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Move add_experiences out of trainer, add Trajectories #3067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9a16838
55b2918
38f5795
453dd4c
b00f779
3b7191b
9c47678
a57a220
efe29c8
f5f9598
f3459eb
0b603c7
ea6e79d
a264b48
5e4f1bc
a5ac988
a2e33e8
7004db8
0863ff5
88feb1b
d6fe367
8e43ecd
2b32d61
991be2c
9b7969b
5efd4e9
3bfe3df
f7649ae
6b40d00
bf59521
68984df
12d4467
2322150
2d084ed
295e3a0
9334bb6
93060b5
3a3eb5b
4c5bd73
1c95992
a48e7f7
0053517
29797b1
e9dcdd9
68a3b3d
6298731
cd4c09c
1a545c1
9452806
1052ad5
94c5f8c
866bf9c
03bd3e4
153368c
fd1312b
1a7fffd
d1b30b3
d9abe26
f090033
6a1f275
0f08718
80a3359
a938d61
63d6dd0
82e8191
9a83b66
c827581
89f9375
212cc3b
10dcc1b
2d72b06
a0c76c7
b1060e5
70f91af
8a44fc5
919a00b
7122d39
cb1ec87
9d554bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,155 @@ | ||
from typing import List, Union | ||
import sys | ||
from typing import List, Dict | ||
from collections import defaultdict, Counter | ||
|
||
from mlagents.trainers.buffer import AgentBuffer, BufferException | ||
from mlagents.trainers.trainer import Trainer | ||
from mlagents.trainers.trajectory import Trajectory, AgentExperience | ||
from mlagents.trainers.brain import BrainInfo | ||
from mlagents.trainers.tf_policy import TFPolicy | ||
from mlagents.trainers.action_info import ActionInfoOutputs | ||
from mlagents.trainers.stats import StatsReporter | ||
|
||
|
||
class ProcessingBuffer(dict): | ||
class AgentProcessor: | ||
""" | ||
ProcessingBuffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id. | ||
AgentProcessor contains a dictionary per-agent trajectory buffers. The buffers are indexed by agent_id. | ||
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model. | ||
One AgentProcessor should be created per agent group. | ||
""" | ||
|
||
def __str__(self): | ||
return "local_buffers :\n{0}".format( | ||
"\n".join(["\tagent {0} :{1}".format(k, str(self[k])) for k in self.keys()]) | ||
) | ||
|
||
def __getitem__(self, key): | ||
if key not in self.keys(): | ||
self[key] = AgentBuffer() | ||
return super().__getitem__(key) | ||
|
||
def reset_local_buffers(self) -> None: | ||
def __init__( | ||
self, | ||
trainer: Trainer, | ||
policy: TFPolicy, | ||
stats_reporter: StatsReporter, | ||
max_trajectory_length: int = sys.maxsize, | ||
): | ||
""" | ||
Resets all the local AgentBuffers. | ||
Create an AgentProcessor. | ||
:param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory | ||
when it is finished. | ||
:param policy: Policy instance associated with this AgentProcessor. | ||
:param max_trajectory_length: Maximum length of a trajectory before it is added to the trainer. | ||
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. | ||
""" | ||
for buf in self.values(): | ||
buf.reset_agent() | ||
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) | ||
self.last_brain_info: Dict[str, BrainInfo] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the last_brain_info per agent ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} | ||
# Note: this is needed until we switch to AgentExperiences as the data input type. | ||
# We still need some info from the policy (memories, previous actions) | ||
# that really should be gathered by the env-manager. | ||
self.policy = policy | ||
self.episode_steps: Counter = Counter() | ||
self.episode_rewards: Dict[str, float] = defaultdict(float) | ||
self.stats_reporter = stats_reporter | ||
self.trainer = trainer | ||
self.max_trajectory_length = max_trajectory_length | ||
|
||
def append_to_update_buffer( | ||
def add_experiences( | ||
self, | ||
update_buffer: AgentBuffer, | ||
agent_id: Union[int, str], | ||
key_list: List[str] = None, | ||
batch_size: int = None, | ||
training_length: int = None, | ||
curr_info: BrainInfo, | ||
next_info: BrainInfo, | ||
take_action_outputs: ActionInfoOutputs, | ||
) -> None: | ||
""" | ||
Appends the buffer of an agent to the update buffer. | ||
:param update_buffer: A reference to an AgentBuffer to append the agent's buffer to | ||
:param agent_id: The id of the agent which data will be appended | ||
:param key_list: The fields that must be added. If None: all fields will be appended. | ||
:param batch_size: The number of elements that must be appended. If None: All of them will be. | ||
:param training_length: The length of the samples that must be appended. If None: only takes one element. | ||
Adds experiences to each agent's experience history. | ||
:param curr_info: current BrainInfo. | ||
:param next_info: next BrainInfo. | ||
:param take_action_outputs: The outputs of the Policy's get_action method. | ||
""" | ||
if key_list is None: | ||
key_list = self[agent_id].keys() | ||
if not self[agent_id].check_length(key_list): | ||
raise BufferException( | ||
"The length of the fields {0} for agent {1} were not of same length".format( | ||
key_list, agent_id | ||
) | ||
if take_action_outputs: | ||
self.stats_reporter.add_stat( | ||
"Policy/Entropy", take_action_outputs["entropy"].mean() | ||
) | ||
for field_key in key_list: | ||
update_buffer[field_key].extend( | ||
self[agent_id][field_key].get_batch( | ||
batch_size=batch_size, training_length=training_length | ||
) | ||
self.stats_reporter.add_stat( | ||
"Policy/Learning Rate", take_action_outputs["learning_rate"] | ||
) | ||
|
||
def append_all_agent_batch_to_update_buffer( | ||
self, | ||
update_buffer: AgentBuffer, | ||
key_list: List[str] = None, | ||
batch_size: int = None, | ||
training_length: int = None, | ||
) -> None: | ||
""" | ||
Appends the buffer of all agents to the update buffer. | ||
:param key_list: The fields that must be added. If None: all fields will be appended. | ||
:param batch_size: The number of elements that must be appended. If None: All of them will be. | ||
:param training_length: The length of the samples that must be appended. If None: only takes one element. | ||
""" | ||
for agent_id in self.keys(): | ||
self.append_to_update_buffer( | ||
update_buffer, agent_id, key_list, batch_size, training_length | ||
) | ||
for agent_id in curr_info.agents: | ||
self.last_brain_info[agent_id] = curr_info | ||
self.last_take_action_outputs[agent_id] = take_action_outputs | ||
|
||
# Store the environment reward | ||
tmp_environment_reward = next_info.rewards | ||
|
||
for next_idx, agent_id in enumerate(next_info.agents): | ||
stored_info = self.last_brain_info.get(agent_id, None) | ||
if stored_info is not None: | ||
stored_take_action_outputs = self.last_take_action_outputs[agent_id] | ||
idx = stored_info.agents.index(agent_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Little worried about the O(N) lookup here since we're doing it N times. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to do something like
outside the loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tricky bit here is that the stored_info might be different per iteration of the loop (some agents in next_info might not have been in the previous info and vice-versa). So the index might change as well. To make matters worse, we do this indexing twice (once here, and once in the LL-Python API to convert BatchedState-> BrainInfo). Long-term we will be removing BrainInfo, (today: BatchedState -> BrainInfo -> AgentExperience, end goal: BatchedState -> AgentExperience) so I think we will be able to get away with simply adding to trajectories agent-by-agent. We won't have to store the stored_info anymore. In this case, we will only have to do the indexing once. |
||
obs = [] | ||
if not stored_info.local_done[idx]: | ||
for i, _ in enumerate(stored_info.visual_observations): | ||
obs.append(stored_info.visual_observations[i][idx]) | ||
if self.policy.use_vec_obs: | ||
obs.append(stored_info.vector_observations[idx]) | ||
if self.policy.use_recurrent: | ||
memory = self.policy.retrieve_memories([agent_id])[0, :] | ||
else: | ||
memory = None | ||
|
||
done = next_info.local_done[next_idx] | ||
max_step = next_info.max_reached[next_idx] | ||
|
||
# Add the outputs of the last eval | ||
action = stored_take_action_outputs["action"][idx] | ||
if self.policy.use_continuous_act: | ||
action_pre = stored_take_action_outputs["pre_action"][idx] | ||
else: | ||
action_pre = None | ||
action_probs = stored_take_action_outputs["log_probs"][idx] | ||
action_masks = stored_info.action_masks[idx] | ||
prev_action = self.policy.retrieve_previous_action([agent_id])[0, :] | ||
|
||
experience = AgentExperience( | ||
obs=obs, | ||
reward=tmp_environment_reward[next_idx], | ||
done=done, | ||
action=action, | ||
action_probs=action_probs, | ||
action_pre=action_pre, | ||
action_mask=action_masks, | ||
prev_action=prev_action, | ||
max_step=max_step, | ||
memory=memory, | ||
) | ||
# Add the value outputs if needed | ||
self.experience_buffers[agent_id].append(experience) | ||
self.episode_rewards[agent_id] += tmp_environment_reward[next_idx] | ||
if ( | ||
next_info.local_done[next_idx] | ||
or ( | ||
len(self.experience_buffers[agent_id]) | ||
>= self.max_trajectory_length | ||
) | ||
) and len(self.experience_buffers[agent_id]) > 0: | ||
# Make next AgentExperience | ||
next_obs = [] | ||
for i, _ in enumerate(next_info.visual_observations): | ||
next_obs.append(next_info.visual_observations[i][next_idx]) | ||
if self.policy.use_vec_obs: | ||
next_obs.append(next_info.vector_observations[next_idx]) | ||
trajectory = Trajectory( | ||
steps=self.experience_buffers[agent_id], | ||
agent_id=agent_id, | ||
next_obs=next_obs, | ||
) | ||
# This will eventually be replaced with a queue | ||
self.trainer.process_trajectory(trajectory) | ||
self.experience_buffers[agent_id] = [] | ||
if next_info.local_done[next_idx]: | ||
self.stats_reporter.add_stat( | ||
"Environment/Cumulative Reward", | ||
self.episode_rewards.get(agent_id, 0), | ||
) | ||
self.stats_reporter.add_stat( | ||
"Environment/Episode Length", | ||
self.episode_steps.get(agent_id, 0), | ||
) | ||
del self.episode_steps[agent_id] | ||
del self.episode_rewards[agent_id] | ||
elif not next_info.local_done[next_idx]: | ||
self.episode_steps[agent_id] += 1 | ||
self.policy.save_previous_action( | ||
curr_info.agents, take_action_outputs["action"] | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -255,6 +255,35 @@ def truncate(self, max_length: int, sequence_length: int = 1) -> None: | |
for _key in self.keys(): | ||
self[_key] = self[_key][current_length - max_length :] | ||
|
||
def resequence_and_append( | ||
self, | ||
target_buffer: "AgentBuffer", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have this magic string here ? Will it cause problems if a user names a brain AgentBuffer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't a magic string - it's a type annotation :P -> https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html#miscellaneous |
||
key_list: List[str] = None, | ||
batch_size: int = None, | ||
training_length: int = None, | ||
) -> None: | ||
""" | ||
Takes in a batch size and training length (sequence length), and appends this AgentBuffer to target_buffer | ||
properly padded for LSTM use. Optionally, use key_list to restrict which fields are inserted into the new | ||
buffer. | ||
:param target_buffer: The buffer which to append the samples to. | ||
:param key_list: The fields that must be added. If None: all fields will be appended. | ||
:param batch_size: The number of elements that must be appended. If None: All of them will be. | ||
:param training_length: The length of the samples that must be appended. If None: only takes one element. | ||
""" | ||
if key_list is None: | ||
key_list = list(self.keys()) | ||
if not self.check_length(key_list): | ||
raise BufferException( | ||
"The length of the fields {0} were not of same length".format(key_list) | ||
) | ||
for field_key in key_list: | ||
target_buffer[field_key].extend( | ||
self[field_key].get_batch( | ||
batch_size=batch_size, training_length=training_length | ||
) | ||
) | ||
|
||
@property | ||
def num_experiences(self) -> int: | ||
""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.