Skip to content

Commit

Permalink
[RLlib] A2C training_iteration method implementation (`_disable_exe…
Browse files Browse the repository at this point in the history
…cution_plan_api=True`) (ray-project#23735)
  • Loading branch information
sven1977 authored Apr 15, 2022
1 parent c38a295 commit 92781c6
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 23 deletions.
10 changes: 10 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/a3c"]
)

py_test(
name = "learning_tests_cartpole_a2c_microbatch",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/a3c/cartpole-a2c-microbatch.yaml"],
args = ["--yaml-dir=tuned_examples/a3c"]
)

py_test(
name = "learning_tests_cartpole_a2c_fake_gpus",
main = "tests/run_regression_tests.py",
Expand Down
139 changes: 133 additions & 6 deletions rllib/agents/a3c/a2c.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,44 @@
import logging
import math

from ray.util.iter import LocalIterator
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, A3CTrainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
ConcatBatches,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
ComputeGradients,
AverageGradients,
ApplyGradients,
MultiGPUTrainOneStep,
TrainOneStep,
)
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.metrics import (
APPLY_GRADS_TIMER,
COMPUTE_GRADS_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
WORKER_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
from ray.util.iter import LocalIterator

logger = logging.getLogger(__name__)

A2C_DEFAULT_CONFIG = merge_dicts(
A3C_CONFIG,
Expand All @@ -28,8 +51,6 @@
# training with batch sizes much larger than can fit in GPU memory.
# To enable, set this to a value less than the train batch size.
"microbatch_size": None,
# Use `execution_plan` for A2C (no `training_iteration` implementation yet).
"_disable_execution_plan_api": False,
},
)

Expand All @@ -40,6 +61,112 @@ class A2CTrainer(A3CTrainer):
def get_default_config(cls) -> TrainerConfigDict:
return A2C_DEFAULT_CONFIG

@override(A3CTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["microbatch_size"]:
# Train batch size needs to be significantly larger than microbatch_size.
if config["train_batch_size"] / config["microbatch_size"] < 3:
logger.warning(
"`train_batch_size` should be considerably larger (at least 3x) "
"than `microbatch_size` for a microbatching setup to make sense!"
)
# Rollout fragment length needs to be less than microbatch_size.
if config["rollout_fragment_length"] > config["microbatch_size"]:
logger.warning(
"`rollout_fragment_length` should not be larger than "
"`microbatch_size` (try setting them to the same value)! "
"Otherwise, microbatches of desired size won't be achievable."
)

@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)

# Create a microbatch variable for collecting gradients on microbatches'.
# These gradients will be accumulated on-the-fly and applied at once (once train
# batch size has been collected) to the model.
if (
self.config["_disable_execution_plan_api"] is True
and self.config["microbatch_size"]
):
self._microbatches_grads = None
self._microbatches_counts = self._num_microbatches = 0

@override(A3CTrainer)
def training_iteration(self) -> ResultDict:
# W/o microbatching: Identical to Trainer's default implementation.
# Only difference to a default Trainer being the value function loss term
# and its value computations alongside each action.
if self.config["microbatch_size"] is None:
return Trainer.training_iteration(self)

# In microbatch mode, we want to compute gradients on experience
# microbatches, average a number of these microbatches, and then
# apply the averaged gradient in one SGD step. This conserves GPU
# memory, allowing for extremely large experience batches to be
# used.
if self._by_agent_steps:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_agent_steps=self.config["microbatch_size"]
)
else:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_env_steps=self.config["microbatch_size"]
)
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()

with self._timers[COMPUTE_GRADS_TIMER]:
grad, info = self.workers.local_worker().compute_gradients(
train_batch, single_agent=True
)
# New microbatch accumulation phase.
if self._microbatches_grads is None:
self._microbatches_grads = grad
# Existing gradients: Accumulate new gradients on top of existing ones.
else:
for i, g in enumerate(grad):
self._microbatches_grads[i] += g
self._microbatches_counts += train_batch.count
self._num_microbatches += 1

# If `train_batch_size` reached: Accumulate gradients and apply.
num_microbatches = math.ceil(
self.config["train_batch_size"] / self.config["microbatch_size"]
)
if self._num_microbatches >= num_microbatches:
# Update counters.
self._counters[STEPS_TRAINED_COUNTER] += self._microbatches_counts
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = self._microbatches_counts

# Apply gradients.
apply_timer = self._timers[APPLY_GRADS_TIMER]
with apply_timer:
self.workers.local_worker().apply_gradients(self._microbatches_grads)
apply_timer.push_units_processed(self._microbatches_counts)

# Reset microbatch information.
self._microbatches_grads = None
self._microbatches_counts = self._num_microbatches = 0

# Also update global vars of the local worker.
# Create current global vars.
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights(
policies=self.workers.local_worker().get_policies_to_train(),
global_vars=global_vars,
)

train_results = {DEFAULT_POLICY_ID: info}

return train_results

@staticmethod
@override(Trainer)
def execution_plan(
Expand Down
12 changes: 7 additions & 5 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,11 +941,13 @@ def learn_on_batch(self, samples: SampleBatchType) -> Dict:
info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()})
else:
if self.is_policy_to_train(DEFAULT_POLICY_ID, samples):
info_out = {
DEFAULT_POLICY_ID: self.policy_map[
DEFAULT_POLICY_ID
].learn_on_batch(samples)
}
info_out.update(
{
DEFAULT_POLICY_ID: self.policy_map[
DEFAULT_POLICY_ID
].learn_on_batch(samples)
}
)
if log_once("learn_out"):
logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
return info_out
Expand Down
8 changes: 4 additions & 4 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
)
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
SampleBatch,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.tune.registry import register_env

Expand Down Expand Up @@ -159,12 +159,12 @@ def test_global_vars_update(self):
result = agent.train()
print(
"{}={}".format(
STEPS_TRAINED_COUNTER, result["info"][STEPS_TRAINED_COUNTER]
NUM_AGENT_STEPS_TRAINED, result["info"][NUM_AGENT_STEPS_TRAINED]
)
)
print(
"{}={}".format(
STEPS_SAMPLED_COUNTER, result["info"][STEPS_SAMPLED_COUNTER]
NUM_AGENT_STEPS_SAMPLED, result["info"][NUM_AGENT_STEPS_SAMPLED]
)
)
global_timesteps = (
Expand Down Expand Up @@ -588,7 +588,7 @@ def test_vector_env_support(self):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertGreater(result["episodes_this_iter"], 7)
self.assertGreater(result["episodes_this_iter"], 6)
ev.stop()

def test_truncate_episodes(self):
Expand Down
4 changes: 0 additions & 4 deletions rllib/examples/random_parametric_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def set_weights(self, weights):
pass


# Backward compatibility, just in case users want to use the erroneous old name.
RandomParametriclPolicy = RandomParametricPolicy


class RandomParametricTrainer(Trainer):
"""Trainer with Policy and config defined above and overriding `training_iteration`.
Expand Down
2 changes: 2 additions & 0 deletions rllib/tests/test_exec_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_exec_plan_stats(ray_start_regular):
config={
"min_time_s_per_reporting": 0,
"framework": fw,
"_disable_execution_plan_api": False,
},
)
result = trainer.train()
Expand All @@ -47,6 +48,7 @@ def test_exec_plan_save_restore(ray_start_regular):
config={
"min_time_s_per_reporting": 0,
"framework": fw,
"_disable_execution_plan_api": False,
},
)
res1 = trainer.train()
Expand Down
10 changes: 6 additions & 4 deletions rllib/tuned_examples/a3c/cartpole-a2c-microbatch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ cartpole-a2c-microbatch:
run: A2C
stop:
episode_reward_mean: 150
timesteps_total: 100000
timesteps_total: 500000
config:
# Works for both torch and tf.
framework: tf
num_workers: 1
num_workers: 2
gamma: 0.95
microbatch_size: 50
train_batch_size: 100
rollout_fragment_length: 20
microbatch_size: 40
train_batch_size: 120
seed: 13

0 comments on commit 92781c6

Please sign in to comment.