-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[rllib] Autoregressive action distributions #5304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
556985a
wip
ericl 515f7ce
wip
ericl dee364f
fix
ericl e2d4fcc
doc
ericl 3a51e24
doc
ericl 81c731f
Update dqn_policy.py
ericl a9e5e14
none
ericl 05083c3
Merge branch 'autoregressive' of github.com:ericl/ray into autoregres…
ericl 292d1ba
lint
ericl 6e6059d
Update rllib-models.rst
ericl 368188e
docs update
ericl c1980d7
Merge branch 'autoregressive' of github.com:ericl/ray into autoregres…
ericl ca4cbbc
doc update
ericl b469b47
move matrix
ericl a4e3069
model
ericl f5e5d0b
env
ericl 2c34ebc
update
ericl 67a2ae0
fix shuffle
ericl e223e85
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl 1c7c0b3
remove keras
ericl eca623f
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl 19b91da
update docs
ericl 59a29f6
docs
ericl b1ed891
switch to logp for stability
ericl 6e584a3
remove override
ericl ba1b531
fix op leak
ericl b551fcb
fix
ericl 4c4786a
fix
ericl ba007f1
lint
ericl ef30c39
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl afc5002
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl 1134d17
doc
ericl 2d946f1
cateogrical
ericl da85071
fix
ericl 7a66f09
fix vtrace
ericl b10b749
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl 406620d
fix appo
ericl fe8ecff
to note
ericl 7e3f040
comments
ericl 3ffb4b7
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl 979fd5a
fix merge
ericl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
RLlib Models and Preprocessors | ||
============================== | ||
RLlib Models, Preprocessors, and Action Distributions | ||
===================================================== | ||
|
||
The following diagram provides a conceptual overview of data flow between different components in RLlib. We start with an ``Environment``, which given an action produces an observation. The observation is preprocessed by a ``Preprocessor`` and ``Filter`` (e.g. for running mean normalization) before being sent to a neural network ``Model``. The model output is in turn interpreted by an ``ActionDistribution`` to determine the next action. | ||
|
||
|
@@ -145,6 +145,7 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith | |
|
||
import ray | ||
import ray.rllib.agents.ppo as ppo | ||
from ray.rllib.models import ModelCatalog | ||
from ray.rllib.models.preprocessors import Preprocessor | ||
|
||
class MyPreprocessorClass(Preprocessor): | ||
|
@@ -164,6 +165,40 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith | |
}, | ||
}) | ||
|
||
Custom Action Distributions | ||
--------------------------- | ||
|
||
Similar to custom models and preprocessors, you can also specify a custom action distribution class as follows. The action dist class is passed a reference to the ``model``, which you can use to access ``model.model_config`` or other attributes of the model. This is commonly used to implement `autoregressive action outputs <#autoregressive-action-distributions>`__. | ||
|
||
.. code-block:: python | ||
|
||
import ray | ||
import ray.rllib.agents.ppo as ppo | ||
from ray.rllib.models import ModelCatalog | ||
from ray.rllib.models.preprocessors import Preprocessor | ||
|
||
class MyActionDist(ActionDistribution): | ||
@staticmethod | ||
def required_model_output_shape(action_space, model_config): | ||
return 7 # controls model output feature vector size | ||
|
||
def __init__(self, inputs, model): | ||
super(MyActionDist, self).__init__(inputs, model) | ||
assert model.num_outputs == 7 | ||
|
||
def sample(self): ... | ||
def logp(self, actions): ... | ||
def entropy(self): ... | ||
|
||
ModelCatalog.register_custom_action_dist("my_dist", MyActionDist) | ||
|
||
ray.init() | ||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={ | ||
"model": { | ||
"custom_action_dist": "my_dist", | ||
}, | ||
}) | ||
|
||
Supervised Model Losses | ||
----------------------- | ||
|
||
|
@@ -231,26 +266,119 @@ Custom models can be used to work with environments where (1) the set of valid a | |
return action_logits + inf_mask, state | ||
|
||
|
||
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. | ||
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__. | ||
|
||
Model-Based Rollouts | ||
~~~~~~~~~~~~~~~~~~~~ | ||
|
||
With a custom policy, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicy for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy: | ||
Autoregressive Action Distributions | ||
----------------------------------- | ||
|
||
In an action space with multiple components (e.g., ``Tuple(a1, a2)``), you might want ``a2`` to be conditioned on the sampled value of ``a1``, i.e., ``a2_sampled ~ P(a2 | a1_sampled, obs)``. Normally, ``a1`` and ``a2`` would be sampled independently, reducing the expressivity of the policy. | ||
|
||
To do this, you need both a custom model that implements the autoregressive pattern, and a custom action distribution class that leverages that model. The `autoregressive_action_dist.py <https://github.com/ray-project/ray/blob/master/rllib/examples/autoregressive_action_dist.py>`__ example shows how this can be implemented for a simple binary action space. For a more complex space, a more efficient architecture such as a `MADE <https://arxiv.org/abs/1502.03509>`__ is recommended. Note that sampling a `N-part` action requires `N` forward passes through the model, however computing the log probability of an action can be done in one pass: | ||
richardliaw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. code-block:: python | ||
|
||
class ModelBasedPolicy(PGPolicy): | ||
def compute_actions(self, | ||
obs_batch, | ||
state_batches, | ||
prev_action_batch=None, | ||
prev_reward_batch=None, | ||
episodes=None): | ||
# compute a batch of actions based on the current obs_batch | ||
# and state of each episode (i.e., for multiagent). You can do | ||
# whatever is needed here, e.g., MCTS rollouts. | ||
return action_batch | ||
class BinaryAutoregressiveOutput(ActionDistribution): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not use |
||
"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)""" | ||
|
||
@staticmethod | ||
def required_model_output_shape(self, model_config): | ||
return 16 # controls model output feature vector size | ||
|
||
def sample(self): | ||
# first, sample a1 | ||
a1_dist = self._a1_distribution() | ||
a1 = a1_dist.sample() | ||
|
||
# sample a2 conditioned on a1 | ||
a2_dist = self._a2_distribution(a1) | ||
a2 = a2_dist.sample() | ||
|
||
# return the action tuple | ||
return TupleActions([a1, a2]) | ||
|
||
def logp(self, actions): | ||
a1, a2 = actions[:, 0], actions[:, 1] | ||
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) | ||
a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec]) | ||
return (Categorical(a1_logits, None).logp(a1) + Categorical( | ||
a2_logits, None).logp(a2)) | ||
|
||
def _a1_distribution(self): | ||
BATCH = tf.shape(self.inputs)[0] | ||
a1_logits, _ = self.model.action_model( | ||
[self.inputs, tf.zeros((BATCH, 1))]) | ||
a1_dist = Categorical(a1_logits, None) | ||
return a1_dist | ||
|
||
def _a2_distribution(self, a1): | ||
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) | ||
_, a2_logits = self.model.action_model([self.inputs, a1_vec]) | ||
a2_dist = Categorical(a2_logits, None) | ||
return a2_dist | ||
|
||
class AutoregressiveActionsModel(TFModelV2): | ||
"""Implements the `.action_model` branch required above.""" | ||
|
||
def __init__(self, obs_space, action_space, num_outputs, model_config, | ||
name): | ||
super(AutoregressiveActionsModel, self).__init__( | ||
obs_space, action_space, num_outputs, model_config, name) | ||
if action_space != Tuple([Discrete(2), Discrete(2)]): | ||
raise ValueError( | ||
"This model only supports the [2, 2] action space") | ||
|
||
# Inputs | ||
obs_input = tf.keras.layers.Input( | ||
shape=obs_space.shape, name="obs_input") | ||
a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input") | ||
ctx_input = tf.keras.layers.Input( | ||
shape=(num_outputs, ), name="ctx_input") | ||
|
||
# Output of the model (normally 'logits', but for an autoregressive | ||
# dist this is more like a context/feature layer encoding the obs) | ||
context = tf.keras.layers.Dense( | ||
num_outputs, | ||
name="hidden", | ||
activation=tf.nn.tanh, | ||
kernel_initializer=normc_initializer(1.0))(obs_input) | ||
|
||
# P(a1 | obs) | ||
a1_logits = tf.keras.layers.Dense( | ||
2, | ||
name="a1_logits", | ||
activation=None, | ||
kernel_initializer=normc_initializer(0.01))(ctx_input) | ||
|
||
# P(a2 | a1) | ||
# --note: typically you'd want to implement P(a2 | a1, obs) as follows: | ||
# a2_context = tf.keras.layers.Concatenate(axis=1)( | ||
# [ctx_input, a1_input]) | ||
a2_context = a1_input | ||
a2_hidden = tf.keras.layers.Dense( | ||
16, | ||
name="a2_hidden", | ||
activation=tf.nn.tanh, | ||
kernel_initializer=normc_initializer(1.0))(a2_context) | ||
a2_logits = tf.keras.layers.Dense( | ||
2, | ||
name="a2_logits", | ||
activation=None, | ||
kernel_initializer=normc_initializer(0.01))(a2_hidden) | ||
|
||
# Base layers | ||
self.base_model = tf.keras.Model(obs_input, context) | ||
self.register_variables(self.base_model.variables) | ||
self.base_model.summary() | ||
|
||
# Autoregressive action sampler | ||
self.action_model = tf.keras.Model([ctx_input, a1_input], | ||
[a1_logits, a2_logits]) | ||
self.action_model.summary() | ||
self.register_variables(self.action_model.variables) | ||
|
||
|
||
|
||
.. note:: | ||
|
||
If you want take this rollouts data and append it to the sample batch, use the ``add_extra_batch()`` method of the `episode objects <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ passed in. For an example of this, see the ``testReturningModelBasedRolloutsData`` `unit test <https://github.com/ray-project/ray/blob/master/rllib/tests/test_multi_agent_env.py>`__. | ||
Not all algorithms support autoregressive action distributions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.