Skip to content

Commit

Permalink
[rllib] clarify train batch size for PPO (ray-project#2793)
Browse files Browse the repository at this point in the history
It's possible to configure PPO in a way that ends up discarding most of the samples (they are treated as "stragglers"). Add a warning when this happens, and raise an exception if the waste is particularly egregious.
  • Loading branch information
ericl authored Sep 5, 2018
1 parent c87a911 commit 995ac24
Show file tree
Hide file tree
Showing 24 changed files with 83 additions and 63 deletions.
1 change: 1 addition & 0 deletions doc/source/rllib-config.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ In an example below, we train A2C by specifying 8 workers through the config fla
python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
--run=A2C --config '{"num_workers": 8, "monitor": true}'
.. image:: rllib-config.svg

Specifying Resources
~~~~~~~~~~~~~~~~~~~~

You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. Many agents also provide a ``num_gpus`` or ``gpu`` option. In addition, you can allocate a fraction of a GPU by setting ``gpu_fraction: f``. For example, with DQN you can pack five agents onto one GPU by setting ``gpu_fraction: 0.2``. Note that fractional GPU support requires enabling the experimental Xray backend by setting the environment variable ``RAY_USE_XRAY=1``.
>>>>>>> 01b030bd57f014386aa5e4c67a2e069938528abb

Evaluating Trained Agents
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
"num_workers": 2,
# Default sample batch size
"sample_batch_size": 200,
# Training batch size, if applicable. Should be >= sample_batch_size.
# Samples batches will be concatenated together to this size for training.
"train_batch_size": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "truncate_episodes",
# Whether to use a background thread for sampling (slightly off-policy)
Expand Down
5 changes: 2 additions & 3 deletions python/ray/rllib/agents/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"l2_coeff": 0.005,
"noise_stdev": 0.02,
"episodes_per_batch": 1000,
"timesteps_per_batch": 10000,
"train_batch_size": 10000,
"eval_prob": 0.003,
"return_proc_mode": "centered_rank",
"num_workers": 10,
Expand Down Expand Up @@ -213,8 +213,7 @@ def _train(self):
# Use the actors to do rollouts, note that we pass in the ID of the
# policy weights.
results, num_episodes, num_timesteps = self._collect_results(
theta_id, config["episodes_per_batch"],
config["timesteps_per_batch"])
theta_id, config["episodes_per_batch"], config["train_batch_size"])

all_noise_indices = []
all_training_returns = []
Expand Down
27 changes: 20 additions & 7 deletions python/ray/rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
"lambda": 1.0,
# Initial coefficient for KL divergence
"kl_coeff": 0.2,
# Size of batches collected from each worker
"sample_batch_size": 200,
# Number of timesteps collected for each SGD round
"timesteps_per_batch": 4000,
"train_batch_size": 4000,
# Total SGD batch size across all devices for SGD (multi-gpu only)
"sgd_minibatch_size": 128,
# Number of SGD iterations in each outer loop
"num_sgd_iter": 30,
# Stepsize of SGD
"sgd_stepsize": 5e-5,
"lr": 5e-5,
# Learning rate schedule
"lr_schedule": None,
# Share layers for value function
"vf_share_layers": False,
# Total SGD batch size across all devices for SGD (multi-gpu only)
"sgd_batchsize": 128,
# Coefficient of the value function loss
"vf_loss_coeff": 1.0,
# Coefficient of the entropy regularizer
Expand Down Expand Up @@ -79,6 +81,17 @@ def default_resource_request(cls, config):
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])

def _init(self):
waste_ratio = (
self.config["sample_batch_size"] * self.config["num_workers"] /
self.config["train_batch_size"])
if waste_ratio > 1:
msg = ("sample_batch_size * num_workers >> train_batch_size. "
"This means that many steps will be discarded. Consider "
"reducing sample_batch_size, or increase train_batch_size.")
if waste_ratio > 1.5:
raise ValueError(msg)
else:
print("Warning: " + msg)
self.local_evaluator = self.make_local_evaluator(
self.env_creator, self._policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
Expand All @@ -90,15 +103,15 @@ def _init(self):
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators, {
"num_sgd_iter": self.config["num_sgd_iter"],
"timesteps_per_batch": self.config["timesteps_per_batch"]
"train_batch_size": self.config["train_batch_size"]
})
else:
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator, self.remote_evaluators, {
"sgd_batch_size": self.config["sgd_batchsize"],
"sgd_batch_size": self.config["sgd_minibatch_size"],
"num_sgd_iter": self.config["num_sgd_iter"],
"num_gpus": self.config["num_gpus"],
"timesteps_per_batch": self.config["timesteps_per_batch"],
"train_batch_size": self.config["train_batch_size"],
"standardize_fields": ["advantages"],
})

Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/agents/ppo/ppo_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self,
vf_loss_coeff=self.config["vf_loss_coeff"],
use_gae=self.config["use_gae"])

LearningRateSchedule.__init__(self, self.config["sgd_stepsize"],
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/ppo/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray.rllib.evaluation.sample_batch import SampleBatch


def collect_samples(agents, timesteps_per_batch):
def collect_samples(agents, train_batch_size):
num_timesteps_so_far = 0
trajectories = []
# This variable maps the object IDs of trajectories that are currently
Expand All @@ -19,7 +19,7 @@ def collect_samples(agents, timesteps_per_batch):
fut_sample = agent.sample.remote()
agent_dict[fut_sample] = agent

while num_timesteps_so_far < timesteps_per_batch:
while num_timesteps_so_far < train_batch_size:
# TODO(pcm): Make wait support arbitrary iterators and remove the
# conversion to list here.
[fut_sample], _ = ray.wait(list(agent_dict))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def create_env(env_config):
num_cpus = 4
ray.init(num_cpus=num_cpus, redirect_output=True)
config["num_workers"] = num_cpus
config["timesteps_per_batch"] = 10
config["train_batch_size"] = 1000
config["num_sgd_iter"] = 10
config["gamma"] = 0.999
config["horizon"] = horizon
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def create_env(env_config):
num_cpus = 4
ray.init(num_cpus=num_cpus, redirect_output=True)
config["num_workers"] = num_cpus
config["timesteps_per_batch"] = 10
config["sgd_batchsize"] = 10
config["train_batch_size"] = 1000
config["sgd_minibatch_size"] = 10
config["num_sgd_iter"] = 10
config["gamma"] = 0.999
config["horizon"] = horizon
Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/optimizers/multi_gpu_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
def _init(self,
sgd_batch_size=128,
num_sgd_iter=10,
timesteps_per_batch=1024,
train_batch_size=1024,
num_gpus=0,
standardize_fields=[]):
self.batch_size = sgd_batch_size
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
self.train_batch_size = train_batch_size
if not num_gpus:
self.devices = ["/cpu:0"]
else:
Expand Down Expand Up @@ -99,7 +99,7 @@ def step(self):
# TODO(rliaw): remove when refactoring
from ray.rllib.agents.ppo.rollout import collect_samples
samples = collect_samples(self.remote_evaluators,
self.timesteps_per_batch)
self.train_batch_size)
else:
samples = self.local_evaluator.sample()
self._check_not_multiagent(samples)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/optimizers/sync_samples_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class SyncSamplesOptimizer(PolicyOptimizer):
model weights are then broadcast to all remote evaluators.
"""

def _init(self, num_sgd_iter=1, timesteps_per_batch=1):
def _init(self, num_sgd_iter=1, train_batch_size=1):
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
self.train_batch_size = train_batch_size
self.learner_stats = {}

def step(self):
Expand All @@ -35,7 +35,7 @@ def step(self):

with self.sample_timer:
samples = []
while sum(s.count for s in samples) < self.timesteps_per_batch:
while sum(s.count for s in samples) < self.train_batch_size:
if self.remote_evaluators:
samples.extend(
ray.get([
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/test/test_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_mean_action(alg, obs):
CONFIGS = {
"ES": {
"episodes_per_batch": 10,
"timesteps_per_batch": 100,
"train_batch_size": 100,
"num_workers": 2
},
"DQN": {},
Expand All @@ -40,7 +40,7 @@ def get_mean_action(alg, obs):
},
"PPO": {
"num_sgd_iter": 5,
"timesteps_per_batch": 1000,
"train_batch_size": 1000,
"num_workers": 2
},
"A3C": {
Expand Down
7 changes: 4 additions & 3 deletions python/ray/rllib/test/test_supported_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,16 @@ def testAll(self):
"PPO", {
"num_workers": 1,
"num_sgd_iter": 1,
"timesteps_per_batch": 1,
"sgd_batchsize": 1
"train_batch_size": 10,
"sample_batch_size": 10,
"sgd_minibatch_size": 1
}, stats)
check_support(
"ES", {
"num_workers": 1,
"noise_size": 10000000,
"episodes_per_batch": 1,
"timesteps_per_batch": 1
"train_batch_size": 1
}, stats)
check_support(
"ARS", {
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/tuned_examples/atari-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ atari-ppo:
kl_coeff: 0.5
clip_param: 0.1
entropy_coeff: 0.01
timesteps_per_batch: 5000
train_batch_size: 5000
sample_batch_size: 500
sgd_batchsize: 500
sgd_minibatch_size: 500
num_sgd_iter: 10
num_workers: 10
num_envs_per_worker: 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ cartpole-ppo:
num_workers: 2
num_sgd_iter:
grid_search: [1, 4]
sgd_batchsize:
sgd_minibatch_size:
grid_search: [128, 256, 512]
6 changes: 3 additions & 3 deletions python/ray/rllib/tuned_examples/hopper-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ hopper-ppo:
gamma: 0.995
kl_coeff: 1.0
num_sgd_iter: 20
sgd_stepsize: .0001
sgd_batchsize: 32768
timesteps_per_batch: 160000
lr: .0001
sgd_minibatch_size: 32768
train_batch_size: 160000
num_workers: 64
num_gpus: 4
6 changes: 3 additions & 3 deletions python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ humanoid-ppo-gae:
clip_param: 0.2
kl_coeff: 1.0
num_sgd_iter: 20
sgd_stepsize: .0001
sgd_batchsize: 32768
lr: .0001
sgd_minibatch_size: 32768
horizon: 5000
timesteps_per_batch: 320000
train_batch_size: 320000
model:
free_log_std: true
num_workers: 64
Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/tuned_examples/humanoid-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ humanoid-ppo:
gamma: 0.995
kl_coeff: 1.0
num_sgd_iter: 20
sgd_stepsize: .0001
sgd_batchsize: 32768
timesteps_per_batch: 320000
lr: .0001
sgd_minibatch_size: 32768
train_batch_size: 320000
model:
free_log_std: true
use_gae: false
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/hyperband-cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ cartpole-ppo:
num_workers: 1
num_sgd_iter:
grid_search: [1, 4]
sgd_batchsize:
sgd_minibatch_size:
grid_search: [128, 256, 512]
6 changes: 3 additions & 3 deletions python/ray/rllib/tuned_examples/pendulum-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ pendulum-ppo:
env: Pendulum-v0
run: PPO
config:
timesteps_per_batch: 2048
train_batch_size: 2048
num_workers: 4
lambda: 0.1
gamma: 0.95
sgd_stepsize: 0.0003
sgd_batchsize: 64
lr: 0.0003
sgd_minibatch_size: 64
num_sgd_iter: 10
model:
fcnet_hiddens: [64, 64]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ pendulum-ppo:
# expect -140 within 300-500k steps
timesteps_total: 600000
config:
timesteps_per_batch: 2048
train_batch_size: 2048
num_workers: 4
lambda: 0.1
gamma: 0.95
sgd_stepsize: 0.0003
sgd_batchsize: 64
lr: 0.0003
sgd_minibatch_size: 64
num_sgd_iter: 10
model:
fcnet_hiddens: [64, 64]
Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/tuned_examples/walker2d-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ walker2d-v1-ppo:
config:
kl_coeff: 1.0
num_sgd_iter: 20
sgd_stepsize: .0001
sgd_batchsize: 32768
timesteps_per_batch: 320000
lr: .0001
sgd_minibatch_size: 32768
train_batch_size: 320000
num_workers: 64
num_gpus: 4
16 changes: 8 additions & 8 deletions python/ray/tune/examples/pbt_ppo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
# Postprocess the perturbed config to ensure it's still valid
def explore(config):
# ensure we collect enough timesteps to do sgd
if config["timesteps_per_batch"] < config["sgd_batchsize"] * 2:
config["timesteps_per_batch"] = config["sgd_batchsize"] * 2
if config["train_batch_size"] < config["sgd_minibatch_size"] * 2:
config["train_batch_size"] = config["sgd_minibatch_size"] * 2
# ensure we run at least one sgd iter
if config["num_sgd_iter"] < 1:
config["num_sgd_iter"] = 1
Expand All @@ -37,10 +37,10 @@ def explore(config):
hyperparam_mutations={
"lambda": lambda: random.uniform(0.9, 1.0),
"clip_param": lambda: random.uniform(0.01, 0.5),
"sgd_stepsize": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
"lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
"num_sgd_iter": lambda: random.randint(1, 30),
"sgd_batchsize": lambda: random.randint(128, 16384),
"timesteps_per_batch": lambda: random.randint(2000, 160000),
"sgd_minibatch_size": lambda: random.randint(128, 16384),
"train_batch_size": lambda: random.randint(2000, 160000),
},
custom_explore_fn=explore)

Expand All @@ -61,13 +61,13 @@ def explore(config):
# These params are tuned from a fixed starting value.
"lambda": 0.95,
"clip_param": 0.2,
"sgd_stepsize": 1e-4,
"lr": 1e-4,
# These params start off randomly drawn from a set.
"num_sgd_iter":
lambda spec: random.choice([10, 20, 30]),
"sgd_batchsize":
"sgd_minibatch_size":
lambda spec: random.choice([128, 512, 2048]),
"timesteps_per_batch":
"train_batch_size":
lambda spec: random.choice([10000, 20000, 40000])
},
},
Expand Down
Loading

0 comments on commit 995ac24

Please sign in to comment.