Skip to content

Commit

Permalink
[RLlib] PyTorch version of ES (Evolution Strategies). (#8104)
Browse files Browse the repository at this point in the history
PyTorch version of Evolution Strategies (ES) Algo.
  • Loading branch information
sven1977 authored Apr 20, 2020
1 parent 9f3e9e7 commit 3812bfe
Show file tree
Hide file tree
Showing 17 changed files with 276 additions and 121 deletions.
4 changes: 2 additions & 2 deletions doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
=================== ========== ======================= ================== =========== =====================
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
`ARS`_ tf **Yes** **Yes** No
`ES`_ tf **Yes** **Yes** No
`ES`_ tf + torch **Yes** **Yes** No
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
`APEX-DDPG`_ tf No **Yes** **Yes**
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
Expand Down Expand Up @@ -422,7 +422,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll

Evolution Strategies
--------------------
|tensorflow|
|pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/1703.03864>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/es/es.py>`__
Code here is adapted from https://github.com/openai/evolution-strategies-starter to execute in the distributed setting with Ray.

Expand Down
8 changes: 4 additions & 4 deletions doc/source/rllib-toc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,21 @@ Algorithms

- |pytorch| |tensorflow| :ref:`Advantage Actor-Critic (A2C, A3C) <a3c>`

- |tensorflow| :ref:`Deep Deterministic Policy Gradients (DDPG, TD3) <ddpg>`
- |pytorch| |tensorflow| :ref:`Deep Deterministic Policy Gradients (DDPG, TD3) <ddpg>`

- |pytorch| |tensorflow| :ref:`Deep Q Networks (DQN, Rainbow, Parametric DQN) <dqn>`

- |pytorch| |tensorflow| :ref:`Policy Gradients <pg>`

- |pytorch| |tensorflow| :ref:`Proximal Policy Optimization (PPO) <ppo>`

- |tensorflow| :ref:`Soft Actor Critic (SAC) <sac>`
- |pytorch| |tensorflow| :ref:`Soft Actor Critic (SAC) <sac>`

* Derivative-free

- |tensorflow| :ref:`Augmented Random Search (ARS) <ars>`

- |tensorflow| :ref:`Evolution Strategies <es>`
- |pytorch| |tensorflow| :ref:`Evolution Strategies <es>`

* Multi-agent specific

Expand All @@ -124,7 +124,7 @@ Algorithms

* Offline

- |tensorflow| :ref:`Advantage Re-Weighted Imitation Learning (MARWIL) <marwil>`
- |pytorch| |tensorflow| :ref:`Advantage Re-Weighted Imitation Learning (MARWIL) <marwil>`

* Contextual bandits

Expand Down
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ py_test(
py_test(
name = "test_apex",
tags = ["agents_dir"],
size = "medium",
size = "large",
srcs = ["agents/dqn/tests/test_apex.py"]
)

Expand Down
9 changes: 4 additions & 5 deletions rllib/agents/es/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG)
from ray.rllib.utils import renamed_agent
from ray.rllib.agents.es.es import ESTrainer, DEFAULT_CONFIG
from ray.rllib.agents.es.es_tf_policy import ESTFPolicy
from ray.rllib.agents.es.es_torch_policy import ESTorchPolicy

ESAgent = renamed_agent(ESTrainer)

__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"]
__all__ = ["ESTFPolicy", "ESTorchPolicy", "ESTrainer", "DEFAULT_CONFIG"]
73 changes: 38 additions & 35 deletions rllib/agents/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import ray
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.es import optimizers, policies, utils
from ray.rllib.agents.es import optimizers, utils
from ray.rllib.agents.es.es_tf_policy import ESTFPolicy, rollout
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import FilterManager
Expand Down Expand Up @@ -72,7 +73,8 @@ def __init__(self,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.config.update(policy_params)
self.config["single_threaded"] = True
self.noise = SharedNoiseTable(noise)

env_context = EnvContext(config["env_config"] or {}, worker_index)
Expand All @@ -81,15 +83,13 @@ def __init__(self,
self.preprocessor = models.ModelCatalog.get_preprocessor(
self.env, config["model"])

self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
self.sess, self.env.action_space, self.env.observation_space,
self.preprocessor, config["observation_filter"], config["model"],
**policy_params)
policy_cls = get_policy_class(config)
self.policy = policy_cls(self.env.observation_space,
self.env.action_space, config)

@property
def filters(self):
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
return {DEFAULT_POLICY_ID: self.policy.observation_filter}

def sync_filters(self, new_filters):
for k in self.filters:
Expand All @@ -104,7 +104,7 @@ def get_filters(self, flush_after=False):
return return_filters

def rollout(self, timestep_limit, add_noise=True):
rollout_rewards, rollout_fragment_length = policies.rollout(
rollout_rewards, rollout_fragment_length = rollout(
self.policy,
self.env,
timestep_limit=timestep_limit,
Expand All @@ -113,7 +113,7 @@ def rollout(self, timestep_limit, add_noise=True):

def do_rollouts(self, params, timestep_limit=None):
# Set the network weights.
self.policy.set_weights(params)
self.policy.set_flat_weights(params)

noise_indices, returns, sign_returns, lengths = [], [], [], []
eval_returns, eval_lengths = [], []
Expand All @@ -125,7 +125,7 @@ def do_rollouts(self, params, timestep_limit=None):

if np.random.uniform() < self.config["eval_prob"]:
# Do an evaluation run with no perturbation.
self.policy.set_weights(params)
self.policy.set_flat_weights(params)
rewards, length = self.rollout(timestep_limit, add_noise=False)
eval_returns.append(rewards.sum())
eval_lengths.append(length)
Expand All @@ -138,10 +138,10 @@ def do_rollouts(self, params, timestep_limit=None):

# These two sampling steps could be done in parallel on
# different actors letting us update twice as frequently.
self.policy.set_weights(params + perturbation)
self.policy.set_flat_weights(params + perturbation)
rewards_pos, lengths_pos = self.rollout(timestep_limit)

self.policy.set_weights(params - perturbation)
self.policy.set_flat_weights(params - perturbation)
rewards_neg, lengths_neg = self.rollout(timestep_limit)

noise_indices.append(noise_index)
Expand All @@ -160,6 +160,15 @@ def do_rollouts(self, params, timestep_limit=None):
eval_lengths=eval_lengths)


def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.es.es_torch_policy import ESTorchPolicy
policy_cls = ESTorchPolicy
else:
policy_cls = ESTFPolicy
return policy_cls


class ESTrainer(Trainer):
"""Large-scale implementation of Evolution Strategies in Ray."""

Expand All @@ -168,22 +177,15 @@ class ESTrainer(Trainer):

@override(Trainer)
def _init(self, config, env_creator):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"ES does not support PyTorch yet! Use tf instead.")

policy_params = {"action_noise_std": 0.01}

config.update(policy_params)
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(env)

self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
self.sess, env.action_space, env.observation_space, preprocessor,
config["observation_filter"], config["model"], **policy_params)
policy_cls = get_policy_class(config)
self.policy = policy_cls(
obs_space=env.observation_space,
action_space=env.action_space,
config=config)
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
self.report_length = config["report_length"]

Expand All @@ -207,8 +209,9 @@ def _init(self, config, env_creator):
def _train(self):
config = self.config

theta = self.policy.get_weights()
theta = self.policy.get_flat_weights()
assert theta.dtype == np.float32
assert len(theta.shape) == 1

# Put the current policy weights in the object store.
theta_id = ray.put(theta)
Expand Down Expand Up @@ -264,14 +267,14 @@ def _train(self):
theta, update_ratio = self.optimizer.update(-g +
config["l2_coeff"] * theta)
# Set the new weights in the local copy of the policy.
self.policy.set_weights(theta)
self.policy.set_flat_weights(theta)
# Store the rewards
if len(all_eval_returns) > 0:
self.reward_list.append(np.mean(eval_returns))

# Now sync the filters
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self._workers)

info = {
Expand All @@ -293,7 +296,7 @@ def _train(self):

@override(Trainer)
def compute_action(self, observation, *args, **kwargs):
return self.policy.compute(observation, update=False)[0]
return self.policy.compute_actions(observation, update=False)[0]

@override(Trainer)
def _stop(self):
Expand Down Expand Up @@ -325,15 +328,15 @@ def _collect_results(self, theta_id, min_episodes, min_timesteps):

def __getstate__(self):
return {
"weights": self.policy.get_weights(),
"filter": self.policy.get_filter(),
"weights": self.policy.get_flat_weights(),
"filter": self.policy.observation_filter,
"episodes_so_far": self.episodes_so_far,
}

def __setstate__(self, state):
self.episodes_so_far = state["episodes_so_far"]
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
self.policy.set_flat_weights(state["weights"])
self.policy.observation_filter = state["filter"]
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self._workers)
43 changes: 23 additions & 20 deletions rllib/agents/es/policies.py → rllib/agents/es/es_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils import try_import_tf

Expand All @@ -30,7 +31,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
t = 0
observation = env.reset()
for _ in range(timestep_limit or max_timestep_limit):
ac = policy.compute(observation, add_noise=add_noise)[0]
ac = policy.compute_actions(observation, add_noise=add_noise)[0]
observation, rew, done, _ = env.step(ac)
rews.append(rew)
t += 1
Expand All @@ -40,24 +41,32 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
return rews, t


class GenericPolicy:
def __init__(self, sess, action_space, obs_space, preprocessor,
observation_filter, model_options, action_noise_std):
self.sess = sess
def make_session(single_threaded):
if not single_threaded:
return tf.Session()
return tf.Session(
config=tf.ConfigProto(
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1))


class ESTFPolicy:
def __init__(self, obs_space, action_space, config):
self.action_space = action_space
self.action_noise_std = action_noise_std
self.preprocessor = preprocessor
self.observation_filter = get_filter(observation_filter,
self.action_noise_std = config["action_noise_std"]
self.preprocessor = ModelCatalog.get_preprocessor_for_space(obs_space)
self.observation_filter = get_filter(config["observation_filter"],
self.preprocessor.shape)
self.single_threaded = config.get("single_threaded", False)
self.sess = make_session(single_threaded=self.single_threaded)
self.inputs = tf.placeholder(tf.float32,
[None] + list(self.preprocessor.shape))

# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
self.action_space, model_options, dist_type="deterministic")
self.action_space, config["model"], dist_type="deterministic")
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_options)
SampleBatch.CUR_OBS: self.inputs
}, obs_space, action_space, dist_dim, config["model"])
dist = dist_class(model.outputs, model)
self.sampler = dist.sample()

Expand All @@ -69,7 +78,7 @@ def __init__(self, sess, action_space, obs_space, preprocessor,
for _, variable in self.variables.variables.items())
self.sess.run(tf.global_variables_initializer())

def compute(self, observation, add_noise=False, update=True):
def compute_actions(self, observation, add_noise=False, update=True):
observation = self.preprocessor.transform(observation)
observation = self.observation_filter(observation[None], update=update)
action = self.sess.run(
Expand All @@ -79,14 +88,8 @@ def compute(self, observation, add_noise=False, update=True):
action += np.random.randn(*action.shape) * self.action_noise_std
return action

def set_weights(self, x):
def set_flat_weights(self, x):
self.variables.set_flat(x)

def get_weights(self):
def get_flat_weights(self):
return self.variables.get_flat()

def get_filter(self):
return self.observation_filter

def set_filter(self, observation_filter):
self.observation_filter = observation_filter
Loading

0 comments on commit 3812bfe

Please sign in to comment.