|
1 |
| -from typing import List, Union |
| 1 | +import sys |
| 2 | +from typing import List, Dict |
| 3 | +from collections import defaultdict, Counter |
2 | 4 |
|
3 |
| -from mlagents.trainers.buffer import AgentBuffer, BufferException |
| 5 | +from mlagents.trainers.trainer import Trainer |
| 6 | +from mlagents.trainers.trajectory import Trajectory, AgentExperience |
| 7 | +from mlagents.trainers.brain import BrainInfo |
| 8 | +from mlagents.trainers.tf_policy import TFPolicy |
| 9 | +from mlagents.trainers.action_info import ActionInfoOutputs |
| 10 | +from mlagents.trainers.stats import StatsReporter |
4 | 11 |
|
5 | 12 |
|
6 |
| -class ProcessingBuffer(dict): |
| 13 | +class AgentProcessor: |
7 | 14 | """
|
8 |
| - ProcessingBuffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id. |
| 15 | + AgentProcessor contains a dictionary per-agent trajectory buffers. The buffers are indexed by agent_id. |
| 16 | + Buffer also contains an update_buffer that corresponds to the buffer used when updating the model. |
| 17 | + One AgentProcessor should be created per agent group. |
9 | 18 | """
|
10 | 19 |
|
11 |
| - def __str__(self): |
12 |
| - return "local_buffers :\n{0}".format( |
13 |
| - "\n".join(["\tagent {0} :{1}".format(k, str(self[k])) for k in self.keys()]) |
14 |
| - ) |
15 |
| - |
16 |
| - def __getitem__(self, key): |
17 |
| - if key not in self.keys(): |
18 |
| - self[key] = AgentBuffer() |
19 |
| - return super().__getitem__(key) |
20 |
| - |
21 |
| - def reset_local_buffers(self) -> None: |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + trainer: Trainer, |
| 23 | + policy: TFPolicy, |
| 24 | + stats_reporter: StatsReporter, |
| 25 | + max_trajectory_length: int = sys.maxsize, |
| 26 | + ): |
22 | 27 | """
|
23 |
| - Resets all the local AgentBuffers. |
| 28 | + Create an AgentProcessor. |
| 29 | + :param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory |
| 30 | + when it is finished. |
| 31 | + :param policy: Policy instance associated with this AgentProcessor. |
| 32 | + :param max_trajectory_length: Maximum length of a trajectory before it is added to the trainer. |
| 33 | + :param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. |
24 | 34 | """
|
25 |
| - for buf in self.values(): |
26 |
| - buf.reset_agent() |
| 35 | + self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) |
| 36 | + self.last_brain_info: Dict[str, BrainInfo] = {} |
| 37 | + self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} |
| 38 | + # Note: this is needed until we switch to AgentExperiences as the data input type. |
| 39 | + # We still need some info from the policy (memories, previous actions) |
| 40 | + # that really should be gathered by the env-manager. |
| 41 | + self.policy = policy |
| 42 | + self.episode_steps: Counter = Counter() |
| 43 | + self.episode_rewards: Dict[str, float] = defaultdict(float) |
| 44 | + self.stats_reporter = stats_reporter |
| 45 | + self.trainer = trainer |
| 46 | + self.max_trajectory_length = max_trajectory_length |
27 | 47 |
|
28 |
| - def append_to_update_buffer( |
| 48 | + def add_experiences( |
29 | 49 | self,
|
30 |
| - update_buffer: AgentBuffer, |
31 |
| - agent_id: Union[int, str], |
32 |
| - key_list: List[str] = None, |
33 |
| - batch_size: int = None, |
34 |
| - training_length: int = None, |
| 50 | + curr_info: BrainInfo, |
| 51 | + next_info: BrainInfo, |
| 52 | + take_action_outputs: ActionInfoOutputs, |
35 | 53 | ) -> None:
|
36 | 54 | """
|
37 |
| - Appends the buffer of an agent to the update buffer. |
38 |
| - :param update_buffer: A reference to an AgentBuffer to append the agent's buffer to |
39 |
| - :param agent_id: The id of the agent which data will be appended |
40 |
| - :param key_list: The fields that must be added. If None: all fields will be appended. |
41 |
| - :param batch_size: The number of elements that must be appended. If None: All of them will be. |
42 |
| - :param training_length: The length of the samples that must be appended. If None: only takes one element. |
| 55 | + Adds experiences to each agent's experience history. |
| 56 | + :param curr_info: current BrainInfo. |
| 57 | + :param next_info: next BrainInfo. |
| 58 | + :param take_action_outputs: The outputs of the Policy's get_action method. |
43 | 59 | """
|
44 |
| - if key_list is None: |
45 |
| - key_list = self[agent_id].keys() |
46 |
| - if not self[agent_id].check_length(key_list): |
47 |
| - raise BufferException( |
48 |
| - "The length of the fields {0} for agent {1} were not of same length".format( |
49 |
| - key_list, agent_id |
50 |
| - ) |
| 60 | + if take_action_outputs: |
| 61 | + self.stats_reporter.add_stat( |
| 62 | + "Policy/Entropy", take_action_outputs["entropy"].mean() |
51 | 63 | )
|
52 |
| - for field_key in key_list: |
53 |
| - update_buffer[field_key].extend( |
54 |
| - self[agent_id][field_key].get_batch( |
55 |
| - batch_size=batch_size, training_length=training_length |
56 |
| - ) |
| 64 | + self.stats_reporter.add_stat( |
| 65 | + "Policy/Learning Rate", take_action_outputs["learning_rate"] |
57 | 66 | )
|
58 | 67 |
|
59 |
| - def append_all_agent_batch_to_update_buffer( |
60 |
| - self, |
61 |
| - update_buffer: AgentBuffer, |
62 |
| - key_list: List[str] = None, |
63 |
| - batch_size: int = None, |
64 |
| - training_length: int = None, |
65 |
| - ) -> None: |
66 |
| - """ |
67 |
| - Appends the buffer of all agents to the update buffer. |
68 |
| - :param key_list: The fields that must be added. If None: all fields will be appended. |
69 |
| - :param batch_size: The number of elements that must be appended. If None: All of them will be. |
70 |
| - :param training_length: The length of the samples that must be appended. If None: only takes one element. |
71 |
| - """ |
72 |
| - for agent_id in self.keys(): |
73 |
| - self.append_to_update_buffer( |
74 |
| - update_buffer, agent_id, key_list, batch_size, training_length |
75 |
| - ) |
| 68 | + for agent_id in curr_info.agents: |
| 69 | + self.last_brain_info[agent_id] = curr_info |
| 70 | + self.last_take_action_outputs[agent_id] = take_action_outputs |
| 71 | + |
| 72 | + # Store the environment reward |
| 73 | + tmp_environment_reward = next_info.rewards |
| 74 | + |
| 75 | + for next_idx, agent_id in enumerate(next_info.agents): |
| 76 | + stored_info = self.last_brain_info.get(agent_id, None) |
| 77 | + if stored_info is not None: |
| 78 | + stored_take_action_outputs = self.last_take_action_outputs[agent_id] |
| 79 | + idx = stored_info.agents.index(agent_id) |
| 80 | + obs = [] |
| 81 | + if not stored_info.local_done[idx]: |
| 82 | + for i, _ in enumerate(stored_info.visual_observations): |
| 83 | + obs.append(stored_info.visual_observations[i][idx]) |
| 84 | + if self.policy.use_vec_obs: |
| 85 | + obs.append(stored_info.vector_observations[idx]) |
| 86 | + if self.policy.use_recurrent: |
| 87 | + memory = self.policy.retrieve_memories([agent_id])[0, :] |
| 88 | + else: |
| 89 | + memory = None |
| 90 | + |
| 91 | + done = next_info.local_done[next_idx] |
| 92 | + max_step = next_info.max_reached[next_idx] |
| 93 | + |
| 94 | + # Add the outputs of the last eval |
| 95 | + action = stored_take_action_outputs["action"][idx] |
| 96 | + if self.policy.use_continuous_act: |
| 97 | + action_pre = stored_take_action_outputs["pre_action"][idx] |
| 98 | + else: |
| 99 | + action_pre = None |
| 100 | + action_probs = stored_take_action_outputs["log_probs"][idx] |
| 101 | + action_masks = stored_info.action_masks[idx] |
| 102 | + prev_action = self.policy.retrieve_previous_action([agent_id])[0, :] |
| 103 | + |
| 104 | + experience = AgentExperience( |
| 105 | + obs=obs, |
| 106 | + reward=tmp_environment_reward[next_idx], |
| 107 | + done=done, |
| 108 | + action=action, |
| 109 | + action_probs=action_probs, |
| 110 | + action_pre=action_pre, |
| 111 | + action_mask=action_masks, |
| 112 | + prev_action=prev_action, |
| 113 | + max_step=max_step, |
| 114 | + memory=memory, |
| 115 | + ) |
| 116 | + # Add the value outputs if needed |
| 117 | + self.experience_buffers[agent_id].append(experience) |
| 118 | + self.episode_rewards[agent_id] += tmp_environment_reward[next_idx] |
| 119 | + if ( |
| 120 | + next_info.local_done[next_idx] |
| 121 | + or ( |
| 122 | + len(self.experience_buffers[agent_id]) |
| 123 | + >= self.max_trajectory_length |
| 124 | + ) |
| 125 | + ) and len(self.experience_buffers[agent_id]) > 0: |
| 126 | + # Make next AgentExperience |
| 127 | + next_obs = [] |
| 128 | + for i, _ in enumerate(next_info.visual_observations): |
| 129 | + next_obs.append(next_info.visual_observations[i][next_idx]) |
| 130 | + if self.policy.use_vec_obs: |
| 131 | + next_obs.append(next_info.vector_observations[next_idx]) |
| 132 | + trajectory = Trajectory( |
| 133 | + steps=self.experience_buffers[agent_id], |
| 134 | + agent_id=agent_id, |
| 135 | + next_obs=next_obs, |
| 136 | + ) |
| 137 | + # This will eventually be replaced with a queue |
| 138 | + self.trainer.process_trajectory(trajectory) |
| 139 | + self.experience_buffers[agent_id] = [] |
| 140 | + if next_info.local_done[next_idx]: |
| 141 | + self.stats_reporter.add_stat( |
| 142 | + "Environment/Cumulative Reward", |
| 143 | + self.episode_rewards.get(agent_id, 0), |
| 144 | + ) |
| 145 | + self.stats_reporter.add_stat( |
| 146 | + "Environment/Episode Length", |
| 147 | + self.episode_steps.get(agent_id, 0), |
| 148 | + ) |
| 149 | + del self.episode_steps[agent_id] |
| 150 | + del self.episode_rewards[agent_id] |
| 151 | + elif not next_info.local_done[next_idx]: |
| 152 | + self.episode_steps[agent_id] += 1 |
| 153 | + self.policy.save_previous_action( |
| 154 | + curr_info.agents, take_action_outputs["action"] |
| 155 | + ) |
0 commit comments