Skip to content

Commit

Permalink
[RLlib] AlphaZero TrainerConfig objects. (#25256)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored May 30, 2022
1 parent d95009a commit 30f6fc3
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 86 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Environments and Adapters
Example demonstrating how to use the SUMO simulator in connection with RLlib.
- `VizDoom example script using RLlib's auto-attention wrapper <https://github.com/ray-project/ray/blob/master/rllib/examples/vizdoom_with_attention_net.py>`__:
Script showing how to run PPO with an attention net against a VizDoom gym environment.
- `Subprocess environment <https://github.com/ray-project/ray/blob/master/rllib/tests/test_env_with_subprocess.py>`__:
- `Subprocess environment <https://github.com/ray-project/ray/blob/master/rllib/env/tests/test_env_with_subprocess.py>`__:
Example of how to ensure subprocesses spawned by envs are killed when RLlib exits.


Expand Down
2 changes: 2 additions & 0 deletions rllib/algorithms/alpha_zero/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from ray.rllib.algorithms.alpha_zero.alpha_zero import (
AlphaZeroConfig,
AlphaZeroTrainer,
DEFAULT_CONFIG,
)
from ray.rllib.algorithms.alpha_zero.alpha_zero_policy import AlphaZeroPolicy

__all__ = [
"AlphaZeroConfig",
"AlphaZeroPolicy",
"AlphaZeroTrainer",
"DEFAULT_CONFIG",
Expand Down
282 changes: 201 additions & 81 deletions rllib/algorithms/alpha_zero/alpha_zero.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from typing import Type
from typing import List, Optional, Type, Union

from ray.rllib.agents import with_common_config
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.replay_ops import (
SimpleReplayBuffer,
Expand All @@ -28,7 +28,7 @@
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import Deprecated, override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
Expand Down Expand Up @@ -63,83 +63,186 @@ def on_episode_start(self, worker, base_env, policies, episode, **kwargs):
episode.user_data["initial_state"] = state


# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Size of batches collected from each worker
"rollout_fragment_length": 200,
# Number of timesteps collected for each SGD round
"train_batch_size": 4000,
# Total SGD batch size across all devices for SGD
"sgd_minibatch_size": 128,
# Whether to shuffle sequences in the batch when training (recommended)
"shuffle_sequences": True,
# Number of SGD iterations in each outer loop
"num_sgd_iter": 30,
# In case a buffer optimizer is used
"learning_starts": 1000,
# Size of the replay buffer in batches (not timesteps!).
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
# When to start returning samples (in batches, not timesteps!).
"learning_starts": 500,
},
# Stepsize of SGD
"lr": 5e-5,
# Learning rate schedule
"lr_schedule": None,
# Share layers for value function. If you set this to True, it"s important
# to tune vf_loss_coeff.
"vf_share_layers": False,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "complete_episodes",
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",

# === MCTS ===
"mcts_config": {
"puct_coefficient": 1.0,
"num_simulations": 30,
"temperature": 1.5,
"dirichlet_epsilon": 0.25,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
},

# === Ranked Rewards ===
# implement the ranked reward (r2) algorithm
# from: https://arxiv.org/pdf/1807.01672.pdf
"ranked_rewards": {
"enable": True,
"percentile": 75,
"buffer_max_length": 1000,
# add rewards obtained from random policy to
# "warm start" the buffer
"initialize_buffer": True,
"num_init_rewards": 100,
},

# === Evaluation ===
# Extra configuration that disables exploration.
"evaluation_config": {
"mcts_config": {
"argmax_tree_policy": True,
"add_dirichlet_noise": False,
},
},

# === Callbacks ===
"callbacks": AlphaZeroDefaultCallbacks,

"framework": "torch", # Only PyTorch supported so far.
})
# __sphinx_doc_end__
# fmt: on
class AlphaZeroConfig(TrainerConfig):
"""Defines a configuration class from which an AlphaZeroTrainer can be built.
Example:
>>> from ray.rllib.algorithms.alpha_zero import AlphaZeroConfig
>>> config = AlphaZeroConfig().training(sgd_minibatch_size=256)\
... .resources(num_gpus=0)\
... .rollouts(num_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.algorithms.alpha_zero import AlphaZeroConfig
>>> from ray import tune
>>> config = AlphaZeroConfig()
>>> # Print out some default values.
>>> print(config.shuffle_sequences)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "AlphaZero",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""

def __init__(self, trainer_class=None):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=trainer_class or AlphaZeroTrainer)

# fmt: off
# __sphinx_doc_begin__
# AlphaZero specific config settings:
self.sgd_minibatch_size = 128
self.shuffle_sequences = True
self.num_sgd_iter = 30
self.learning_starts = 1000
self.replay_buffer_config = {
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
# When to start returning samples (in batches, not timesteps!).
"learning_starts": 500,
}
self.lr_schedule = None
self.vf_share_layers = False
self.mcts_config = {
"puct_coefficient": 1.0,
"num_simulations": 30,
"temperature": 1.5,
"dirichlet_epsilon": 0.25,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
}
self.ranked_rewards = {
"enable": True,
"percentile": 75,
"buffer_max_length": 1000,
# add rewards obtained from random policy to
# "warm start" the buffer
"initialize_buffer": True,
"num_init_rewards": 100,
}

# Override some of TrainerConfig's default values with AlphaZero-specific
# values.
self.framework_str = "torch"
self.callbacks_class = AlphaZeroDefaultCallbacks
self.lr = 5e-5
self.rollout_fragment_length = 200
self.train_batch_size = 4000
self.batch_mode = "complete_episodes"
# Extra configuration that disables exploration.
self.evaluation_config = {
"mcts_config": {
"argmax_tree_policy": True,
"add_dirichlet_noise": False,
},
}
# __sphinx_doc_end__
# fmt: on

self.buffer_size = DEPRECATED_VALUE

@override(TrainerConfig)
def training(
self,
*,
sgd_minibatch_size: Optional[int] = None,
shuffle_sequences: Optional[bool] = None,
num_sgd_iter: Optional[int] = None,
replay_buffer_config: Optional[dict] = None,
lr_schedule: Optional[List[List[Union[int, float]]]] = None,
vf_share_layers: Optional[bool] = None,
mcts_config: Optional[dict] = None,
ranked_rewards: Optional[dict] = None,
**kwargs,
) -> "AlphaZeroConfig":
"""Sets the training related configuration.
Args:
sgd_minibatch_size: Total SGD batch size across all devices for SGD.
shuffle_sequences: Whether to shuffle sequences in the batch when training
(recommended).
num_sgd_iter: Number of SGD iterations in each outer loop.
replay_buffer_config: Replay buffer config.
Examples:
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"learning_starts": 1000,
"capacity": 50000,
"replay_sequence_length": 1,
}
- OR -
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"replay_sequence_length": 1,
}
- Where -
prioritized_replay_alpha: Alpha parameter controls the degree of
prioritization in the buffer. In other words, when a buffer sample has
a higher temporal-difference error, with how much more probability
should it drawn to use to update the parametrized Q-network. 0.0
corresponds to uniform probability. Setting much above 1.0 may quickly
result as the sampling distribution could become heavily “pointy” with
low entropy.
prioritized_replay_beta: Beta parameter controls the degree of
importance sampling which suppresses the influence of gradient updates
from samples that have higher probability of being sampled via alpha
parameter and the temporal-difference error.
prioritized_replay_eps: Epsilon parameter sets the baseline probability
for sampling so that when the temporal-difference error of a sample is
zero, there is still a chance of drawing the sample.
lr_schedule: Learning rate schedule. In the format of
[[timestep, lr-value], [timestep, lr-value], ...]
Intermediary timesteps will be assigned to interpolated learning rate
values. A schedule should normally start from timestep 0.
vf_share_layers: Share layers for value function. If you set this to True,
it's important to tune vf_loss_coeff.
mcts_config: MCTS specific settings.
ranked_rewards: Settings for the ranked reward (r2) algorithm
from: https://arxiv.org/pdf/1807.01672.pdf
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)

if sgd_minibatch_size is not None:
self.sgd_minibatch_size = sgd_minibatch_size
if shuffle_sequences is not None:
self.shuffle_sequences = shuffle_sequences
if num_sgd_iter is not None:
self.num_sgd_iter = num_sgd_iter
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
if lr_schedule is not None:
self.lr_schedule = lr_schedule
if vf_share_layers is not None:
self.vf_share_layers = vf_share_layers
if mcts_config is not None:
self.mcts_config = mcts_config
if ranked_rewards is not None:
self.ranked_rewards = ranked_rewards

return self


def alpha_zero_loss(policy, model, dist_class, train_batch):
Expand Down Expand Up @@ -203,7 +306,7 @@ class AlphaZeroTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return AlphaZeroConfig().to_dict()

def validate_config(self, config: TrainerConfigDict) -> None:
"""Checks and updates the config based on settings."""
Expand Down Expand Up @@ -312,3 +415,20 @@ def execution_plan(
)

return StandardMetricsReporting(train_op, workers, config)


# Deprecated: Use ray.rllib.algorithms.alpha_zero.AlphaZeroConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(AlphaZeroConfig().to_dict())

@Deprecated(
old="ray.rllib.algorithms.alpha_zero.alpha_zero.DEFAULT_CONFIG",
new="ray.rllib.algorithms.alpha_zero.alpha_zero.AlphaZeroConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)


DEFAULT_CONFIG = _deprecated_default_config()
10 changes: 6 additions & 4 deletions rllib/algorithms/alpha_zero/tests/test_alpha_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ def tearDownClass(cls) -> None:

def test_alpha_zero_compilation(self):
"""Test whether an AlphaZeroTrainer can be built with all frameworks."""
config = az.DEFAULT_CONFIG.copy()
config["env"] = CartPoleSparseRewards
config["model"]["custom_model"] = DenseModel
config = (
az.AlphaZeroConfig()
.environment(env=CartPoleSparseRewards)
.training(model={"custom_model": DenseModel})
)
num_iterations = 1

# Only working for torch right now.
for _ in framework_iterator(config, frameworks="torch"):
trainer = az.AlphaZeroTrainer(config)
trainer = config.build()
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
Expand Down

0 comments on commit 30f6fc3

Please sign in to comment.