-
Notifications
You must be signed in to change notification settings - Fork 4.3k
setuptools-based plugin for StatsWriters #4788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
76b43ae
c9be196
ca5b56e
d597c74
9db6da2
58a22c3
f138c48
5db6efe
b25b2e8
463be02
be038b7
cf9b4aa
a3de6cd
c72f245
f5b0854
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Customizing Training via Plugins | ||
|
||
ML-Agents provides support for running your own python implementations of specific interfaces during the training | ||
process. These interfaces are currently fairly limited, but will be expanded in the future. | ||
|
||
** Note ** Plugin interfaces should currently be considered "in beta", and they may change in future releases. | ||
|
||
## How to Write Your Own Plugin | ||
[This video](https://www.youtube.com/watch?v=fY3Y_xPKWNA) explains the basics of how to create a plugin system using | ||
setuptools, and is the same approach that ML-Agents' plugin system is based on. | ||
|
||
The `ml-agents-plugin-examples` directory contains a reference implementation of each plugin interface, so it's a good | ||
starting point. | ||
|
||
### setup.py | ||
If you don't already have a `setup.py` file for your python code, you'll need to add one. `ml-agents-plugin-examples` | ||
has a [minimal example](../ml-agents-plugin-examples/setup.py) of this. | ||
|
||
In the call to `setup()`, you'll need to add to the `entry_points` dictionary for each plugin interface that you | ||
implement. The form of this is `{entry point name}={plugin module}:{plugin function}`. For example, in | ||
`ml-agents-plugin-examples`: | ||
```python | ||
entry_points={ | ||
ML_AGENTS_STATS_WRITER: [ | ||
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" | ||
] | ||
} | ||
``` | ||
* `ML_AGENTS_STATS_WRITER` (which is a string constant, `mlagents.stats_writer`) is the name of the plugin interface. | ||
This must be one of the provided interfaces ([see below](#plugin-interfaces)). | ||
* `example` is the plugin implementation name. This can be anything. | ||
* `mlagents_plugin_examples.example_stats_writer` is the plugin module. This points to the module where the | ||
plugin registration function is defined. | ||
* `get_example_stats_writer` is the plugin registration function. This is called when running `mlagents-learn`. The | ||
arguments and expected return type for this are different for each plugin interface. | ||
|
||
### Local Installation | ||
Once you've defined `entry_points` in your `setup.py`, you will need to run | ||
``` | ||
pip install -e [path to your plugin code] | ||
``` | ||
in the same python virtual environment that you have `mlagents` installed. | ||
|
||
## Plugin Interfaces | ||
|
||
### StatsWriter | ||
The StatsWriter class receives various information from the training process, such as the average Agent reward in | ||
each summary period. By default, we log this information to the console and write it to | ||
[TensorBoard](Using-Tensorboard.md). | ||
|
||
#### Interface | ||
The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter, | ||
which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with | ||
string keys. | ||
|
||
#### Registration | ||
The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An | ||
example implementation is provided in [`mlagents_plugin_examples`](../ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# ML-Agents Plugins | ||
|
||
See the [Plugins documentation](../docs/Training-Plugins.md) for more information. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import Dict, List | ||
from mlagents.trainers.settings import RunOptions | ||
from mlagents.trainers.stats import StatsWriter, StatsSummary | ||
|
||
|
||
class ExampleStatsWriter(StatsWriter): | ||
""" | ||
Example implementation of the StatsWriter abstract class. | ||
This doesn't do anything interesting, just prints the stats that it gets. | ||
""" | ||
|
||
def write_stats( | ||
self, category: str, values: Dict[str, StatsSummary], step: int | ||
) -> None: | ||
print(f"ExampleStatsWriter category: {category} values: {values}") | ||
|
||
|
||
def get_example_stats_writer(run_options: RunOptions) -> List[StatsWriter]: | ||
""" | ||
Registration function. This is referenced in setup.py and will | ||
be called by mlagents-learn when it starts to determine the | ||
list of StatsWriters to use. | ||
|
||
It must return a list of StatsWriters. | ||
""" | ||
print("Creating a new stats writer! This is so exciting!") | ||
return [ExampleStatsWriter()] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import pytest | ||
|
||
from mlagents.plugins.stats_writer import register_stats_writer_plugins | ||
from mlagents.trainers.settings import RunOptions | ||
|
||
from mlagents_plugin_examples.example_stats_writer import ExampleStatsWriter | ||
|
||
|
||
@pytest.mark.check_environment_trains | ||
def test_register_stats_writers(): | ||
# Make sure that the ExampleStatsWriter gets returned from the list of all StatsWriters | ||
stats_writers = register_stats_writer_plugins(RunOptions()) | ||
assert any(isinstance(sw, ExampleStatsWriter) for sw in stats_writers) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from setuptools import setup | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No plans to publish this package, but we could use it to set up tests, e.g. for bad imports. |
||
from mlagents.plugins import ML_AGENTS_STATS_WRITER | ||
|
||
setup( | ||
name="mlagents_plugin_examples", | ||
version="0.0.1", | ||
# Example of how to add your own registration functions that will be called | ||
# by mlagents-learn. | ||
# | ||
# Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py | ||
# will get registered with the ML_AGENTS_STATS_WRITER plugin interface. | ||
entry_points={ | ||
ML_AGENTS_STATS_WRITER: [ | ||
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The form of this is |
||
] | ||
}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ML_AGENTS_STATS_WRITER = "mlagents.stats_writer" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import sys | ||
from typing import List | ||
|
||
# importlib.metadata is new in python3.8 | ||
# We use the backport for older python versions. | ||
if sys.version_info < (3, 8): | ||
import importlib_metadata | ||
else: | ||
import importlib.metadata as importlib_metadata # pylint: disable=E0611 | ||
|
||
from mlagents.trainers.stats import StatsWriter | ||
|
||
from mlagents_envs import logging_util | ||
from mlagents.plugins import ML_AGENTS_STATS_WRITER | ||
from mlagents.trainers.settings import RunOptions | ||
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter | ||
|
||
|
||
logger = logging_util.get_logger(__name__) | ||
|
||
|
||
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: | ||
""" | ||
The StatsWriters that mlagents-learn always uses: | ||
* A TensorboardWriter to write information to TensorBoard | ||
* A GaugeWriter to record our internal stats | ||
* A ConsoleWriter to output to stdout. | ||
""" | ||
checkpoint_settings = run_options.checkpoint_settings | ||
return [ | ||
TensorboardWriter( | ||
checkpoint_settings.write_path, | ||
clear_past_data=not checkpoint_settings.resume, | ||
), | ||
GaugeWriter(), | ||
ConsoleWriter(), | ||
] | ||
|
||
|
||
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]: | ||
""" | ||
Registers all StatsWriter plugins (including the default one), | ||
and evaluates them, and returns the list of all the StatsWriter implementations. | ||
""" | ||
all_stats_writers: List[StatsWriter] = [] | ||
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER] | ||
|
||
for entry_point in entry_points: | ||
|
||
try: | ||
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}") | ||
plugin_func = entry_point.load() | ||
plugin_stats_writers = plugin_func(run_options) | ||
logger.debug( | ||
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}" | ||
) | ||
all_stats_writers += plugin_stats_writers | ||
except BaseException: | ||
chriselion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things. | ||
logger.exception( | ||
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used." | ||
) | ||
return all_stats_writers |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -189,6 +189,9 @@ def _create_parser() -> argparse.ArgumentParser: | |||
action=RaiseRemovedWarning, | ||||
help="(Removed) Use the TensorFlow framework.", | ||||
) | ||||
argparser.add_argument( | ||||
"--results-dir", default="results", help="Results base directory" | ||||
) | ||||
Comment on lines
+192
to
+194
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this relevant to this PR? Why are we adding this flag ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was previously hardcoded in learn.py and is now used in CheckpointSettings.maybe_init_path() I can keep it hardcoded if you prefer, but I do think moving all the path logic to part of RunOptions is a win (and basically a pre-req to make TensorBoardWriter act like a plugin) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a user changes this from the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a user can set it in the run_options then yes this could mess with cloud since we don't already take it into account. I don't think it is a reason not to make the change though. We can add logic to overwrite this field since we already modify run_options. Just a couple of questions to confirm my understanding:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, we can check for it and only overwrite it if already set then. When you check this in can you file a ticket against me to make sure we get this added at some point? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do |
||||
|
||||
eng_conf = argparser.add_argument_group(title="Engine Configuration") | ||||
eng_conf.add_argument( | ||||
|
Uh oh!
There was an error while loading. Please reload this page.