Skip to content

Commit ab59de6

Browse files
authored
mpi-less baselines (openai#689)
* make baselines run without mpi wip * squash-merged latest master * further removing MPI references where unnecessary * more MPI removal * syntax and flake8 * MpiAdam becomes regular Adam if Mpi not present * autopep8 * add assertion to test in mpi_adam; fix trpo_mpi failure without MPI on cartpole * mpiless ddpg
1 parent a071fa7 commit ab59de6

File tree

8 files changed

+124
-45
lines changed

8 files changed

+124
-45
lines changed

Dockerfile

+2-11
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
1-
FROM ubuntu:16.04
1+
FROM python:3.6
22

3-
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
3+
# RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
44
ENV CODE_DIR /root/code
5-
ENV VENV /root/venv
6-
7-
RUN \
8-
pip install virtualenv && \
9-
virtualenv $VENV --python=python3 && \
10-
. $VENV/bin/activate && \
11-
pip install --upgrade pip
12-
13-
ENV PATH=$VENV/bin:$PATH
145

156
COPY . $CODE_DIR/baselines
167
WORKDIR $CODE_DIR/baselines

baselines/common/mpi_adam.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from mpi4py import MPI
21
import baselines.common.tf_util as U
32
import tensorflow as tf
43
import numpy as np
4+
try:
5+
from mpi4py import MPI
6+
except ImportError:
7+
MPI = None
8+
59

610
class MpiAdam(object):
711
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
@@ -16,16 +20,19 @@ def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_gra
1620
self.t = 0
1721
self.setfromflat = U.SetFromFlat(var_list)
1822
self.getflat = U.GetFlat(var_list)
19-
self.comm = MPI.COMM_WORLD if comm is None else comm
23+
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
2024

2125
def update(self, localg, stepsize):
2226
if self.t % 100 == 0:
2327
self.check_synced()
2428
localg = localg.astype('float32')
25-
globalg = np.zeros_like(localg)
26-
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
27-
if self.scale_grad_by_procs:
28-
globalg /= self.comm.Get_size()
29+
if self.comm is not None:
30+
globalg = np.zeros_like(localg)
31+
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
32+
if self.scale_grad_by_procs:
33+
globalg /= self.comm.Get_size()
34+
else:
35+
globalg = np.copy(localg)
2936

3037
self.t += 1
3138
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
@@ -35,11 +42,15 @@ def update(self, localg, stepsize):
3542
self.setfromflat(self.getflat() + step)
3643

3744
def sync(self):
45+
if self.comm is None:
46+
return
3847
theta = self.getflat()
3948
self.comm.Bcast(theta, root=0)
4049
self.setfromflat(theta)
4150

4251
def check_synced(self):
52+
if self.comm is None:
53+
return
4354
if self.comm.Get_rank() == 0: # this is root
4455
theta = self.getflat()
4556
self.comm.Bcast(theta, root=0)
@@ -63,17 +74,30 @@ def test_MpiAdam():
6374
do_update = U.function([], loss, updates=[update_op])
6475

6576
tf.get_default_session().run(tf.global_variables_initializer())
77+
losslist_ref = []
6678
for i in range(10):
67-
print(i,do_update())
79+
l = do_update()
80+
print(i, l)
81+
losslist_ref.append(l)
82+
83+
6884

6985
tf.set_random_seed(0)
7086
tf.get_default_session().run(tf.global_variables_initializer())
7187

7288
var_list = [a,b]
73-
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
89+
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
7490
adam = MpiAdam(var_list)
7591

92+
losslist_test = []
7693
for i in range(10):
7794
l,g = lossandgrad()
7895
adam.update(g, stepsize)
7996
print(i,l)
97+
losslist_test.append(l)
98+
99+
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
100+
101+
102+
if __name__ == '__main__':
103+
test_MpiAdam()

baselines/common/mpi_running_mean_std.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from mpi4py import MPI
1+
try:
2+
from mpi4py import MPI
3+
except ImportError:
4+
MPI = None
5+
26
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
37

48
class RunningMeanStd(object):
@@ -39,7 +43,8 @@ def update(self, x):
3943
n = int(np.prod(self.shape))
4044
totalvec = np.zeros(n*2+1, 'float64')
4145
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
42-
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
46+
if MPI is not None:
47+
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
4348
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
4449

4550
@U.in_session

baselines/ddpg/ddpg.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
from baselines import logger
1414
import numpy as np
15-
from mpi4py import MPI
1615

16+
try:
17+
from mpi4py import MPI
18+
except ImportError:
19+
MPI = None
1720

1821
def learn(network, env,
1922
seed=None,
@@ -49,7 +52,11 @@ def learn(network, env,
4952
else:
5053
nb_epochs = 500
5154

52-
rank = MPI.COMM_WORLD.Get_rank()
55+
if MPI is not None:
56+
rank = MPI.COMM_WORLD.Get_rank()
57+
else:
58+
rank = 0
59+
5360
nb_actions = env.action_space.shape[-1]
5461
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
5562

@@ -199,7 +206,11 @@ def learn(network, env,
199206
eval_episode_rewards_history.append(eval_episode_reward[d])
200207
eval_episode_reward[d] = 0.0
201208

202-
mpi_size = MPI.COMM_WORLD.Get_size()
209+
if MPI is not None:
210+
mpi_size = MPI.COMM_WORLD.Get_size()
211+
else:
212+
mpi_size = 1
213+
203214
# Log stats.
204215
# XXX shouldn't call np.mean on variable length lists
205216
duration = time.time() - start_time
@@ -233,7 +244,10 @@ def as_scalar(x):
233244
else:
234245
raise ValueError('expected scalar, got %s'%x)
235246

236-
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]))
247+
combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])
248+
if MPI is not None:
249+
combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums)
250+
237251
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
238252

239253
# Total statistics.

baselines/ddpg/ddpg_learner.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from baselines.common.mpi_adam import MpiAdam
1010
import baselines.common.tf_util as U
1111
from baselines.common.mpi_running_mean_std import RunningMeanStd
12-
from mpi4py import MPI
12+
try:
13+
from mpi4py import MPI
14+
except ImportError:
15+
MPI = None
1316

1417
def normalize(x, stats):
1518
if stats is None:
@@ -358,6 +361,11 @@ def get_stats(self):
358361
return stats
359362

360363
def adapt_param_noise(self):
364+
try:
365+
from mpi4py import MPI
366+
except ImportError:
367+
MPI = None
368+
361369
if self.param_noise is None:
362370
return 0.
363371

@@ -371,7 +379,16 @@ def adapt_param_noise(self):
371379
self.param_noise_stddev: self.param_noise.current_stddev,
372380
})
373381

374-
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
382+
if MPI is not None:
383+
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
384+
else:
385+
mean_distance = distance
386+
387+
if MPI is not None:
388+
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
389+
else:
390+
mean_distance = distance
391+
375392
self.param_noise.adapt(mean_distance)
376393
return mean_distance
377394

baselines/ppo2/ppo2.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from baselines.common.policies import build_policy
1111
from baselines.common.runners import AbstractEnvRunner
1212
from baselines.common.tf_util import get_session, save_variables, load_variables
13-
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
1413

15-
from mpi4py import MPI
14+
try:
15+
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
16+
from mpi4py import MPI
17+
from baselines.common.mpi_util import sync_from_root
18+
except ImportError:
19+
MPI = None
20+
1621
from baselines.common.tf_util import initialize
17-
from baselines.common.mpi_util import sync_from_root
1822

1923
class Model(object):
2024
"""
@@ -93,7 +97,10 @@ def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
9397
# 1. Get the model parameters
9498
params = tf.trainable_variables('ppo2_model')
9599
# 2. Build our trainer
96-
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
100+
if MPI is not None:
101+
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
102+
else:
103+
trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
97104
# 3. Calculate the gradients
98105
grads_and_var = trainer.compute_gradients(loss, params)
99106
grads, var = zip(*grads_and_var)
@@ -136,10 +143,12 @@ def train(lr, cliprange, obs, returns, masks, actions, values, neglogpacs, state
136143
self.save = functools.partial(save_variables, sess=sess)
137144
self.load = functools.partial(load_variables, sess=sess)
138145

139-
if MPI.COMM_WORLD.Get_rank() == 0:
146+
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
140147
initialize()
141148
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
142-
sync_from_root(sess, global_variables) #pylint: disable=E1101
149+
150+
if MPI is not None:
151+
sync_from_root(sess, global_variables) #pylint: disable=E1101
143152

144153
class Runner(AbstractEnvRunner):
145154
"""
@@ -392,9 +401,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
392401
logger.logkv('time_elapsed', tnow - tfirststart)
393402
for (lossval, lossname) in zip(lossvals, model.loss_names):
394403
logger.logkv(lossname, lossval)
395-
if MPI.COMM_WORLD.Get_rank() == 0:
404+
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
396405
logger.dumpkvs()
397-
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and MPI.COMM_WORLD.Get_rank() == 0:
406+
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
398407
checkdir = osp.join(logger.get_dir(), 'checkpoints')
399408
os.makedirs(checkdir, exist_ok=True)
400409
savepath = osp.join(checkdir, '%.5i'%update)

baselines/trpo_mpi/trpo_mpi.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import tensorflow as tf, numpy as np
55
import time
66
from baselines.common import colorize
7-
from mpi4py import MPI
87
from collections import deque
98
from baselines.common import set_global_seeds
109
from baselines.common.mpi_adam import MpiAdam
@@ -13,6 +12,11 @@
1312
from baselines.common.policies import build_policy
1413
from contextlib import contextmanager
1514

15+
try:
16+
from mpi4py import MPI
17+
except ImportError:
18+
MPI = None
19+
1620
def traj_segment_generator(pi, env, horizon, stochastic):
1721
# Initialize state variables
1822
t = 0
@@ -146,9 +150,12 @@ def learn(*,
146150
147151
'''
148152

149-
150-
nworkers = MPI.COMM_WORLD.Get_size()
151-
rank = MPI.COMM_WORLD.Get_rank()
153+
if MPI is not None:
154+
nworkers = MPI.COMM_WORLD.Get_size()
155+
rank = MPI.COMM_WORLD.Get_rank()
156+
else:
157+
nworkers = 1
158+
rank = 0
152159

153160
cpus_per_worker = 1
154161
U.get_session(config=tf.ConfigProto(
@@ -237,17 +244,23 @@ def timed(msg):
237244

238245
def allmean(x):
239246
assert isinstance(x, np.ndarray)
240-
out = np.empty_like(x)
241-
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
242-
out /= nworkers
247+
if MPI is not None:
248+
out = np.empty_like(x)
249+
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
250+
out /= nworkers
251+
else:
252+
out = np.copy(x)
253+
243254
return out
244255

245256
U.initialize()
246257
if load_path is not None:
247258
pi.load(load_path)
248259

249260
th_init = get_flat()
250-
MPI.COMM_WORLD.Bcast(th_init, root=0)
261+
if MPI is not None:
262+
MPI.COMM_WORLD.Bcast(th_init, root=0)
263+
251264
set_from_flat(th_init)
252265
vfadam.sync()
253266
print("Init param sum", th_init.sum(), flush=True)
@@ -353,7 +366,11 @@ def fisher_vector_product(p):
353366
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
354367

355368
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
356-
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
369+
if MPI is not None:
370+
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
371+
else:
372+
listoflrpairs = [lrlocal]
373+
357374
lens, rews = map(flatten_lists, zip(*listoflrpairs))
358375
lenbuffer.extend(lens)
359376
rewbuffer.extend(rews)

setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
],
1616
'bullet': [
1717
'pybullet',
18+
],
19+
'mpi': [
20+
'mpi4py'
1821
]
1922
}
2023

@@ -34,7 +37,6 @@
3437
'joblib',
3538
'dill',
3639
'progressbar2',
37-
'mpi4py',
3840
'cloudpickle',
3941
'click',
4042
'opencv-python'

0 commit comments

Comments
 (0)