Skip to content

Commit b71152e

Browse files
Adds support for Hindsight Experience Replay (HER) (openai#299)
* Add Hindsight Experience Replay (HER) * Minor improvements
1 parent df2e846 commit b71152e

23 files changed

+1740
-37
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pip install -e .
2020
- [DDPG](baselines/ddpg)
2121
- [DQN](baselines/deepq)
2222
- [GAIL](baselines/gail)
23+
- [HER](baselines/her)
2324
- [PPO1](baselines/ppo1) (Multi-CPU using MPI)
2425
- [PPO2](baselines/ppo2) (Optimized for GPU)
2526
- [TRPO](baselines/trpo_mpi)

baselines/a2c/utils.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,24 @@ def _ortho_init(shape, dtype, partition_info=None):
3939
return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
4040
return _ortho_init
4141

42-
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0):
42+
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC'):
43+
if data_format == 'NHWC':
44+
channel_ax = 3
45+
strides = [1, stride, stride, 1]
46+
bshape = [1, 1, 1, nf]
47+
elif data_format == 'NCHW':
48+
channel_ax = 1
49+
strides = [1, 1, stride, stride]
50+
bshape = [1, nf, 1, 1]
51+
else:
52+
raise NotImplementedError
53+
nin = x.get_shape()[channel_ax].value
54+
wshape = [rf, rf, nin, nf]
4355
with tf.variable_scope(scope):
44-
nin = x.get_shape()[3].value
45-
w = tf.get_variable("w", [rf, rf, nin, nf], initializer=ortho_init(init_scale))
46-
b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0.0))
47-
return tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b
56+
w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
57+
b = tf.get_variable("b", [1, nf, 1, 1], initializer=tf.constant_initializer(0.0))
58+
if data_format == 'NHWC': b = tf.reshape(b, bshape)
59+
return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format)
4860

4961
def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
5062
with tf.variable_scope(scope):

baselines/bench/monitor.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import csv
88
import os.path as osp
99
import json
10+
import numpy as np
1011

1112
class Monitor(Wrapper):
1213
EXT = "monitor.csv"
1314
f = None
1415

15-
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=()):
16+
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
1617
Wrapper.__init__(self, env=env)
1718
self.tstart = time.time()
1819
if filename is None:
@@ -26,10 +27,12 @@ def __init__(self, env, filename, allow_early_resets=False, reset_keywords=()):
2627
filename = filename + "." + Monitor.EXT
2728
self.f = open(filename, "wt")
2829
self.f.write('#%s\n'%json.dumps({"t_start": self.tstart, 'env_id' : env.spec and env.spec.id}))
29-
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords)
30+
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords+info_keywords)
3031
self.logger.writeheader()
32+
self.f.flush()
3133

3234
self.reset_keywords = reset_keywords
35+
self.info_keywords = info_keywords
3336
self.allow_early_resets = allow_early_resets
3437
self.rewards = None
3538
self.needs_reset = True
@@ -61,6 +64,8 @@ def step(self, action):
6164
eprew = sum(self.rewards)
6265
eplen = len(self.rewards)
6366
epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)}
67+
for k in self.info_keywords:
68+
epinfo[k] = info[k]
6469
self.episode_rewards.append(eprew)
6570
self.episode_lengths.append(eplen)
6671
self.episode_times.append(time.time() - self.tstart)

baselines/common/cmd_util.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import gym
7+
from gym.wrappers import FlattenDictWrapper
78
from baselines import logger
89
from baselines.bench import Monitor
910
from baselines.common import set_global_seeds
@@ -36,6 +37,19 @@ def make_mujoco_env(env_id, seed):
3637
env.seed(seed)
3738
return env
3839

40+
def make_robotics_env(env_id, seed, rank=0):
41+
"""
42+
Create a wrapped, monitored gym.Env for MuJoCo.
43+
"""
44+
set_global_seeds(seed)
45+
env = gym.make(env_id)
46+
env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
47+
env = Monitor(
48+
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
49+
info_keywords=('is_success',))
50+
env.seed(seed)
51+
return env
52+
3953
def arg_parser():
4054
"""
4155
Create an empty argparse.ArgumentParser.
@@ -58,7 +72,17 @@ def mujoco_arg_parser():
5872
Create an argparse.ArgumentParser for run_mujoco.py.
5973
"""
6074
parser = arg_parser()
61-
parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1")
75+
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
76+
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
77+
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
78+
return parser
79+
80+
def robotics_arg_parser():
81+
"""
82+
Create an argparse.ArgumentParser for run_mujoco.py.
83+
"""
84+
parser = arg_parser()
85+
parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0')
6286
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
6387
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
6488
return parser

baselines/common/console_util.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ def fmt_item(x, l):
1616
if isinstance(x, np.ndarray):
1717
assert x.ndim==0
1818
x = x.item()
19-
if isinstance(x, float): rep = "%g"%x
19+
if isinstance(x, (float, np.float32, np.float64)):
20+
v = abs(x)
21+
if (v < 1e-4 or v > 1e+4) and v > 0:
22+
rep = "%7.2e" % x
23+
else:
24+
rep = "%7.5f" % x
2025
else: rep = str(x)
2126
return " "*(l - len(rep)) + rep
2227

baselines/common/tf_util.py

+17
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,20 @@ def get_placeholder_cached(name):
261261

262262
def flattenallbut0(x):
263263
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
264+
265+
266+
# ================================================================
267+
# Diagnostics
268+
# ================================================================
269+
270+
def display_var_info(vars):
271+
from baselines import logger
272+
count_params = 0
273+
for v in vars:
274+
name = v.name
275+
if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
276+
count_params += np.prod(v.shape.as_list())
277+
if "/b:" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
278+
logger.info(" %s%s%s" % (name, " "*(55-len(name)), str(v.shape)))
279+
logger.info("Total model parameters: %0.1f million" % (count_params*1e-6))
280+
+24-13
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,42 @@
11
import numpy as np
2+
import gym
23
from . import VecEnv
34

45
class DummyVecEnv(VecEnv):
56
def __init__(self, env_fns):
67
self.envs = [fn() for fn in env_fns]
7-
env = self.envs[0]
8+
env = self.envs[0]
89
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
9-
self.ts = np.zeros(len(self.envs), dtype='int')
10+
11+
obs_spaces = self.observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (self.observation_space,)
12+
self.buf_obs = [np.zeros((self.num_envs,) + tuple(s.shape), s.dtype) for s in obs_spaces]
13+
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
14+
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
15+
self.buf_infos = [{} for _ in range(self.num_envs)]
1016
self.actions = None
1117

1218
def step_async(self, actions):
1319
self.actions = actions
1420

1521
def step_wait(self):
16-
results = [env.step(a) for (a,env) in zip(self.actions, self.envs)]
17-
obs, rews, dones, infos = map(np.array, zip(*results))
18-
self.ts += 1
19-
for (i, done) in enumerate(dones):
20-
if done:
21-
obs[i] = self.envs[i].reset()
22-
self.ts[i] = 0
23-
self.actions = None
24-
return np.array(obs), np.array(rews), np.array(dones), infos
22+
for i in range(self.num_envs):
23+
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i])
24+
if isinstance(obs_tuple, (tuple, list)):
25+
for t,x in enumerate(obs_tuple):
26+
self.buf_obs[t][i] = x
27+
else:
28+
self.buf_obs[0][i] = obs_tuple
29+
return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos
2530

2631
def reset(self):
27-
results = [env.reset() for env in self.envs]
28-
return np.array(results)
32+
for i in range(self.num_envs):
33+
obs_tuple = self.envs[i].reset()
34+
if isinstance(obs_tuple, (tuple, list)):
35+
for t,x in enumerate(obs_tuple):
36+
self.buf_obs[t][i] = x
37+
else:
38+
self.buf_obs[0][i] = obs_tuple
39+
return self.buf_obs
2940

3041
def close(self):
3142
return

baselines/her/README.md

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Hindsight Experience Replay
2+
For details on Hindsight Experience Replay (HER), please read the [paper](https://arxiv.org/pdf/1707.01495.pdf).
3+
4+
## How to use Hindsight Experience Replay
5+
6+
### Getting started
7+
Training an agent is very simple:
8+
```bash
9+
python -m baselines.her.experiment.train
10+
```
11+
This will train a DDPG+HER agent on the `FetchReach` environment.
12+
You should see the success rate go up quickly to `1.0`, which means that the agent achieves the
13+
desired goal in 100% of the cases.
14+
The training script logs other diagnostics as well and pickles the best policy so far (w.r.t. to its test success rate),
15+
the latest policy, and, if enabled, a history of policies every K epochs.
16+
17+
To inspect what the agent has learned, use the play script:
18+
```bash
19+
python -m baselines.her.experiment.play /path/to/an/experiment/policy_best.pkl
20+
```
21+
You can try it right now with the results of the training step (the script prints out the path for you).
22+
This should visualize the current policy for 10 episodes and will also print statistics.
23+
24+
25+
### Advanced usage
26+
The train script comes with advanced features like MPI support, that allows to scale across all cores of a single machine.
27+
To see all available options, simply run this command:
28+
```bash
29+
python -m baselines.her.experiment.train --help
30+
```
31+
To run on, say, 20 CPU cores, you can use the following command:
32+
```bash
33+
python -m baselines.her.experiment.train --num_cpu 20
34+
```
35+
That's it, you are now running rollouts using 20 MPI workers and average gradients for network updates across all 20 core.

baselines/her/__init__.py

Whitespace-only changes.

baselines/her/actor_critic.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import tensorflow as tf
2+
from baselines.her.util import store_args, nn
3+
4+
5+
class ActorCritic:
6+
@store_args
7+
def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers,
8+
**kwargs):
9+
"""The actor-critic network and related training code.
10+
11+
Args:
12+
inputs_tf (dict of tensors): all necessary inputs for the network: the
13+
observation (o), the goal (g), and the action (u)
14+
dimo (int): the dimension of the observations
15+
dimg (int): the dimension of the goals
16+
dimu (int): the dimension of the actions
17+
max_u (float): the maximum magnitude of actions; action outputs will be scaled
18+
accordingly
19+
o_stats (baselines.her.Normalizer): normalizer for observations
20+
g_stats (baselines.her.Normalizer): normalizer for goals
21+
hidden (int): number of hidden units that should be used in hidden layers
22+
layers (int): number of hidden layers
23+
"""
24+
self.o_tf = inputs_tf['o']
25+
self.g_tf = inputs_tf['g']
26+
self.u_tf = inputs_tf['u']
27+
28+
# Prepare inputs for actor and critic.
29+
o = self.o_stats.normalize(self.o_tf)
30+
g = self.g_stats.normalize(self.g_tf)
31+
input_pi = tf.concat(axis=1, values=[o, g]) # for actor
32+
33+
# Networks.
34+
with tf.variable_scope('pi'):
35+
self.pi_tf = self.max_u * tf.tanh(nn(
36+
input_pi, [self.hidden] * self.layers + [self.dimu]))
37+
with tf.variable_scope('Q'):
38+
# for policy training
39+
input_Q = tf.concat(axis=1, values=[o, g, self.pi_tf / self.max_u])
40+
self.Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1])
41+
# for critic training
42+
input_Q = tf.concat(axis=1, values=[o, g, self.u_tf / self.max_u])
43+
self._input_Q = input_Q # exposed for tests
44+
self.Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True)

0 commit comments

Comments
 (0)