Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
33f01bb
wip
ericl Aug 18, 2018
d6b5b5d
stats
ericl Aug 18, 2018
56081a2
fix assert
ericl Aug 18, 2018
33c4871
fix par opt
ericl Aug 18, 2018
10cb4ac
clip
ericl Aug 19, 2018
517b1ab
fix
ericl Aug 19, 2018
c84906e
gpu queu
ericl Aug 19, 2018
0a01023
multiple loader threadsst
ericl Aug 19, 2018
a33bff0
thread safe
ericl Aug 19, 2018
50843da
fix
ericl Aug 19, 2018
763f165
icr threads
ericl Aug 19, 2018
2d4e332
depth
ericl Aug 19, 2018
e5f9095
Merge remote-tracking branch 'upstream/master' into impala-multigpu
ericl Aug 29, 2018
91df50b
revert local
ericl Aug 29, 2018
60f2d4b
fix
ericl Aug 29, 2018
9676459
frac gpu
ericl Aug 29, 2018
84cc7db
doc
ericl Aug 29, 2018
2295a15
Update rllib-training.rst
ericl Aug 31, 2018
10ad1f7
yapf
ericl Aug 31, 2018
79d9677
Merge branch 'test-frac-gpu' of github.com:ericl/ray into test-frac-gpu
ericl Aug 31, 2018
ee3b599
remove xray
ericl Aug 31, 2018
6cc13b1
replay params
ericl Sep 3, 2018
6070872
fix assert
ericl Sep 3, 2018
e6fe182
Merge remote-tracking branch 'upstream/master' into impala-multigpu
ericl Sep 3, 2018
0c4db6d
rep
ericl Sep 3, 2018
54acce9
note
ericl Sep 3, 2018
c1468a1
update
ericl Sep 3, 2018
9d16999
fix ceil
ericl Sep 3, 2018
8f5e704
Merge branch 'test-frac-gpu' into impala-multigpu
ericl Sep 3, 2018
d56768e
Merge remote-tracking branch 'upstream/master' into impala-multigpu
ericl Oct 5, 2018
113a968
Merge remote-tracking branch 'upstream/master' into impala-multigpu
ericl Oct 13, 2018
5e194c4
fix up names
ericl Oct 13, 2018
df26220
no debug
ericl Oct 13, 2018
14f2339
clean up apex stats too
ericl Oct 13, 2018
49c38ea
test
ericl Oct 13, 2018
dacda0f
fix replay sampling
ericl Oct 13, 2018
473f3fe
doc
ericl Oct 13, 2018
0dd2bae
comments 0
ericl Oct 14, 2018
9d2b535
fix lstm case
ericl Oct 14, 2018
daebf65
set to 99999
ericl Oct 14, 2018
8d66ded
fix assert
ericl Oct 14, 2018
8979a54
doc
ericl Oct 14, 2018
7c81fe9
impala lstm test
ericl Oct 14, 2018
d94fe31
remove unneeded assert
ericl Oct 14, 2018
bae28f6
ppo mask
ericl Oct 14, 2018
76640e3
doc
ericl Oct 14, 2018
507ae93
lint
ericl Oct 15, 2018
73005ca
fix mask
ericl Oct 15, 2018
6b3f6f3
fix non lstm case
ericl Oct 15, 2018
c7d893d
fix impala
ericl Oct 15, 2018
b76c0c4
damnit flake8
ericl Oct 15, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Importance Weighted Actor-Learner Architecture (IMPALA)

`[paper] <https://arxiv.org/abs/1802.01561>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/impala/impala.py>`__
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models>`__.
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models>`__. Multiple learner GPUs and experience replay are also supported.

Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/atari-impala.yaml>`__

Expand Down
24 changes: 22 additions & 2 deletions python/ray/rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
from ray.tune.trial import Resources

OPTIMIZER_SHARED_CONFIGS = [
"lr",
"num_envs_per_worker",
"num_gpus",
"sample_batch_size",
"train_batch_size",
"replay_buffer_num_slots",
"replay_proportion",
"num_parallel_data_loaders",
"grad_clip",
"max_sample_requests_in_flight_per_worker",
]

DEFAULT_CONFIG = with_common_config({
Expand All @@ -25,10 +33,22 @@
"sample_batch_size": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"gpu": True,
"num_workers": 2,
"num_cpus_per_worker": 1,
"num_gpus_per_worker": 0,
# number of GPUs the learner should use.
"num_gpus": 1,
# set >1 to load data into GPUs in parallel. Increases GPU memory usage
# proportionally with the number of loaders.
"num_parallel_data_loaders": 1,
# level of queuing for sampling.
"max_sample_requests_in_flight_per_worker": 2,
# set >0 to enable experience replay. Saved samples will be replayed with
# a p:1 proportion to new data samples.
"replay_proportion": 0.0,
# number of sample batches to store for replay. The number of transitions
# saved total will be (replay_buffer_num_slots * sample_batch_size).
"replay_buffer_num_slots": 100,

# Learning params.
"grad_clip": 40.0,
Expand Down Expand Up @@ -65,7 +85,7 @@ def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1,
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
gpu=cf["num_gpus"] and cf["num_gpus"] * cf["gpu_fraction"] or 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work when num_gpus * gpu_fraction > 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what will happen if gpu_fraction > 1, but for < 1 I believe it should work.

extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])

Expand Down
81 changes: 58 additions & 23 deletions python/ray/rllib/agents/impala/vtrace_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self,
rewards,
values,
bootstrap_value,
valid_mask,
vf_loss_coeff=0.5,
entropy_coeff=-0.01,
clip_rho_threshold=1.0,
Expand All @@ -52,6 +53,7 @@ def __init__(self,
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
valid_mask: A bool tensor of valid RNN input elements (#2992).
"""

# Compute vtrace on the CPU for better perf.
Expand All @@ -70,55 +72,73 @@ def __init__(self,

# The policy gradients loss
self.pi_loss = -tf.reduce_sum(
actions_logp * self.vtrace_returns.pg_advantages)
tf.boolean_mask(actions_logp * self.vtrace_returns.pg_advantages,
valid_mask))

# The baseline loss
delta = values - self.vtrace_returns.vs
delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))

# The entropy loss
self.entropy = tf.reduce_sum(actions_entropy)
self.entropy = tf.reduce_sum(
tf.boolean_mask(actions_entropy, valid_mask))

# The summed weighted loss
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
self.entropy * entropy_coeff)


class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
def __init__(self, observation_space, action_space, config):
def __init__(self,
observation_space,
action_space,
config,
existing_inputs=None):
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
assert config["batch_mode"] == "truncate_episodes", \
"Must use `truncate_episodes` batch mode with V-trace."
self.config = config
self.sess = tf.get_default_session()

# Create input placeholders
if existing_inputs:
actions, dones, behaviour_logits, rewards, observations = \
existing_inputs[:5]
existing_state_in = existing_inputs[5:-1]
existing_seq_lens = existing_inputs[-1]
else:
if isinstance(action_space, gym.spaces.Discrete):
ac_size = action_space.n
actions = tf.placeholder(tf.int64, [None], name="ac")
else:
raise UnsupportedSpaceException(
"Action space {} is not supported for IMPALA.".format(
action_space))
dones = tf.placeholder(tf.bool, [None], name="dones")
rewards = tf.placeholder(tf.float32, [None], name="rewards")
behaviour_logits = tf.placeholder(
tf.float32, [None, ac_size], name="behaviour_logits")
observations = tf.placeholder(
tf.float32, [None] + list(observation_space.shape))
existing_state_in = None
existing_seq_lens = None

# Setup the policy
self.observations = tf.placeholder(
tf.float32, [None] + list(observation_space.shape))
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_model(self.observations, logit_dim,
self.config["model"])
self.model = ModelCatalog.get_model(
observations,
logit_dim,
self.config["model"],
state_in=existing_state_in,
seq_lens=existing_seq_lens)
action_dist = dist_class(self.model.outputs)
values = tf.reshape(
linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
[-1])
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)

# Setup the policy loss
if isinstance(action_space, gym.spaces.Discrete):
ac_size = action_space.n
actions = tf.placeholder(tf.int64, [None], name="ac")
else:
raise UnsupportedSpaceException(
"Action space {} is not supported for IMPALA.".format(
action_space))
dones = tf.placeholder(tf.bool, [None], name="dones")
rewards = tf.placeholder(tf.float32, [None], name="rewards")
behaviour_logits = tf.placeholder(
tf.float32, [None, ac_size], name="behaviour_logits")

def to_batches(tensor):
if self.config["model"]["use_lstm"]:
B = tf.shape(self.model.seq_lens)[0]
Expand All @@ -135,6 +155,13 @@ def to_batches(tensor):
rs,
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

if self.model.state_in:
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(rewards)

# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
self.loss = VTraceLoss(
actions=to_batches(actions)[:-1],
Expand All @@ -147,6 +174,7 @@ def to_batches(tensor):
rewards=to_batches(rewards)[:-1],
values=to_batches(values)[:-1],
bootstrap_value=to_batches(values)[-1],
valid_mask=to_batches(mask)[:-1],
vf_loss_coeff=self.config["vf_loss_coeff"],
entropy_coeff=self.config["entropy_coeff"],
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
Expand All @@ -158,7 +186,7 @@ def to_batches(tensor):
("dones", dones),
("behaviour_logits", behaviour_logits),
("rewards", rewards),
("obs", self.observations),
("obs", observations),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
Expand All @@ -167,7 +195,7 @@ def to_batches(tensor):
observation_space,
action_space,
self.sess,
obs_input=self.observations,
obs_input=observations,
action_sampler=action_dist.sample(),
loss=self.loss.total_loss,
loss_inputs=loss_in,
Expand Down Expand Up @@ -218,3 +246,10 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None):

def get_initial_state(self):
return self.model.state_init

def copy(self, existing_inputs):
return VTracePolicyGraph(
self.observation_space,
self.action_space,
self.config,
existing_inputs=existing_inputs)
35 changes: 25 additions & 10 deletions python/ray/rllib/agents/ppo/ppo_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
curr_action_dist,
value_fn,
cur_kl_coeff,
valid_mask,
entropy_coeff=0,
clip_param=0.1,
vf_clip_param=0.1,
Expand All @@ -48,43 +49,49 @@ def __init__(self,
value_fn (Tensor): Current value function output Tensor.
cur_kl_coeff (Variable): Variable holding the current PPO KL
coefficient.
valid_mask (Tensor): A bool mask of valid input elements (#2992).
entropy_coeff (float): Coefficient of the entropy regularizer.
clip_param (float): Clip parameter
vf_clip_param (float): Clip parameter for the value function
vf_loss_coeff (float): Coefficient of the value function loss
use_gae (bool): If true, use the Generalized Advantage Estimator.
"""

def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))

dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
prev_dist = dist_cls(logits)
# Make loss functions.
logp_ratio = tf.exp(
curr_action_dist.logp(actions) - prev_dist.logp(actions))
action_kl = prev_dist.kl(curr_action_dist)
self.mean_kl = tf.reduce_mean(action_kl)
self.mean_kl = reduce_mean_valid(action_kl)

curr_entropy = curr_action_dist.entropy()
self.mean_entropy = tf.reduce_mean(curr_entropy)
self.mean_entropy = reduce_mean_valid(curr_entropy)

surrogate_loss = tf.minimum(
advantages * logp_ratio,
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
1 + clip_param))
self.mean_policy_loss = tf.reduce_mean(-surrogate_loss)
self.mean_policy_loss = reduce_mean_valid(-surrogate_loss)

if use_gae:
vf_loss1 = tf.square(value_fn - value_targets)
vf_clipped = vf_preds + tf.clip_by_value(
value_fn - vf_preds, -vf_clip_param, vf_clip_param)
vf_loss2 = tf.square(vf_clipped - value_targets)
vf_loss = tf.maximum(vf_loss1, vf_loss2)
self.mean_vf_loss = tf.reduce_mean(vf_loss)
loss = tf.reduce_mean(-surrogate_loss + cur_kl_coeff * action_kl +
vf_loss_coeff * vf_loss -
entropy_coeff * curr_entropy)
self.mean_vf_loss = reduce_mean_valid(vf_loss)
loss = reduce_mean_valid(
-surrogate_loss + cur_kl_coeff * action_kl +
vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy)
else:
self.mean_vf_loss = tf.constant(0.0)
loss = tf.reduce_mean(-surrogate_loss + cur_kl_coeff * action_kl -
entropy_coeff * curr_entropy)
loss = reduce_mean_valid(-surrogate_loss +
cur_kl_coeff * action_kl -
entropy_coeff * curr_entropy)
self.loss = loss


Expand Down Expand Up @@ -179,6 +186,13 @@ def __init__(self,
else:
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])

if self.model.state_in:
max_seq_len = tf.reduce_max(self.model.seq_lens)
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(adv_ph)

self.loss_obj = PPOLoss(
action_space,
value_targets_ph,
Expand All @@ -189,6 +203,7 @@ def __init__(self,
curr_action_dist,
self.value_function,
self.kl_coeff,
mask,
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
Expand Down Expand Up @@ -227,7 +242,7 @@ def __init__(self,
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
return PPOPolicyGraph(
None,
self.observation_space,
self.action_space,
self.config,
existing_inputs=existing_inputs)
Expand Down
20 changes: 16 additions & 4 deletions python/ray/rllib/examples/cartpole_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=200)
parser.add_argument("--run", type=str, default="PPO")


class CartPoleStatelessEnv(gym.Env):
Expand Down Expand Up @@ -163,18 +164,29 @@ def close(self):
tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())

ray.init()

configs = {
"PPO": {
"num_sgd_iter": 5,
},
"IMPALA": {
"num_workers": 2,
"num_gpus": 0,
"vf_loss_coeff": 0.01,
},
}

tune.run_experiments({
"test": {
"env": "cartpole_stateless",
"run": "PPO",
"run": args.run,
"stop": {
"episode_reward_mean": args.stop
},
"config": {
"num_sgd_iter": 5,
"config": dict(configs[args.run], **{
"model": {
"use_lstm": True,
},
},
}),
}
})
Loading