Skip to content

Commit ef03bf6

Browse files
author
Chris Elion
authored
setuptools-based plugin for StatsWriters (#4788)
1 parent 2352955 commit ef03bf6

File tree

16 files changed

+234
-29
lines changed

16 files changed

+234
-29
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ jobs:
4545
python -m pip install --progress-bar=off -e ./ml-agents
4646
python -m pip install --progress-bar=off -r test_requirements.txt
4747
python -m pip install --progress-bar=off -e ./gym-unity
48+
python -m pip install --progress-bar=off -e ./ml-agents-plugin-examples
4849
- name: Save python dependencies
4950
run: |
5051
pip freeze > pip_versions-${{ matrix.python-version }}.txt

com.unity.ml-agents/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ and this project adheres to
1212
#### com.unity.ml-agents (C#)
1313
#### ml-agents / ml-agents-envs / gym-unity (Python)
1414
- TensorFlow trainers have been removed, please use the Torch trainers instead. (#4707)
15+
- A plugin system for `mlagents-learn` has been added. You can now define custom
16+
`StatsWriter` implementations and register them to be called during training.
17+
More types of plugins will be added in the future. (#4788)
1518

1619
### Minor Changes
1720
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

docs/Training-Plugins.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Customizing Training via Plugins
2+
3+
ML-Agents provides support for running your own python implementations of specific interfaces during the training
4+
process. These interfaces are currently fairly limited, but will be expanded in the future.
5+
6+
** Note ** Plugin interfaces should currently be considered "in beta", and they may change in future releases.
7+
8+
## How to Write Your Own Plugin
9+
[This video](https://www.youtube.com/watch?v=fY3Y_xPKWNA) explains the basics of how to create a plugin system using
10+
setuptools, and is the same approach that ML-Agents' plugin system is based on.
11+
12+
The `ml-agents-plugin-examples` directory contains a reference implementation of each plugin interface, so it's a good
13+
starting point.
14+
15+
### setup.py
16+
If you don't already have a `setup.py` file for your python code, you'll need to add one. `ml-agents-plugin-examples`
17+
has a [minimal example](../ml-agents-plugin-examples/setup.py) of this.
18+
19+
In the call to `setup()`, you'll need to add to the `entry_points` dictionary for each plugin interface that you
20+
implement. The form of this is `{entry point name}={plugin module}:{plugin function}`. For example, in
21+
`ml-agents-plugin-examples`:
22+
```python
23+
entry_points={
24+
ML_AGENTS_STATS_WRITER: [
25+
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer"
26+
]
27+
}
28+
```
29+
* `ML_AGENTS_STATS_WRITER` (which is a string constant, `mlagents.stats_writer`) is the name of the plugin interface.
30+
This must be one of the provided interfaces ([see below](#plugin-interfaces)).
31+
* `example` is the plugin implementation name. This can be anything.
32+
* `mlagents_plugin_examples.example_stats_writer` is the plugin module. This points to the module where the
33+
plugin registration function is defined.
34+
* `get_example_stats_writer` is the plugin registration function. This is called when running `mlagents-learn`. The
35+
arguments and expected return type for this are different for each plugin interface.
36+
37+
### Local Installation
38+
Once you've defined `entry_points` in your `setup.py`, you will need to run
39+
```
40+
pip install -e [path to your plugin code]
41+
```
42+
in the same python virtual environment that you have `mlagents` installed.
43+
44+
## Plugin Interfaces
45+
46+
### StatsWriter
47+
The StatsWriter class receives various information from the training process, such as the average Agent reward in
48+
each summary period. By default, we log this information to the console and write it to
49+
[TensorBoard](Using-Tensorboard.md).
50+
51+
#### Interface
52+
The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter,
53+
which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with
54+
string keys.
55+
56+
#### Registration
57+
The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An
58+
example implementation is provided in [`mlagents_plugin_examples`](../ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py)

ml-agents-plugin-examples/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ML-Agents Plugins
2+
3+
See the [Plugins documentation](../docs/Training-Plugins.md) for more information.

ml-agents-plugin-examples/mlagents_plugin_examples/__init__.py

Whitespace-only changes.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Dict, List
2+
from mlagents.trainers.settings import RunOptions
3+
from mlagents.trainers.stats import StatsWriter, StatsSummary
4+
5+
6+
class ExampleStatsWriter(StatsWriter):
7+
"""
8+
Example implementation of the StatsWriter abstract class.
9+
This doesn't do anything interesting, just prints the stats that it gets.
10+
"""
11+
12+
def write_stats(
13+
self, category: str, values: Dict[str, StatsSummary], step: int
14+
) -> None:
15+
print(f"ExampleStatsWriter category: {category} values: {values}")
16+
17+
18+
def get_example_stats_writer(run_options: RunOptions) -> List[StatsWriter]:
19+
"""
20+
Registration function. This is referenced in setup.py and will
21+
be called by mlagents-learn when it starts to determine the
22+
list of StatsWriters to use.
23+
24+
It must return a list of StatsWriters.
25+
"""
26+
print("Creating a new stats writer! This is so exciting!")
27+
return [ExampleStatsWriter()]

ml-agents-plugin-examples/mlagents_plugin_examples/tests/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
3+
from mlagents.plugins.stats_writer import register_stats_writer_plugins
4+
from mlagents.trainers.settings import RunOptions
5+
6+
from mlagents_plugin_examples.example_stats_writer import ExampleStatsWriter
7+
8+
9+
@pytest.mark.check_environment_trains
10+
def test_register_stats_writers():
11+
# Make sure that the ExampleStatsWriter gets returned from the list of all StatsWriters
12+
stats_writers = register_stats_writer_plugins(RunOptions())
13+
assert any(isinstance(sw, ExampleStatsWriter) for sw in stats_writers)

ml-agents-plugin-examples/setup.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from setuptools import setup
2+
from mlagents.plugins import ML_AGENTS_STATS_WRITER
3+
4+
setup(
5+
name="mlagents_plugin_examples",
6+
version="0.0.1",
7+
# Example of how to add your own registration functions that will be called
8+
# by mlagents-learn.
9+
#
10+
# Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py
11+
# will get registered with the ML_AGENTS_STATS_WRITER plugin interface.
12+
entry_points={
13+
ML_AGENTS_STATS_WRITER: [
14+
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer"
15+
]
16+
},
17+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ML_AGENTS_STATS_WRITER = "mlagents.stats_writer"
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import sys
2+
from typing import List
3+
4+
# importlib.metadata is new in python3.8
5+
# We use the backport for older python versions.
6+
if sys.version_info < (3, 8):
7+
import importlib_metadata
8+
else:
9+
import importlib.metadata as importlib_metadata # pylint: disable=E0611
10+
11+
from mlagents.trainers.stats import StatsWriter
12+
13+
from mlagents_envs import logging_util
14+
from mlagents.plugins import ML_AGENTS_STATS_WRITER
15+
from mlagents.trainers.settings import RunOptions
16+
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter
17+
18+
19+
logger = logging_util.get_logger(__name__)
20+
21+
22+
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
23+
"""
24+
The StatsWriters that mlagents-learn always uses:
25+
* A TensorboardWriter to write information to TensorBoard
26+
* A GaugeWriter to record our internal stats
27+
* A ConsoleWriter to output to stdout.
28+
"""
29+
checkpoint_settings = run_options.checkpoint_settings
30+
return [
31+
TensorboardWriter(
32+
checkpoint_settings.write_path,
33+
clear_past_data=not checkpoint_settings.resume,
34+
),
35+
GaugeWriter(),
36+
ConsoleWriter(),
37+
]
38+
39+
40+
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
41+
"""
42+
Registers all StatsWriter plugins (including the default one),
43+
and evaluates them, and returns the list of all the StatsWriter implementations.
44+
"""
45+
all_stats_writers: List[StatsWriter] = []
46+
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]
47+
48+
for entry_point in entry_points:
49+
50+
try:
51+
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
52+
plugin_func = entry_point.load()
53+
plugin_stats_writers = plugin_func(run_options)
54+
logger.debug(
55+
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
56+
)
57+
all_stats_writers += plugin_stats_writers
58+
except BaseException:
59+
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
60+
logger.exception(
61+
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
62+
)
63+
return all_stats_writers

ml-agents/mlagents/trainers/cli_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ def _create_parser() -> argparse.ArgumentParser:
189189
action=RaiseRemovedWarning,
190190
help="(Removed) Use the TensorFlow framework.",
191191
)
192+
argparser.add_argument(
193+
"--results-dir", default="results", help="Results base directory"
194+
)
192195

193196
eng_conf = argparser.add_argument_group(title="Engine Configuration")
194197
eng_conf.add_argument(

ml-agents/mlagents/trainers/learn.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@
1414
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
1515
from mlagents.trainers.trainer import TrainerFactory
1616
from mlagents.trainers.directory_utils import validate_existing_directories
17-
from mlagents.trainers.stats import (
18-
TensorboardWriter,
19-
StatsReporter,
20-
GaugeWriter,
21-
ConsoleWriter,
22-
)
17+
from mlagents.trainers.stats import StatsReporter
2318
from mlagents.trainers.cli_utils import parser
2419
from mlagents_envs.environment import UnityEnvironment
2520
from mlagents.trainers.settings import RunOptions
@@ -34,6 +29,7 @@
3429
add_metadata as add_timer_metadata,
3530
)
3631
from mlagents_envs import logging_util
32+
from mlagents.plugins.stats_writer import register_stats_writer_plugins
3733

3834
logger = logging_util.get_logger(__name__)
3935

@@ -65,21 +61,15 @@ def run_training(run_seed: int, options: RunOptions) -> None:
6561
checkpoint_settings = options.checkpoint_settings
6662
env_settings = options.env_settings
6763
engine_settings = options.engine_settings
68-
base_path = "results"
69-
write_path = os.path.join(base_path, checkpoint_settings.run_id)
70-
maybe_init_path = (
71-
os.path.join(base_path, checkpoint_settings.initialize_from)
72-
if checkpoint_settings.initialize_from is not None
73-
else None
74-
)
75-
run_logs_dir = os.path.join(write_path, "run_logs")
64+
65+
run_logs_dir = checkpoint_settings.run_logs_dir
7666
port: Optional[int] = env_settings.base_port
7767
# Check if directory exists
7868
validate_existing_directories(
79-
write_path,
69+
checkpoint_settings.write_path,
8070
checkpoint_settings.resume,
8171
checkpoint_settings.force,
82-
maybe_init_path,
72+
checkpoint_settings.maybe_init_path,
8373
)
8474
# Make run logs directory
8575
os.makedirs(run_logs_dir, exist_ok=True)
@@ -90,14 +80,9 @@ def run_training(run_seed: int, options: RunOptions) -> None:
9080
)
9181

9282
# Configure Tensorboard Writers and StatsReporter
93-
tb_writer = TensorboardWriter(
94-
write_path, clear_past_data=not checkpoint_settings.resume
95-
)
96-
gauge_write = GaugeWriter()
97-
console_writer = ConsoleWriter()
98-
StatsReporter.add_writer(tb_writer)
99-
StatsReporter.add_writer(gauge_write)
100-
StatsReporter.add_writer(console_writer)
83+
stats_writers = register_stats_writer_plugins(options)
84+
for sw in stats_writers:
85+
StatsReporter.add_writer(sw)
10186

10287
if env_settings.env_path is None:
10388
port = None
@@ -117,18 +102,18 @@ def run_training(run_seed: int, options: RunOptions) -> None:
117102

118103
trainer_factory = TrainerFactory(
119104
trainer_config=options.behaviors,
120-
output_path=write_path,
105+
output_path=checkpoint_settings.write_path,
121106
train_model=not checkpoint_settings.inference,
122107
load_model=checkpoint_settings.resume,
123108
seed=run_seed,
124109
param_manager=env_parameter_manager,
125-
init_path=maybe_init_path,
110+
init_path=checkpoint_settings.maybe_init_path,
126111
multi_gpu=False,
127112
)
128113
# Create controller and begin training.
129114
tc = TrainerController(
130115
trainer_factory,
131-
write_path,
116+
checkpoint_settings.write_path,
132117
checkpoint_settings.run_id,
133118
env_parameter_manager,
134119
not checkpoint_settings.inference,
@@ -140,7 +125,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
140125
tc.start_learning(env_manager)
141126
finally:
142127
env_manager.close()
143-
write_run_options(write_path, options)
128+
write_run_options(checkpoint_settings.write_path, options)
144129
write_timing_tree(run_logs_dir)
145130
write_training_status(run_logs_dir)
146131

ml-agents/mlagents/trainers/settings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os.path
12
import warnings
23

34
import attr
@@ -706,6 +707,23 @@ class CheckpointSettings:
706707
force: bool = parser.get_default("force")
707708
train_model: bool = parser.get_default("train_model")
708709
inference: bool = parser.get_default("inference")
710+
results_dir: str = parser.get_default("results_dir")
711+
712+
@property
713+
def write_path(self) -> str:
714+
return os.path.join(self.results_dir, self.run_id)
715+
716+
@property
717+
def maybe_init_path(self) -> Optional[str]:
718+
return (
719+
os.path.join(self.results_dir, self.initialize_from)
720+
if self.initialize_from is not None
721+
else None
722+
)
723+
724+
@property
725+
def run_logs_dir(self) -> str:
726+
return os.path.join(self.write_path, "run_logs")
709727

710728

711729
@attr.s(auto_attribs=True)

ml-agents/mlagents/trainers/stats.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ class StatsWriter(abc.ABC):
8787
def write_stats(
8888
self, category: str, values: Dict[str, StatsSummary], step: int
8989
) -> None:
90+
"""
91+
Callback to record training information
92+
:param category: Category of the statistics. Usually this is the behavior name.
93+
:param values: Dictionary of statistics.
94+
:param step: The current training step.
95+
:return:
96+
"""
9097
pass
9198

9299
def add_property(

ml-agents/setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from setuptools import setup, find_packages
55
from setuptools.command.install import install
6+
from mlagents.plugins import ML_AGENTS_STATS_WRITER
67
import mlagents.trainers
78

89
VERSION = mlagents.trainers.__version__
@@ -71,13 +72,18 @@ def run(self):
7172
"cattrs>=1.0.0,<1.1.0",
7273
"attrs>=19.3.0",
7374
'pypiwin32==223;platform_system=="Windows"',
75+
"importlib_metadata; python_version<'3.8'",
7476
],
7577
python_requires=">=3.6.1",
7678
entry_points={
7779
"console_scripts": [
7880
"mlagents-learn=mlagents.trainers.learn:main",
7981
"mlagents-run-experiment=mlagents.trainers.run_experiment:main",
80-
]
82+
],
83+
# Plugins - each plugin type should have an entry here for the default behavior
84+
ML_AGENTS_STATS_WRITER: [
85+
"default=mlagents.plugins.stats_writer:get_default_stats_writers"
86+
],
8187
},
8288
cmdclass={"verify": VerifyVersionCommand},
8389
)

0 commit comments

Comments
 (0)