-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Add stats reporter class and re-enable missing stats #3076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add stats reporter class and re-enable missing stats #3076
Conversation
# Note: this is needed until we switch to AgentExperiences as the data input type. | ||
# We still need some info from the policy (memories, previous actions) | ||
# that really should be gathered by the env-manager. | ||
self.policy = policy | ||
self.episode_steps: Dict[str, int] = {} | ||
self.episode_steps: Counter = Counter() | ||
self.episode_rewards: Dict[str, float] = defaultdict(lambda: 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: defaultdict(float)
is more common I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
ml-agents/mlagents/trainers/stats.py
Outdated
for writer in self.writers: | ||
writer.write_text(category, text, step) | ||
|
||
def get_mean_stat(self, category: str, key: str) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about combining get_mean_stat, get_std_stat, and get_num_stats into something like get_summary_stats()
that returns a NamedTuple with count, mean, and stddev. I think that would clean up the usage in write_summary() a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
stats.stats_reporter.add_stat( | ||
self.summary_path, | ||
self.policy.reward_signals[name].value_name, | ||
np.mean(v), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is v always non-empty? Do you need to guard against NaNs anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess it's same behavior as before...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v shouldn't be NaN unless there are NaNs in the network (dun dun dun)
self.stats[self.policy.reward_signals[name].stat_name].append( | ||
rewards.get(agent_id, 0) | ||
stats.stats_reporter.add_stat( | ||
self.summary_path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels a little weird that the "category" here is a filepath. The "category" seems like something that should just be the filename / behavior name, whereas the base bath should be something used to configure the Tensorboard writer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple more comments and a request for tests.
ml-agents/mlagents/trainers/stats.py
Outdated
""" | ||
for key in self.stats_dict[category]: | ||
if len(self.stats_dict[category][key]) > 0: | ||
stat_mean = float(np.mean(self.stats_dict[category][key])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love that we won't have any method for logging min/max values via this interface. Not sure I have a great solution for this at the moment, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, that's a good idea, I think it might be worth adding.
I'm inclined to get the StatsReporter interface installed into the code and then work on what goes behind it in future PRs - currently it's blocking a larger trainer refactor that's in turn blocking another trainer refactor :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've restructured the StatsWriter to share static writers but have different instances per trainer, and added tests.
trainer: Trainer, | ||
policy: TFPolicy, | ||
max_trajectory_length: int, | ||
stats_category: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is coming from the trainer, why not pass a StatsReporter
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
:param key: The name of the text. | ||
:param input_dict: A dictionary that will be displayed in a table on Tensorboard. | ||
""" | ||
# try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove commented out code if it won't be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. This shouldn't be here at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two minor comments, otherwise LGTM
This PR adds back in certain statistics missing from the AgentProcessor PR (entropy, learning rate) that were only known to the AgentProcessor and not the Trainer.
We do this by creating a global class called a StatsReporter, that takes in a category, key, and float value. This StatsReporter can then write the mean of these values out on command. Currently this is still handled by the Trainer.
The StatsReporter also keeps a list of Writer classes - currently, we only have a Tensorboard writer but we can imagine adding more in the future (e.g. REST API writer, CSV writer).
Note: Why is this a PR to the AgentProcessor PR and not to Master? The AgentProcessor marks the first time we have multiple sources for stats, and thus requires that certain stats (related to Policy inference, e.g. reward, entropy, episode steps) come from a different place than others (related to Policy training, e.g. loss).