Skip to content

Commit a1d2e17

Browse files
authored
[rllib] Autoregressive action distributions (#5304)
1 parent 8b6f0d3 commit a1d2e17

31 files changed

+553
-230
lines changed

ci/jenkins_tests/run_rllib_tests.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
446446
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
447447
/ray/ci/suppress_output python /ray/rllib/examples/twostep_game.py --stop=2000 --run=APEX_QMIX
448448

449+
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
450+
/ray/ci/suppress_output python /ray/rllib/examples/autoregressive_action_dist.py --stop=150
451+
449452
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
450453
/ray/ci/suppress_output /ray/rllib/train.py \
451454
--env PongDeterministic-v4 \

doc/source/rllib-algorithms.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ Tuned examples: `Two-step game <https://github.com/ray-project/ray/blob/master/r
336336
:start-after: __sphinx_doc_begin__
337337
:end-before: __sphinx_doc_end__
338338

339-
Multi-Agent Actor Critic (contrib/MADDPG)
340-
-----------------------------------------
339+
Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG)
340+
---------------------------------------------------------------
341341
`[paper] <https://arxiv.org/abs/1706.02275>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/contrib/maddpg/maddpg.py>`__ MADDPG is a specialized multi-agent algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check `wsjeon/maddpg-rllib <https://github.com/wsjeon/maddpg-rllib>`__ for examples and more information.
342342

343343
**MADDPG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

doc/source/rllib-components.svg

Lines changed: 1 addition & 4 deletions
Loading

doc/source/rllib-concepts.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ The action sampler is straightforward, it just takes the q_model, runs a forward
407407
config):
408408
# do max over Q values...
409409
...
410-
return action, action_prob
410+
return action, action_logp
411411
412412
The remainder of DQN is similar to other algorithms. Target updates are handled by a ``after_optimizer_step`` callback that periodically copies the weights of the Q network to the target.
413413

doc/source/rllib-env.rst

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,33 @@ RLlib works with several different types of environments, including `OpenAI Gym
55

66
.. image:: rllib-envs.svg
77

8-
**Compatibility matrix**:
9-
10-
============= ======================= ================== =========== ==================
11-
Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies
12-
============= ======================= ================== =========== ==================
13-
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes**
14-
PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** **Yes**
15-
PG **Yes** `+parametric`_ **Yes** **Yes** **Yes**
16-
IMPALA **Yes** `+parametric`_ **Yes** **Yes** **Yes**
17-
DQN, Rainbow **Yes** `+parametric`_ No **Yes** No
18-
DDPG, TD3 No **Yes** **Yes** No
19-
APEX-DQN **Yes** `+parametric`_ No **Yes** No
20-
APEX-DDPG No **Yes** **Yes** No
21-
SAC (todo) **Yes** **Yes** No
22-
ES **Yes** **Yes** No No
23-
ARS **Yes** **Yes** No No
24-
QMIX **Yes** No **Yes** **Yes**
25-
MARWIL **Yes** `+parametric`_ **Yes** **Yes** **Yes**
26-
============= ======================= ================== =========== ==================
8+
Feature Compatibility Matrix
9+
----------------------------
10+
11+
============= ======================= ================== =========== ===========================
12+
Algorithm Discrete Actions Continuous Multi-Agent Model Support
13+
============= ======================= ================== =========== ===========================
14+
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
15+
PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
16+
PG **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
17+
IMPALA **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
18+
DQN, Rainbow **Yes** `+parametric`_ No **Yes**
19+
DDPG, TD3 No **Yes** **Yes**
20+
APEX-DQN **Yes** `+parametric`_ No **Yes**
21+
APEX-DDPG No **Yes** **Yes**
22+
SAC (todo) **Yes** **Yes**
23+
ES **Yes** **Yes** No
24+
ARS **Yes** **Yes** No
25+
QMIX **Yes** No **Yes** `+RNN`_
26+
MARWIL **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
27+
============= ======================= ================== =========== ===========================
2728

2829
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
30+
.. _`+RNN`: rllib-models.html#recurrent-models
31+
.. _`+autoreg`: rllib-models.html#autoregressive-action-distributions
32+
33+
Configuring Environments
34+
------------------------
2935

3036
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://gym.openai.com/envs>`__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor:
3137

@@ -69,9 +75,6 @@ For a full runnable code example using the custom environment API, see `custom_e
6975

7076
The gym registry is not compatible with Ray. Instead, always use the registration flows documented above to ensure Ray workers can access the environment.
7177

72-
Configuring Environments
73-
------------------------
74-
7578
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:
7679

7780
.. code-block:: python

doc/source/rllib-models.rst

Lines changed: 146 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
RLlib Models and Preprocessors
2-
==============================
1+
RLlib Models, Preprocessors, and Action Distributions
2+
=====================================================
33

44
The following diagram provides a conceptual overview of data flow between different components in RLlib. We start with an ``Environment``, which given an action produces an observation. The observation is preprocessed by a ``Preprocessor`` and ``Filter`` (e.g. for running mean normalization) before being sent to a neural network ``Model``. The model output is in turn interpreted by an ``ActionDistribution`` to determine the next action.
55

@@ -145,6 +145,7 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith
145145
146146
import ray
147147
import ray.rllib.agents.ppo as ppo
148+
from ray.rllib.models import ModelCatalog
148149
from ray.rllib.models.preprocessors import Preprocessor
149150
150151
class MyPreprocessorClass(Preprocessor):
@@ -164,6 +165,40 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith
164165
},
165166
})
166167
168+
Custom Action Distributions
169+
---------------------------
170+
171+
Similar to custom models and preprocessors, you can also specify a custom action distribution class as follows. The action dist class is passed a reference to the ``model``, which you can use to access ``model.model_config`` or other attributes of the model. This is commonly used to implement `autoregressive action outputs <#autoregressive-action-distributions>`__.
172+
173+
.. code-block:: python
174+
175+
import ray
176+
import ray.rllib.agents.ppo as ppo
177+
from ray.rllib.models import ModelCatalog
178+
from ray.rllib.models.preprocessors import Preprocessor
179+
180+
class MyActionDist(ActionDistribution):
181+
@staticmethod
182+
def required_model_output_shape(action_space, model_config):
183+
return 7 # controls model output feature vector size
184+
185+
def __init__(self, inputs, model):
186+
super(MyActionDist, self).__init__(inputs, model)
187+
assert model.num_outputs == 7
188+
189+
def sample(self): ...
190+
def logp(self, actions): ...
191+
def entropy(self): ...
192+
193+
ModelCatalog.register_custom_action_dist("my_dist", MyActionDist)
194+
195+
ray.init()
196+
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
197+
"model": {
198+
"custom_action_dist": "my_dist",
199+
},
200+
})
201+
167202
Supervised Model Losses
168203
-----------------------
169204

@@ -231,26 +266,119 @@ Custom models can be used to work with environments where (1) the set of valid a
231266
return action_logits + inf_mask, state
232267
233268
234-
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms.
269+
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__.
235270

236-
Model-Based Rollouts
237-
~~~~~~~~~~~~~~~~~~~~
238271

239-
With a custom policy, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicy for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy:
272+
Autoregressive Action Distributions
273+
-----------------------------------
274+
275+
In an action space with multiple components (e.g., ``Tuple(a1, a2)``), you might want ``a2`` to be conditioned on the sampled value of ``a1``, i.e., ``a2_sampled ~ P(a2 | a1_sampled, obs)``. Normally, ``a1`` and ``a2`` would be sampled independently, reducing the expressivity of the policy.
276+
277+
To do this, you need both a custom model that implements the autoregressive pattern, and a custom action distribution class that leverages that model. The `autoregressive_action_dist.py <https://github.com/ray-project/ray/blob/master/rllib/examples/autoregressive_action_dist.py>`__ example shows how this can be implemented for a simple binary action space. For a more complex space, a more efficient architecture such as a `MADE <https://arxiv.org/abs/1502.03509>`__ is recommended. Note that sampling a `N-part` action requires `N` forward passes through the model, however computing the log probability of an action can be done in one pass:
240278

241279
.. code-block:: python
242280
243-
class ModelBasedPolicy(PGPolicy):
244-
def compute_actions(self,
245-
obs_batch,
246-
state_batches,
247-
prev_action_batch=None,
248-
prev_reward_batch=None,
249-
episodes=None):
250-
# compute a batch of actions based on the current obs_batch
251-
# and state of each episode (i.e., for multiagent). You can do
252-
# whatever is needed here, e.g., MCTS rollouts.
253-
return action_batch
281+
class BinaryAutoregressiveOutput(ActionDistribution):
282+
"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""
283+
284+
@staticmethod
285+
def required_model_output_shape(self, model_config):
286+
return 16 # controls model output feature vector size
287+
288+
def sample(self):
289+
# first, sample a1
290+
a1_dist = self._a1_distribution()
291+
a1 = a1_dist.sample()
292+
293+
# sample a2 conditioned on a1
294+
a2_dist = self._a2_distribution(a1)
295+
a2 = a2_dist.sample()
296+
297+
# return the action tuple
298+
return TupleActions([a1, a2])
299+
300+
def logp(self, actions):
301+
a1, a2 = actions[:, 0], actions[:, 1]
302+
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
303+
a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec])
304+
return (Categorical(a1_logits, None).logp(a1) + Categorical(
305+
a2_logits, None).logp(a2))
306+
307+
def _a1_distribution(self):
308+
BATCH = tf.shape(self.inputs)[0]
309+
a1_logits, _ = self.model.action_model(
310+
[self.inputs, tf.zeros((BATCH, 1))])
311+
a1_dist = Categorical(a1_logits, None)
312+
return a1_dist
313+
314+
def _a2_distribution(self, a1):
315+
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
316+
_, a2_logits = self.model.action_model([self.inputs, a1_vec])
317+
a2_dist = Categorical(a2_logits, None)
318+
return a2_dist
319+
320+
class AutoregressiveActionsModel(TFModelV2):
321+
"""Implements the `.action_model` branch required above."""
322+
323+
def __init__(self, obs_space, action_space, num_outputs, model_config,
324+
name):
325+
super(AutoregressiveActionsModel, self).__init__(
326+
obs_space, action_space, num_outputs, model_config, name)
327+
if action_space != Tuple([Discrete(2), Discrete(2)]):
328+
raise ValueError(
329+
"This model only supports the [2, 2] action space")
330+
331+
# Inputs
332+
obs_input = tf.keras.layers.Input(
333+
shape=obs_space.shape, name="obs_input")
334+
a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input")
335+
ctx_input = tf.keras.layers.Input(
336+
shape=(num_outputs, ), name="ctx_input")
337+
338+
# Output of the model (normally 'logits', but for an autoregressive
339+
# dist this is more like a context/feature layer encoding the obs)
340+
context = tf.keras.layers.Dense(
341+
num_outputs,
342+
name="hidden",
343+
activation=tf.nn.tanh,
344+
kernel_initializer=normc_initializer(1.0))(obs_input)
345+
346+
# P(a1 | obs)
347+
a1_logits = tf.keras.layers.Dense(
348+
2,
349+
name="a1_logits",
350+
activation=None,
351+
kernel_initializer=normc_initializer(0.01))(ctx_input)
352+
353+
# P(a2 | a1)
354+
# --note: typically you'd want to implement P(a2 | a1, obs) as follows:
355+
# a2_context = tf.keras.layers.Concatenate(axis=1)(
356+
# [ctx_input, a1_input])
357+
a2_context = a1_input
358+
a2_hidden = tf.keras.layers.Dense(
359+
16,
360+
name="a2_hidden",
361+
activation=tf.nn.tanh,
362+
kernel_initializer=normc_initializer(1.0))(a2_context)
363+
a2_logits = tf.keras.layers.Dense(
364+
2,
365+
name="a2_logits",
366+
activation=None,
367+
kernel_initializer=normc_initializer(0.01))(a2_hidden)
368+
369+
# Base layers
370+
self.base_model = tf.keras.Model(obs_input, context)
371+
self.register_variables(self.base_model.variables)
372+
self.base_model.summary()
373+
374+
# Autoregressive action sampler
375+
self.action_model = tf.keras.Model([ctx_input, a1_input],
376+
[a1_logits, a2_logits])
377+
self.action_model.summary()
378+
self.register_variables(self.action_model.variables)
379+
380+
254381
382+
.. note::
255383

256-
If you want take this rollouts data and append it to the sample batch, use the ``add_extra_batch()`` method of the `episode objects <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ passed in. For an example of this, see the ``testReturningModelBasedRolloutsData`` `unit test <https://github.com/ray-project/ray/blob/master/rllib/tests/test_multi_agent_env.py>`__.
384+
Not all algorithms support autoregressive action distributions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__.

doc/source/rllib.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,23 @@ Training APIs
3535
Environments
3636
------------
3737
* `RLlib Environments Overview <rllib-env.html>`__
38+
* `Feature Compatibility Matrix <rllib-env.html#feature-compatibility-matrix>`__
3839
* `OpenAI Gym <rllib-env.html#openai-gym>`__
3940
* `Vectorized <rllib-env.html#vectorized>`__
4041
* `Multi-Agent and Hierarchical <rllib-env.html#multi-agent-and-hierarchical>`__
4142
* `Interfacing with External Agents <rllib-env.html#interfacing-with-external-agents>`__
4243
* `Advanced Integrations <rllib-env.html#advanced-integrations>`__
4344

44-
Models and Preprocessors
45-
------------------------
46-
* `RLlib Models and Preprocessors Overview <rllib-models.html>`__
45+
Models, Preprocessors, and Action Distributions
46+
-----------------------------------------------
47+
* `RLlib Models, Preprocessors, and Action Distributions Overview <rllib-models.html>`__
4748
* `TensorFlow Models <rllib-models.html#tensorflow-models>`__
4849
* `PyTorch Models <rllib-models.html#pytorch-models>`__
4950
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
51+
* `Custom Action Distributions <rllib-models.html#custom-action-distributions>`__
5052
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
5153
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
54+
* `Autoregressive Action Distributions <rllib-models.html#autoregressive-action-distributions>`__
5255

5356
Algorithms
5457
----------
@@ -84,7 +87,7 @@ Algorithms
8487
* Multi-agent specific
8588

8689
- `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__
87-
- `Multi-Agent Actor Critic (contrib/MADDPG) <rllib-algorithms.html#multi-agent-actor-critic-contrib-maddpg>`__
90+
- `Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG) <rllib-algorithms.html#multi-agent-deep-deterministic-policy-gradient-contrib-maddpg>`__
8891

8992
* Offline
9093

rllib/agents/a3c/a3c_torch_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def actor_critic_loss(policy, batch_tensors):
1818
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
1919
}) # TODO(ekl) seq lens shouldn't be None
2020
values = policy.model.value_function()
21-
dist = policy.dist_class(logits, policy.config["model"])
21+
dist = policy.dist_class(logits, policy.model)
2222
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
2323
policy.entropy = dist.entropy().mean()
2424
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(

rllib/agents/ars/policies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self,
8181
model = ModelCatalog.get_model({
8282
"obs": self.inputs
8383
}, obs_space, action_space, dist_dim, model_config)
84-
dist = dist_class(model.outputs, model_config=model_config)
84+
dist = dist_class(model.outputs, model)
8585
self.sampler = dist.sample()
8686

8787
self.variables = ray.experimental.tf_utils.TensorFlowVariables(

0 commit comments

Comments
 (0)