Skip to content

Commit 0f08718

Browse files
author
Ervin T
authored
Add stats reporter class and re-enable missing stats (#3076)
1 parent 6a1f275 commit 0f08718

18 files changed

+306
-104
lines changed

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mlagents.trainers.brain import BrainInfo
88
from mlagents.trainers.tf_policy import TFPolicy
99
from mlagents.trainers.action_info import ActionInfoOutputs
10+
from mlagents.trainers.stats import StatsReporter
1011

1112

1213
class AgentProcessor:
@@ -16,24 +17,31 @@ class AgentProcessor:
1617
One AgentProcessor should be created per agent group.
1718
"""
1819

19-
def __init__(self, trainer: Trainer, policy: TFPolicy, max_trajectory_length: int):
20+
def __init__(
21+
self,
22+
trainer: Trainer,
23+
policy: TFPolicy,
24+
max_trajectory_length: int,
25+
stats_reporter: StatsReporter,
26+
):
2027
"""
2128
Create an AgentProcessor.
2229
:param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory
2330
when it is finished.
2431
:param policy: Policy instance associated with this AgentProcessor.
2532
:param max_trajectory_length: Maximum length of a trajectory before it is added to the trainer.
33+
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
2634
"""
2735
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list)
2836
self.last_brain_info: Dict[str, BrainInfo] = {}
2937
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {}
30-
self.stats: Dict[str, List[float]] = defaultdict(list)
3138
# Note: this is needed until we switch to AgentExperiences as the data input type.
3239
# We still need some info from the policy (memories, previous actions)
3340
# that really should be gathered by the env-manager.
3441
self.policy = policy
3542
self.episode_steps: Counter = Counter()
36-
self.episode_rewards: Dict[str, float] = defaultdict(lambda: 0.0)
43+
self.episode_rewards: Dict[str, float] = defaultdict(float)
44+
self.stats_reporter = stats_reporter
3745
if max_trajectory_length:
3846
self.max_trajectory_length = max_trajectory_length
3947
self.ignore_max_length = False
@@ -55,12 +63,12 @@ def add_experiences(
5563
:param take_action_outputs: The outputs of the Policy's get_action method.
5664
"""
5765
if take_action_outputs:
58-
self.stats["Policy/Entropy"].append(take_action_outputs["entropy"].mean())
59-
self.stats["Policy/Learning Rate"].append(
60-
take_action_outputs["learning_rate"]
66+
self.stats_reporter.add_stat(
67+
"Policy/Entropy", take_action_outputs["entropy"].mean()
68+
)
69+
self.stats_reporter.add_stat(
70+
"Policy/Learning Rate", take_action_outputs["learning_rate"]
6171
)
62-
for name, values in take_action_outputs["value_heads"].items():
63-
self.stats[name].append(np.mean(values))
6472

6573
for agent_id in curr_info.agents:
6674
self.last_brain_info[agent_id] = curr_info
@@ -99,7 +107,6 @@ def add_experiences(
99107
action_masks = stored_info.action_masks[idx]
100108
prev_action = self.policy.retrieve_previous_action([agent_id])[0, :]
101109

102-
values = stored_take_action_outputs["value_heads"]
103110
experience = AgentExperience(
104111
obs=obs,
105112
reward=tmp_environment_reward[next_idx],
@@ -114,7 +121,7 @@ def add_experiences(
114121
)
115122
# Add the value outputs if needed
116123
self.experience_buffers[agent_id].append(experience)
117-
124+
self.episode_rewards[agent_id] += tmp_environment_reward[next_idx]
118125
if (
119126
next_info.local_done[next_idx]
120127
or (
@@ -137,9 +144,18 @@ def add_experiences(
137144
# This will eventually be replaced with a queue
138145
self.trainer.process_trajectory(trajectory)
139146
self.experience_buffers[agent_id] = []
147+
if next_info.local_done[next_idx]:
148+
self.stats_reporter.add_stat(
149+
"Environment/Cumulative Reward",
150+
self.episode_rewards.get(agent_id, 0),
151+
)
152+
self.stats_reporter.add_stat(
153+
"Environment/Episode Length",
154+
self.episode_steps.get(agent_id, 0),
155+
)
156+
del self.episode_steps[agent_id]
157+
del self.episode_rewards[agent_id]
140158
elif not next_info.local_done[next_idx]:
141-
if agent_id not in self.episode_steps:
142-
self.episode_steps[agent_id] = 0
143159
self.episode_steps[agent_id] += 1
144160
self.policy.save_previous_action(
145161
curr_info.agents, take_action_outputs["action"]

ml-agents/mlagents/trainers/curriculum.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import json
33
import math
4+
from typing import Dict, Any, TextIO
45

56
from .exception import CurriculumConfigError, CurriculumLoadingError
67

@@ -51,14 +52,14 @@ def __init__(self, location):
5152
)
5253

5354
@property
54-
def lesson_num(self):
55+
def lesson_num(self) -> int:
5556
return self._lesson_num
5657

5758
@lesson_num.setter
58-
def lesson_num(self, lesson_num):
59+
def lesson_num(self, lesson_num: int) -> None:
5960
self._lesson_num = max(0, min(lesson_num, self.max_lesson_num))
6061

61-
def increment_lesson(self, measure_val):
62+
def increment_lesson(self, measure_val: float) -> bool:
6263
"""
6364
Increments the lesson number depending on the progress given.
6465
:param measure_val: Measure of progress (either reward or percentage
@@ -87,7 +88,7 @@ def increment_lesson(self, measure_val):
8788
return True
8889
return False
8990

90-
def get_config(self, lesson=None):
91+
def get_config(self, lesson: int = None) -> Dict[str, Any]:
9192
"""
9293
Returns reset parameters which correspond to the lesson.
9394
:param lesson: The lesson you want to get the config of. If None, the
@@ -106,7 +107,7 @@ def get_config(self, lesson=None):
106107
return config
107108

108109
@staticmethod
109-
def load_curriculum_file(location):
110+
def load_curriculum_file(location: str) -> None:
110111
try:
111112
with open(location) as data_file:
112113
return Curriculum._load_curriculum(data_file)
@@ -120,7 +121,7 @@ def load_curriculum_file(location):
120121
)
121122

122123
@staticmethod
123-
def _load_curriculum(fp):
124+
def _load_curriculum(fp: TextIO) -> None:
124125
try:
125126
return json.load(fp)
126127
except json.decoder.JSONDecodeError as e:

ml-agents/mlagents/trainers/learn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mlagents.trainers.exception import TrainerError
1818
from mlagents.trainers.meta_curriculum import MetaCurriculum
1919
from mlagents.trainers.trainer_util import load_config, TrainerFactory
20+
from mlagents.trainers.stats import TensorboardWriter, StatsReporter
2021
from mlagents.envs.environment import UnityEnvironment
2122
from mlagents.trainers.sampler_class import SamplerManager
2223
from mlagents.trainers.exception import SamplerException
@@ -248,6 +249,11 @@ def run_training(
248249
)
249250
trainer_config = load_config(trainer_config_path)
250251
port = options.base_port + (sub_id * options.num_envs)
252+
253+
# Configure Tensorboard Writers and StatsReporter
254+
tb_writer = TensorboardWriter(summaries_dir)
255+
StatsReporter.add_writer(tb_writer)
256+
251257
if options.env_path is None:
252258
port = 5004 # This is the in Editor Training Port
253259
env_factory = create_environment_factory(

ml-agents/mlagents/trainers/ppo/policy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ def create_model(
104104
{
105105
"action": self.model.output,
106106
"log_probs": self.model.all_log_probs,
107-
"value_heads": self.model.value_heads,
108-
"value": self.model.value,
109107
"entropy": self.model.entropy,
110108
"learning_rate": self.model.learning_rate,
111109
}

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def process_trajectory(self, trajectory: Trajectory) -> None:
9999
)
100100
for name, v in value_estimates.items():
101101
agent_buffer_trajectory["{}_value_estimates".format(name)].extend(v)
102-
self.stats[self.policy.reward_signals[name].value_name].append(np.mean(v))
102+
self.stats_reporter.add_stat(
103+
self.policy.reward_signals[name].value_name, np.mean(v)
104+
)
103105

104106
value_next = self.policy.get_value_estimates(
105107
trajectory.next_obs,
@@ -212,12 +214,12 @@ def update_policy(self):
212214
batch_update_stats[stat_name].append(value)
213215

214216
for stat, stat_list in batch_update_stats.items():
215-
self.stats[stat].append(np.mean(stat_list))
217+
self.stats_reporter.add_stat(stat, np.mean(stat_list))
216218

217219
if self.policy.bc_module:
218220
update_stats = self.policy.bc_module.update()
219221
for stat, val in update_stats.items():
220-
self.stats[stat].append(val)
222+
self.stats_reporter.add_stat(stat, val)
221223
self.clear_update_buffer()
222224
self.trainer_metrics.end_policy_update()
223225

ml-agents/mlagents/trainers/rl_trainer.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,17 @@ def end_episode(self) -> None:
4646
rewards[agent_id] = 0
4747

4848
def _update_end_episode_stats(self, agent_id: str) -> None:
49-
self.stats["Environment/Episode Length"].append(
50-
self.episode_steps.get(agent_id, 0)
51-
)
5249
self.episode_steps[agent_id] = 0
5350
for name, rewards in self.collected_rewards.items():
5451
if name == "environment":
5552
self.cumulative_returns_since_policy_update.append(
5653
rewards.get(agent_id, 0)
5754
)
58-
self.stats["Environment/Cumulative Reward"].append(
59-
rewards.get(agent_id, 0)
60-
)
6155
self.reward_buffer.appendleft(rewards.get(agent_id, 0))
6256
rewards[agent_id] = 0
6357
else:
64-
self.stats[self.policy.reward_signals[name].stat_name].append(
65-
rewards.get(agent_id, 0)
58+
self.stats_reporter.add_stat(
59+
self.policy.reward_signals[name].stat_name, rewards.get(agent_id, 0)
6660
)
6761
rewards[agent_id] = 0
6862

ml-agents/mlagents/trainers/sac/policy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ def create_model(
124124
{
125125
"action": self.model.output,
126126
"log_probs": self.model.all_log_probs,
127-
"value_heads": self.model.value_heads,
128-
"value": self.model.value,
129127
"entropy": self.model.entropy,
130128
"learning_rate": self.model.learning_rate,
131129
}

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def process_trajectory(self, trajectory: Trajectory) -> None:
166166
agent_buffer_trajectory
167167
)
168168
for name, v in value_estimates.items():
169-
self.stats[self.policy.reward_signals[name].value_name].append(np.mean(v))
169+
self.stats_reporter.add_stat(
170+
self.policy.reward_signals[name].value_name, np.mean(v)
171+
)
170172

171173
# Bootstrap using the last step rather than the bootstrap step if max step is reached.
172174
# Set last element to duplicate obs and remove dones.
@@ -258,13 +260,13 @@ def update_sac_policy(self) -> None:
258260
)
259261

260262
for stat, stat_list in batch_update_stats.items():
261-
self.stats[stat].append(np.mean(stat_list))
263+
self.stats_reporter.add_stat(stat, np.mean(stat_list))
262264

263265
bc_module = self.sac_policy.bc_module
264266
if bc_module:
265267
update_stats = bc_module.update()
266268
for stat, val in update_stats.items():
267-
self.stats[stat].append(val)
269+
self.stats_reporter.add_stat(stat, val)
268270

269271
def update_reward_signals(self) -> None:
270272
"""
@@ -299,4 +301,4 @@ def update_reward_signals(self) -> None:
299301
for stat_name, value in update_stats.items():
300302
batch_update_stats[stat_name].append(value)
301303
for stat, stat_list in batch_update_stats.items():
302-
self.stats[stat].append(np.mean(stat_list))
304+
self.stats_reporter.add_stat(stat, np.mean(stat_list))

ml-agents/mlagents/trainers/stats.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from collections import defaultdict
2+
from typing import List, Dict, NamedTuple
3+
import numpy as np
4+
import abc
5+
import os
6+
7+
from mlagents.tf_utils import tf
8+
9+
10+
class StatsWriter(abc.ABC):
11+
"""
12+
A StatsWriter abstract class. A StatsWriter takes in a category, key, scalar value, and step
13+
and writes it out by some method.
14+
"""
15+
16+
@abc.abstractmethod
17+
def write_stats(self, category: str, key: str, value: float, step: int) -> None:
18+
pass
19+
20+
@abc.abstractmethod
21+
def write_text(self, category: str, text: str, step: int) -> None:
22+
pass
23+
24+
25+
class TensorboardWriter(StatsWriter):
26+
def __init__(self, base_dir: str):
27+
self.summary_writers: Dict[str, tf.summary.FileWriter] = {}
28+
self.base_dir: str = base_dir
29+
30+
def write_stats(self, category: str, key: str, value: float, step: int) -> None:
31+
self._maybe_create_summary_writer(category)
32+
summary = tf.Summary()
33+
summary.value.add(tag="{}".format(key), simple_value=value)
34+
self.summary_writers[category].add_summary(summary, step)
35+
self.summary_writers[category].flush()
36+
37+
def _maybe_create_summary_writer(self, category: str) -> None:
38+
if category not in self.summary_writers:
39+
filewriter_dir = "{basedir}/{category}".format(
40+
basedir=self.base_dir, category=category
41+
)
42+
if not os.path.exists(filewriter_dir):
43+
os.makedirs(filewriter_dir)
44+
self.summary_writers[category] = tf.summary.FileWriter(filewriter_dir)
45+
46+
def write_text(self, category: str, text: str, step: int) -> None:
47+
self._maybe_create_summary_writer(category)
48+
self.summary_writers[category].add_summary(text, step)
49+
50+
51+
class StatsSummary(NamedTuple):
52+
mean: float
53+
std: float
54+
num: int
55+
56+
57+
class StatsReporter:
58+
writers: List[StatsWriter] = []
59+
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))
60+
61+
def __init__(self, category):
62+
"""
63+
Generic StatsReporter. A category is the broadest type of storage (would
64+
correspond the run name and trainer name, e.g. 3DBalltest_3DBall. A key is the
65+
type of stat it is (e.g. Environment/Reward). Finally the Value is the float value
66+
attached to this stat.
67+
"""
68+
self.category: str = category
69+
70+
@staticmethod
71+
def add_writer(writer: StatsWriter) -> None:
72+
StatsReporter.writers.append(writer)
73+
74+
def add_stat(self, key: str, value: float) -> None:
75+
"""
76+
Add a float value stat to the StatsReporter.
77+
:param category: The highest categorization of the statistic, e.g. behavior name.
78+
:param key: The type of statistic, e.g. Environment/Reward.
79+
:param value: the value of the statistic.
80+
"""
81+
StatsReporter.stats_dict[self.category][key].append(value)
82+
83+
def write_stats(self, step: int) -> None:
84+
"""
85+
Write out all stored statistics that fall under the category specified.
86+
The currently stored values will be averaged, written out as a single value,
87+
and the buffer cleared.
88+
:param category: The category which to write out the stats.
89+
:param step: Training step which to write these stats as.
90+
"""
91+
for key in StatsReporter.stats_dict[self.category]:
92+
if len(StatsReporter.stats_dict[self.category][key]) > 0:
93+
stat_mean = float(np.mean(StatsReporter.stats_dict[self.category][key]))
94+
for writer in StatsReporter.writers:
95+
writer.write_stats(self.category, key, stat_mean, step)
96+
del StatsReporter.stats_dict[self.category]
97+
98+
def write_text(self, text: str, step: int) -> None:
99+
"""
100+
Write out some text.
101+
:param category: The highest categorization of the statistic, e.g. behavior name.
102+
:param text: The text to write out.
103+
:param step: Training step which to write these stats as.
104+
"""
105+
for writer in StatsReporter.writers:
106+
writer.write_text(self.category, text, step)
107+
108+
def get_stats_summaries(self, key: str) -> StatsSummary:
109+
"""
110+
Get the mean, std, and count of a particular statistic, since last write.
111+
:param category: The highest categorization of the statistic, e.g. behavior name.
112+
:param key: The type of statistic, e.g. Environment/Reward.
113+
:returns: A StatsSummary NamedTuple containing (mean, std, count).
114+
"""
115+
return StatsSummary(
116+
mean=np.mean(StatsReporter.stats_dict[self.category][key]),
117+
std=np.std(StatsReporter.stats_dict[self.category][key]),
118+
num=len(StatsReporter.stats_dict[self.category][key]),
119+
)

0 commit comments

Comments
 (0)