Skip to content

Commit

Permalink
[RLlib] AlphaZero uses training_iteration API. (ray-project#24507)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored May 18, 2022
1 parent 012a4c8 commit 8f50087
Show file tree
Hide file tree
Showing 25 changed files with 284 additions and 122 deletions.
6 changes: 3 additions & 3 deletions doc/source/rllib/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ click on the dropdowns below:

* Model-based / Meta-learning / Offline

- |pytorch| :ref:`Single-Player AlphaZero (contrib/AlphaZero) <alphazero>`
- |pytorch| :ref:`Single-Player AlphaZero (AlphaZero) <alphazero>`

- |pytorch| |tensorflow| :ref:`Model-Agnostic Meta-Learning (MAML) <maml>`

Expand All @@ -147,8 +147,8 @@ click on the dropdowns below:

* Contextual bandits

- |pytorch| :ref:`Linear Upper Confidence Bound (contrib/LinUCB) <lin-ucb>`
- |pytorch| :ref:`Linear Thompson Sampling (contrib/LinTS) <lints>`
- |pytorch| :ref:`Linear Upper Confidence Bound (LinUCB) <lin-ucb>`
- |pytorch| :ref:`Linear Thompson Sampling (LinTS) <lints>`

* Exploration-based plug-ins (can be combined with any algo)

Expand Down
11 changes: 6 additions & 5 deletions doc/source/rllib/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -752,18 +752,19 @@ Tuned examples:

.. _alphazero:

Single-Player Alpha Zero (contrib/AlphaZero)
--------------------------------------------
Single-Player Alpha Zero (AlphaZero)
------------------------------------
|pytorch|
`[paper] <https://arxiv.org/abs/1712.01815>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/contrib/alpha_zero>`__ AlphaZero is an RL agent originally designed for two-player games. This version adapts it to handle single player games. The code can be sscaled to any number of workers. It also implements the ranked rewards `(R2) <https://arxiv.org/abs/1807.01672>`__ strategy to enable self-play even in the one-player setting. The code is mainly purposed to be used for combinatorial optimization.
`[paper] <https://arxiv.org/abs/1712.01815>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/alpha_zero>`__ AlphaZero is an RL agent originally designed for two-player games. This version adapts it to handle single player games. The code can be sscaled to any number of workers. It also implements the ranked rewards `(R2) <https://arxiv.org/abs/1807.01672>`__ strategy to enable self-play even in the one-player setting. The code is mainly purposed to be used for combinatorial optimization.

Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/contrib/alpha_zero/examples/train_cartpole.py>`__
Tuned examples: `Sparse reward CartPole <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/alpha_zero/cartpole-sparse-rewards-alpha-zero.yaml>`__

**AlphaZero-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

.. literalinclude:: ../../../rllib/contrib/alpha_zero/core/alpha_zero_trainer.py
.. literalinclude:: ../../../rllib/algorithms/alpha_zero/alpha_zero.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__


Multi-Agent Methods
Expand Down
28 changes: 19 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ py_test(
args = ["--yaml-dir=tuned_examples/alpha_star", "--num-cpus=10"]
)

# AlphaZero
py_test(
name = "learning_tests_cartpole_sparse_rewards_alpha_zero",
tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_discrete"],
main = "tests/run_regression_tests.py",
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/alpha_zero/cartpole-sparse-rewards-alpha-zero.yaml"],
args = ["--yaml-dir=tuned_examples/alpha_zero", "--num-cpus=8"]
)

# APEX-DQN
# py_test(
# name = "learning_tests_cartpole_apex",
Expand Down Expand Up @@ -650,6 +661,14 @@ py_test(
srcs = ["algorithms/alpha_star/tests/test_alpha_star.py"]
)

# AlphaZero
py_test(
name = "test_alpha_zero",
tags = ["team:ml", "trainers_dir"],
size = "large",
srcs = ["algorithms/alpha_zero/tests/test_alpha_zero.py"]
)

# APEXTrainer (DQN)
py_test(
name = "test_apex_dqn",
Expand Down Expand Up @@ -872,15 +891,6 @@ py_test(
srcs = ["contrib/random_agent/random_agent.py"]
)

py_test(
name = "alpha_zero_cartpole",
tags = ["team:ml", "trainers_dir"],
main = "contrib/alpha_zero/examples/train_cartpole.py",
size = "large",
srcs = ["contrib/alpha_zero/examples/train_cartpole.py"],
args = ["--training-iteration=1", "--num-workers=2", "--ray-num-cpus=3"]
)


# --------------------------------------------------------------------
# Memory leak tests
Expand Down
13 changes: 0 additions & 13 deletions rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,21 +585,8 @@ def training_iteration(self) -> ResultDict:

self.update_workers_if_necessary()

# Callback for APPO to use to update KL, target network periodically.
# The input to the callback is the learner fetches dict.
self.after_train_step(train_results)

return train_results

def after_train_step(self, train_results: ResultDict) -> None:
"""Called by the training_iteration method after each train step.
Args:
train_results: The train results dict.
"""
# By default, do nothing.
pass

@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
Expand Down
49 changes: 38 additions & 11 deletions rllib/agents/ppo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
https://docs.ray.io/en/master/rllib-algorithms.html#appo
"""
from typing import Optional, Type
import logging

from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
from ray.rllib.agents.ppo.ppo import UpdateKL
from ray.rllib.agents import impala
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.agents.impala import ImpalaTrainer, ImpalaConfig
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
Expand All @@ -29,8 +30,10 @@
TrainerConfigDict,
)

logger = logging.getLogger(__name__)

class APPOConfig(impala.ImpalaConfig):

class APPOConfig(ImpalaConfig):
"""Defines a APPOTrainer configuration class from which a new Trainer can be built.
Example:
Expand Down Expand Up @@ -107,7 +110,7 @@ def __init__(self, trainer_class=None):
# __sphinx_doc_end__
# fmt: on

@override(impala.ImpalaConfig)
@override(ImpalaConfig)
def training(
self,
*,
Expand Down Expand Up @@ -164,18 +167,15 @@ def training(
return self


class APPOTrainer(impala.ImpalaTrainer):
class APPOTrainer(ImpalaTrainer):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)

self.update_kl = UpdateKL(self.workers)

# After init: Initialize target net.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

@override(impala.ImpalaTrainer)
def after_train_step(self, train_results: ResultDict) -> None:
"""Updates the target network and the KL coefficient for the APPO-loss.
Expand Down Expand Up @@ -207,14 +207,41 @@ def after_train_step(self, train_results: ResultDict) -> None:

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config["use_kl_loss"]:
self.update_kl(train_results)

def update(pi, pi_id):
assert LEARNER_STATS_KEY not in train_results, (
"{} should be nested under policy id key".format(
LEARNER_STATS_KEY
),
train_results,
)
if pi_id in train_results:
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (train_results, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(kl)
else:
logger.warning("No data for {}, not updating kl".format(pi_id))

# Update KL on all trainable policies within the local (trainer)
# Worker.
self.workers.local_worker().foreach_policy_to_train(update)

@override(ImpalaTrainer)
def training_iteration(self) -> ResultDict:
train_results = super().training_iteration()

# Update KL, target network periodically.
self.after_train_step(train_results)

return train_results

@classmethod
@override(impala.ImpalaTrainer)
@override(ImpalaTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APPOConfig().to_dict()

@override(impala.ImpalaTrainer)
@override(ImpalaTrainer)
def get_default_policy_class(
self, config: PartialTrainerConfigDict
) -> Optional[Type[Policy]]:
Expand Down
12 changes: 11 additions & 1 deletion rllib/agents/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def _import_alpha_star():
return AlphaStarTrainer, DEFAULT_CONFIG


def _import_alpha_zero():
from ray.rllib.algorithms.alpha_zero.alpha_zero import (
AlphaZeroTrainer,
DEFAULT_CONFIG,
)

return AlphaZeroTrainer, DEFAULT_CONFIG


def _import_apex():
from ray.rllib.agents import dqn

Expand Down Expand Up @@ -191,6 +200,8 @@ def _import_td3():
ALGORITHMS = {
"A2C": _import_a2c,
"A3C": _import_a3c,
"AlphaStar": _import_alpha_star,
"AlphaZero": _import_alpha_zero,
"APPO": _import_appo,
"APEX": _import_apex,
"APEX_DDPG": _import_apex_ddpg,
Expand Down Expand Up @@ -218,7 +229,6 @@ def _import_td3():
"SimpleQ": _import_simple_q,
"SlateQ": _import_slate_q,
"TD3": _import_td3,
"AlphaStar": _import_alpha_star,
}


Expand Down
File renamed without changes.
11 changes: 11 additions & 0 deletions rllib/algorithms/alpha_zero/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ray.rllib.algorithms.alpha_zero.alpha_zero import (
AlphaZeroTrainer,
DEFAULT_CONFIG,
)
from ray.rllib.algorithms.alpha_zero.alpha_zero_policy import AlphaZeroPolicy

__all__ = [
"AlphaZeroPolicy",
"AlphaZeroTrainer",
"DEFAULT_CONFIG",
]
Loading

0 comments on commit 8f50087

Please sign in to comment.