Skip to content

setuptools-based plugin for StatsWriters #4788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 5, 2021
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
python -m pip install --progress-bar=off -e ./ml-agents
python -m pip install --progress-bar=off -r test_requirements.txt
python -m pip install --progress-bar=off -e ./gym-unity
python -m pip install --progress-bar=off -e ./ml-agents-plugin-examples
- name: Save python dependencies
run: |
pip freeze > pip_versions-${{ matrix.python-version }}.txt
Expand Down
3 changes: 3 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- TensorFlow trainers have been removed, please use the Torch trainers instead. (#4707)
- A plugin system for `mlagents-learn` has been added. You can now define custom
`StatsWriter` implementations and register them to be called during training.
More types of plugins will be added in the future. (#4788)

### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
Expand Down
58 changes: 58 additions & 0 deletions docs/Training-Plugins.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Customizing Training via Plugins

ML-Agents provides support for running your own python implementations of specific interfaces during the training
process. These interfaces are currently fairly limited, but will be expanded in the future.

** Note ** Plugin interfaces should currently be considered "in beta", and they may change in future releases.

## How to Write Your Own Plugin
[This video](https://www.youtube.com/watch?v=fY3Y_xPKWNA) explains the basics of how to create a plugin system using
setuptools, and is the same approach that ML-Agents' plugin system is based on.

The `ml-agents-plugin-examples` directory contains a reference implementation of each plugin interface, so it's a good
starting point.

### setup.py
If you don't already have a `setup.py` file for your python code, you'll need to add one. `ml-agents-plugin-examples`
has a [minimal example](../ml-agents-plugin-examples/setup.py) of this.

In the call to `setup()`, you'll need to add to the `entry_points` dictionary for each plugin interface that you
implement. The form of this is `{entry point name}={plugin module}:{plugin function}`. For example, in
`ml-agents-plugin-examples`:
```python
entry_points={
ML_AGENTS_STATS_WRITER: [
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer"
]
}
```
* `ML_AGENTS_STATS_WRITER` (which is a string constant, `mlagents.stats_writer`) is the name of the plugin interface.
This must be one of the provided interfaces ([see below](#plugin-interfaces)).
* `example` is the plugin implementation name. This can be anything.
* `mlagents_plugin_examples.example_stats_writer` is the plugin module. This points to the module where the
plugin registration function is defined.
* `get_example_stats_writer` is the plugin registration function. This is called when running `mlagents-learn`. The
arguments and expected return type for this are different for each plugin interface.

### Local Installation
Once you've defined `entry_points` in your `setup.py`, you will need to run
```
pip install -e [path to your plugin code]
```
in the same python virtual environment that you have `mlagents` installed.

## Plugin Interfaces

### StatsWriter
The StatsWriter class receives various information from the training process, such as the average Agent reward in
each summary period. By default, we log this information to the console and write it to
[TensorBoard](Using-Tensorboard.md).

#### Interface
The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter,
which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with
string keys.

#### Registration
The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An
example implementation is provided in [`mlagents_plugin_examples`](../ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py)
3 changes: 3 additions & 0 deletions ml-agents-plugin-examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ML-Agents Plugins

See the [Plugins documentation](../docs/Training-Plugins.md) for more information.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Dict, List
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import StatsWriter, StatsSummary


class ExampleStatsWriter(StatsWriter):
"""
Example implementation of the StatsWriter abstract class.
This doesn't do anything interesting, just prints the stats that it gets.
"""

def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
print(f"ExampleStatsWriter category: {category} values: {values}")


def get_example_stats_writer(run_options: RunOptions) -> List[StatsWriter]:
"""
Registration function. This is referenced in setup.py and will
be called by mlagents-learn when it starts to determine the
list of StatsWriters to use.

It must return a list of StatsWriters.
"""
print("Creating a new stats writer! This is so exciting!")
return [ExampleStatsWriter()]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from mlagents.plugins.stats_writer import register_stats_writer_plugins
from mlagents.trainers.settings import RunOptions

from mlagents_plugin_examples.example_stats_writer import ExampleStatsWriter


@pytest.mark.check_environment_trains
def test_register_stats_writers():
# Make sure that the ExampleStatsWriter gets returned from the list of all StatsWriters
stats_writers = register_stats_writer_plugins(RunOptions())
assert any(isinstance(sw, ExampleStatsWriter) for sw in stats_writers)
17 changes: 17 additions & 0 deletions ml-agents-plugin-examples/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from setuptools import setup
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No plans to publish this package, but we could use it to set up tests, e.g. for bad imports.

from mlagents.plugins import ML_AGENTS_STATS_WRITER

setup(
name="mlagents_plugin_examples",
version="0.0.1",
# Example of how to add your own registration functions that will be called
# by mlagents-learn.
#
# Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py
# will get registered with the ML_AGENTS_STATS_WRITER plugin interface.
entry_points={
ML_AGENTS_STATS_WRITER: [
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer"
Copy link
Contributor Author

@chriselion chriselion Dec 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The form of this is {entry point name}={plugin module}:{plugin_function}

]
},
)
1 change: 1 addition & 0 deletions ml-agents/mlagents/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ML_AGENTS_STATS_WRITER = "mlagents.stats_writer"
63 changes: 63 additions & 0 deletions ml-agents/mlagents/plugins/stats_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import sys
from typing import List

# importlib.metadata is new in python3.8
# We use the backport for older python versions.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata # pylint: disable=E0611

from mlagents.trainers.stats import StatsWriter

from mlagents_envs import logging_util
from mlagents.plugins import ML_AGENTS_STATS_WRITER
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter


logger = logging_util.get_logger(__name__)


def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
"""
The StatsWriters that mlagents-learn always uses:
* A TensorboardWriter to write information to TensorBoard
* A GaugeWriter to record our internal stats
* A ConsoleWriter to output to stdout.
"""
checkpoint_settings = run_options.checkpoint_settings
return [
TensorboardWriter(
checkpoint_settings.write_path,
clear_past_data=not checkpoint_settings.resume,
),
GaugeWriter(),
ConsoleWriter(),
]


def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
"""
Registers all StatsWriter plugins (including the default one),
and evaluates them, and returns the list of all the StatsWriter implementations.
"""
all_stats_writers: List[StatsWriter] = []
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]

for entry_point in entry_points:

try:
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
plugin_func = entry_point.load()
plugin_stats_writers = plugin_func(run_options)
logger.debug(
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
)
all_stats_writers += plugin_stats_writers
except BaseException:
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
logger.exception(
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
)
return all_stats_writers
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def _create_parser() -> argparse.ArgumentParser:
action=RaiseRemovedWarning,
help="(Removed) Use the TensorFlow framework.",
)
argparser.add_argument(
"--results-dir", default="results", help="Results base directory"
)
Comment on lines +192 to +194
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this relevant to this PR? Why are we adding this flag ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously hardcoded in learn.py
https://github.com/Unity-Technologies/ml-agents/pull/4788/files#diff-4e0d7b8aef419254e6fc3c63f5baa53f222afda7c908705e857a22f7d4792c47L68

and is now used in CheckpointSettings.maybe_init_path()
https://github.com/Unity-Technologies/ml-agents/pull/4788/files#diff-546e90789e914f8707fd97391f78c4bba39ae69a965858bbb69c2d324db1ec51R719

I can keep it hardcoded if you prefer, but I do think moving all the path logic to part of RunOptions is a win (and basically a pre-req to make TensorBoardWriter act like a plugin)

Copy link
Contributor

@ervteng ervteng Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user changes this from the default results will it mess with how they're saved/synced in cloud? cc: @hvpeteet

Copy link
Contributor

@hvpeteet hvpeteet Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user can set it in the run_options then yes this could mess with cloud since we don't already take it into account. I don't think it is a reason not to make the change though. We can add logic to overwrite this field since we already modify run_options.

Just a couple of questions to confirm my understanding:

  1. Would this be set under checkpoint_settings --> write_path?
  2. Could we (cloud) blindly set this for every experiment and it would be ignored by older mlagents versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It would be checkpoint_settings -> results_dir (at the top level)
  2. This would trigger an error here
    def check_and_structure(key: str, value: Any, class_type: type) -> Any:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we can check for it and only overwrite it if already set then. When you check this in can you file a ticket against me to make sure we get this added at some point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do


eng_conf = argparser.add_argument_group(title="Engine Configuration")
eng_conf.add_argument(
Expand Down
41 changes: 13 additions & 28 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.directory_utils import validate_existing_directories
from mlagents.trainers.stats import (
TensorboardWriter,
StatsReporter,
GaugeWriter,
ConsoleWriter,
)
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.cli_utils import parser
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.settings import RunOptions
Expand All @@ -34,6 +29,7 @@
add_metadata as add_timer_metadata,
)
from mlagents_envs import logging_util
from mlagents.plugins.stats_writer import register_stats_writer_plugins

logger = logging_util.get_logger(__name__)

Expand Down Expand Up @@ -65,21 +61,15 @@ def run_training(run_seed: int, options: RunOptions) -> None:
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings
base_path = "results"
write_path = os.path.join(base_path, checkpoint_settings.run_id)
maybe_init_path = (
os.path.join(base_path, checkpoint_settings.initialize_from)
if checkpoint_settings.initialize_from is not None
else None
)
run_logs_dir = os.path.join(write_path, "run_logs")

run_logs_dir = checkpoint_settings.run_logs_dir
port: Optional[int] = env_settings.base_port
# Check if directory exists
validate_existing_directories(
write_path,
checkpoint_settings.write_path,
checkpoint_settings.resume,
checkpoint_settings.force,
maybe_init_path,
checkpoint_settings.maybe_init_path,
)
# Make run logs directory
os.makedirs(run_logs_dir, exist_ok=True)
Expand All @@ -90,14 +80,9 @@ def run_training(run_seed: int, options: RunOptions) -> None:
)

# Configure Tensorboard Writers and StatsReporter
tb_writer = TensorboardWriter(
write_path, clear_past_data=not checkpoint_settings.resume
)
gauge_write = GaugeWriter()
console_writer = ConsoleWriter()
StatsReporter.add_writer(tb_writer)
StatsReporter.add_writer(gauge_write)
StatsReporter.add_writer(console_writer)
stats_writers = register_stats_writer_plugins(options)
for sw in stats_writers:
StatsReporter.add_writer(sw)

if env_settings.env_path is None:
port = None
Expand All @@ -117,18 +102,18 @@ def run_training(run_seed: int, options: RunOptions) -> None:

trainer_factory = TrainerFactory(
trainer_config=options.behaviors,
output_path=write_path,
output_path=checkpoint_settings.write_path,
train_model=not checkpoint_settings.inference,
load_model=checkpoint_settings.resume,
seed=run_seed,
param_manager=env_parameter_manager,
init_path=maybe_init_path,
init_path=checkpoint_settings.maybe_init_path,
multi_gpu=False,
)
# Create controller and begin training.
tc = TrainerController(
trainer_factory,
write_path,
checkpoint_settings.write_path,
checkpoint_settings.run_id,
env_parameter_manager,
not checkpoint_settings.inference,
Expand All @@ -140,7 +125,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
tc.start_learning(env_manager)
finally:
env_manager.close()
write_run_options(write_path, options)
write_run_options(checkpoint_settings.write_path, options)
write_timing_tree(run_logs_dir)
write_training_status(run_logs_dir)

Expand Down
18 changes: 18 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os.path
import warnings

import attr
Expand Down Expand Up @@ -706,6 +707,23 @@ class CheckpointSettings:
force: bool = parser.get_default("force")
train_model: bool = parser.get_default("train_model")
inference: bool = parser.get_default("inference")
results_dir: str = parser.get_default("results_dir")

@property
def write_path(self) -> str:
return os.path.join(self.results_dir, self.run_id)

@property
def maybe_init_path(self) -> Optional[str]:
return (
os.path.join(self.results_dir, self.initialize_from)
if self.initialize_from is not None
else None
)

@property
def run_logs_dir(self) -> str:
return os.path.join(self.write_path, "run_logs")


@attr.s(auto_attribs=True)
Expand Down
7 changes: 7 additions & 0 deletions ml-agents/mlagents/trainers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class StatsWriter(abc.ABC):
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
"""
Callback to record training information
:param category: Category of the statistics. Usually this is the behavior name.
:param values: Dictionary of statistics.
:param step: The current training step.
:return:
"""
pass

def add_property(
Expand Down
8 changes: 7 additions & 1 deletion ml-agents/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from setuptools import setup, find_packages
from setuptools.command.install import install
from mlagents.plugins import ML_AGENTS_STATS_WRITER
import mlagents.trainers

VERSION = mlagents.trainers.__version__
Expand Down Expand Up @@ -71,13 +72,18 @@ def run(self):
"cattrs>=1.0.0,<1.1.0",
"attrs>=19.3.0",
'pypiwin32==223;platform_system=="Windows"',
"importlib_metadata; python_version<'3.8'",
],
python_requires=">=3.6.1",
entry_points={
"console_scripts": [
"mlagents-learn=mlagents.trainers.learn:main",
"mlagents-run-experiment=mlagents.trainers.run_experiment:main",
]
],
# Plugins - each plugin type should have an entry here for the default behavior
ML_AGENTS_STATS_WRITER: [
"default=mlagents.plugins.stats_writer:get_default_stats_writers"
],
},
cmdclass={"verify": VerifyVersionCommand},
)