|
10 | 10 | from baselines.common.policies import build_policy
|
11 | 11 | from baselines.common.runners import AbstractEnvRunner
|
12 | 12 | from baselines.common.tf_util import get_session, save_variables, load_variables
|
13 |
| -from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer |
14 | 13 |
|
15 |
| -from mpi4py import MPI |
| 14 | +try: |
| 15 | + from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer |
| 16 | + from mpi4py import MPI |
| 17 | + from baselines.common.mpi_util import sync_from_root |
| 18 | +except ImportError: |
| 19 | + MPI = None |
| 20 | + |
16 | 21 | from baselines.common.tf_util import initialize
|
17 |
| -from baselines.common.mpi_util import sync_from_root |
18 | 22 |
|
19 | 23 | class Model(object):
|
20 | 24 | """
|
@@ -93,7 +97,10 @@ def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
93 | 97 | # 1. Get the model parameters
|
94 | 98 | params = tf.trainable_variables('ppo2_model')
|
95 | 99 | # 2. Build our trainer
|
96 |
| - trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5) |
| 100 | + if MPI is not None: |
| 101 | + trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5) |
| 102 | + else: |
| 103 | + trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) |
97 | 104 | # 3. Calculate the gradients
|
98 | 105 | grads_and_var = trainer.compute_gradients(loss, params)
|
99 | 106 | grads, var = zip(*grads_and_var)
|
@@ -136,10 +143,12 @@ def train(lr, cliprange, obs, returns, masks, actions, values, neglogpacs, state
|
136 | 143 | self.save = functools.partial(save_variables, sess=sess)
|
137 | 144 | self.load = functools.partial(load_variables, sess=sess)
|
138 | 145 |
|
139 |
| - if MPI.COMM_WORLD.Get_rank() == 0: |
| 146 | + if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: |
140 | 147 | initialize()
|
141 | 148 | global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
142 |
| - sync_from_root(sess, global_variables) #pylint: disable=E1101 |
| 149 | + |
| 150 | + if MPI is not None: |
| 151 | + sync_from_root(sess, global_variables) #pylint: disable=E1101 |
143 | 152 |
|
144 | 153 | class Runner(AbstractEnvRunner):
|
145 | 154 | """
|
@@ -392,9 +401,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
392 | 401 | logger.logkv('time_elapsed', tnow - tfirststart)
|
393 | 402 | for (lossval, lossname) in zip(lossvals, model.loss_names):
|
394 | 403 | logger.logkv(lossname, lossval)
|
395 |
| - if MPI.COMM_WORLD.Get_rank() == 0: |
| 404 | + if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: |
396 | 405 | logger.dumpkvs()
|
397 |
| - if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and MPI.COMM_WORLD.Get_rank() == 0: |
| 406 | + if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0): |
398 | 407 | checkdir = osp.join(logger.get_dir(), 'checkpoints')
|
399 | 408 | os.makedirs(checkdir, exist_ok=True)
|
400 | 409 | savepath = osp.join(checkdir, '%.5i'%update)
|
|
0 commit comments