Skip to content

Commit 11a884b

Browse files
MorvanZhouMorvan Zhou
authored andcommitted
update ppo
1 parent 2d91181 commit 11a884b

File tree

2 files changed

+259
-65
lines changed
  • contents/12_Proximal_Policy_Optimization
  • experiments/Robot_arm

2 files changed

+259
-65
lines changed
Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""
22
A simple version of OpenAI's Proximal Policy Optimization (PPO). [http://adsabs.harvard.edu/abs/2017arXiv170706347S]
3+
34
Distributing workers in parallel to collect data, then stop worker's roll-out and train PPO on collected data.
4-
Restart workers once PPO is updated. I think A3C may be faster than this version of PPO, because this PPO has to stop
5-
parallel data-collection for training.
5+
Restart workers once PPO is updated.
6+
7+
The global PPO updating rule is adopted from DeepMind's paper (DPPO):
8+
Emergence of Locomotion Behaviours in Rich Environments (Google Deepmind): [http://adsabs.harvard.edu/abs/2017arXiv170702286H]
69
710
View more on my tutorial website: https://morvanzhou.github.io/tutorials
811
@@ -15,28 +18,26 @@
1518
from tensorflow.contrib.distributions import Normal
1619
import numpy as np
1720
import matplotlib.pyplot as plt
18-
import gym, threading
19-
from queue import Queue
21+
import gym, threading, queue
2022

21-
EP_MAX = 600
23+
EP_MAX = 1000
2224
EP_LEN = 200
23-
N_WORKER = 3
24-
GAMMA = 0.9
25-
A_LR = 0.0001
26-
C_LR = 0.0002
27-
ROLL_OUT_STEP = 32
28-
UPDATE_STEP = 10
29-
EPSILON = 0.2 # Clipped surrogate objective
30-
S_DIM, A_DIM = 3, 1
25+
N_WORKER = 4 # parallel workers
26+
GAMMA = 0.9 # reward discount factor
27+
A_LR = 0.0001 # learning rate for actor
28+
C_LR = 0.001 # learning rate for critic
29+
MIN_BATCH_SIZE = 64 # minimum batch size for updating PPO
30+
UPDATE_STEP = 5 # loop update operation n-steps
31+
EPSILON = 0.2 # for clipping surrogate objective
32+
GAME = 'Pendulum-v0'
33+
S_DIM, A_DIM = 3, 1 # state and action dimension
3134

3235

3336
class PPO(object):
34-
def __init__(self, s_dim, a_dim,):
35-
self.a_dim = a_dim
36-
self.s_dim = s_dim
37+
def __init__(self):
3738
self.sess = tf.Session()
3839

39-
self.tfs = tf.placeholder(tf.float32, [None, s_dim], 'state')
40+
self.tfs = tf.placeholder(tf.float32, [None, S_DIM], 'state')
4041

4142
# critic
4243
l1 = tf.layers.dense(self.tfs, 100, tf.nn.relu)
@@ -52,7 +53,7 @@ def __init__(self, s_dim, a_dim,):
5253
self.sample_op = tf.squeeze(pi.sample(1), axis=0) # choosing action
5354
self.update_oldpi_op = [oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params)]
5455

55-
self.tfa = tf.placeholder(tf.float32, [None, a_dim], 'action')
56+
self.tfa = tf.placeholder(tf.float32, [None, A_DIM], 'action')
5657
self.tfadv = tf.placeholder(tf.float32, [None, 1], 'advantage')
5758
# ratio = tf.exp(pi.log_prob(self.tfa) - oldpi.log_prob(self.tfa))
5859
ratio = pi.prob(self.tfa) / (oldpi.prob(self.tfa) + 1e-5)
@@ -65,25 +66,27 @@ def __init__(self, s_dim, a_dim,):
6566
self.atrain_op = tf.train.AdamOptimizer(A_LR).minimize(self.aloss)
6667
self.sess.run(tf.global_variables_initializer())
6768

68-
def update(self, coord, queue, rolling_events):
69-
while not coord.should_stop():
70-
if queue.full():
69+
def update(self):
70+
global GLOBAL_UPDATE_COUNTER
71+
while not COORD.should_stop():
72+
if GLOBAL_EP < EP_MAX:
73+
UPDATE_EVENT.wait() # wait until get batch of data
7174
self.sess.run(self.update_oldpi_op) # old pi to pi
72-
73-
data = [queue.get() for _ in range(queue.qsize())]
75+
data = [QUEUE.get() for _ in range(QUEUE.qsize())]
7476
data = np.vstack(data)
75-
s, a, r = data[:, :self.s_dim], data[:, self.s_dim: self.s_dim + self.a_dim], data[:, -1:]
77+
s, a, r = data[:, :S_DIM], data[:, S_DIM: S_DIM + A_DIM], data[:, -1:]
7678
adv = self.sess.run(self.advantage, {self.tfs: s, self.tfdc_r: r})
7779
[self.sess.run(self.atrain_op, {self.tfs: s, self.tfa: a, self.tfadv: adv}) for _ in range(UPDATE_STEP)]
7880
[self.sess.run(self.ctrain_op, {self.tfs: s, self.tfdc_r: r}) for _ in range(UPDATE_STEP)]
79-
80-
[re.set() for re in rolling_events] # set roll-out available
81+
UPDATE_EVENT.clear() # updating finished
82+
GLOBAL_UPDATE_COUNTER = 0 # reset counter
83+
ROLLING_EVENT.set() # set roll-out available
8184

8285
def _build_anet(self, name, trainable):
8386
with tf.variable_scope(name):
8487
l1 = tf.layers.dense(self.tfs, 200, tf.nn.relu, trainable=trainable)
85-
mu = 2 * tf.layers.dense(l1, self.a_dim, tf.nn.tanh, trainable=trainable)
86-
sigma = tf.layers.dense(l1, self.a_dim, tf.nn.softplus, trainable=trainable)
88+
mu = 2 * tf.layers.dense(l1, A_DIM, tf.nn.tanh, trainable=trainable)
89+
sigma = tf.layers.dense(l1, A_DIM, tf.nn.softplus, trainable=trainable)
8790
norm_dist = Normal(loc=mu, scale=sigma)
8891
params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
8992
return norm_dist, params
@@ -99,83 +102,83 @@ def get_v(self, s):
99102

100103

101104
class Worker(object):
102-
def __init__(self, globalPPO, roll_out_steps, wid, game, ep_len, rolling_event):
103-
self.roll_out_steps = roll_out_steps
105+
def __init__(self, wid):
104106
self.wid = wid
105-
self.ep_len = ep_len
106-
self.rolling_event = rolling_event
107-
self.env = gym.make(game).unwrapped
108-
self.ppo = globalPPO
109-
110-
def work(self, coord, queue,):
111-
global GLOBAL_EP, GLOBAL_RUNNING_R
112-
while not coord.should_stop():
107+
self.env = gym.make(GAME).unwrapped
108+
self.ppo = GLOBAL_PPO
109+
110+
def work(self):
111+
global GLOBAL_EP, GLOBAL_RUNNING_R, GLOBAL_UPDATE_COUNTER
112+
while not COORD.should_stop():
113113
s = self.env.reset()
114114
ep_r = 0
115115
buffer_s, buffer_a, buffer_r = [], [], []
116-
for t in range(self.ep_len):
116+
for t in range(EP_LEN):
117+
if not ROLLING_EVENT.is_set(): # while global PPO is updating
118+
ROLLING_EVENT.wait() # wait until PPO is updated
119+
buffer_s, buffer_a, buffer_r = [], [], [] # clear history buffer, use new policy to collect data
117120
a = self.ppo.choose_action(s)
118121
s_, r, done, _ = self.env.step(a)
119122
buffer_s.append(s)
120123
buffer_a.append(a)
121-
buffer_r.append((r + 8) / 8) # normalize reward, find to be useful
124+
buffer_r.append((r + 8) / 8) # normalize reward, find to be useful
122125
s = s_
123126
ep_r += r
124127

125-
# get update buffer
126-
if (t+1) % self.roll_out_steps == 0 or t == self.ep_len - 1:
128+
GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size
129+
if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
127130
v_s_ = self.ppo.get_v(s_)
128-
discounted_r = [] # compute discounted reward
131+
discounted_r = [] # compute discounted reward
129132
for r in buffer_r[::-1]:
130133
v_s_ = r + GAMMA * v_s_
131134
discounted_r.append(v_s_)
132135
discounted_r.reverse()
133136

134137
bs, ba, br = np.vstack(buffer_s), np.vstack(buffer_a), np.array(discounted_r)[:, np.newaxis]
135138
buffer_s, buffer_a, buffer_r = [], [], []
136-
queue.put(np.hstack((bs, ba, br)))
137-
if GLOBAL_EP >= EP_MAX: # stop training
138-
coord.request_stop()
139+
QUEUE.put(np.hstack((bs, ba, br)))
140+
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
141+
ROLLING_EVENT.clear() # stop collecting data
142+
UPDATE_EVENT.set() # globalPPO update
143+
144+
if GLOBAL_EP >= EP_MAX: # stop training
145+
COORD.request_stop()
139146
break
140-
else:
141-
self.rolling_event.clear() # stop roll-out
142-
self.rolling_event.wait() # stop and wait until network is updated
143147

144148
# record reward changes, plot later
145149
if len(GLOBAL_RUNNING_R) == 0: GLOBAL_RUNNING_R.append(ep_r)
146150
else: GLOBAL_RUNNING_R.append(GLOBAL_RUNNING_R[-1]*0.9+ep_r*0.1)
147151
GLOBAL_EP += 1
148-
print('W%i' % self.wid, '|Ep: %i' % GLOBAL_EP, '|Ep_r: %.2f' % ep_r,)
152+
print('{0:.1f}%'.format(GLOBAL_EP/EP_MAX*100), '|W%i' % self.wid, '|Ep_r: %.2f' % ep_r,)
149153

150154

151155
if __name__ == '__main__':
152-
globalPPO = PPO(S_DIM, A_DIM)
153-
workers = [Worker(
154-
globalPPO=globalPPO, roll_out_steps=ROLL_OUT_STEP, wid=i, game='Pendulum-v0',
155-
ep_len=EP_LEN, rolling_event=threading.Event()) for i in range(N_WORKER)]
156-
157-
GLOBAL_EP = 0
156+
GLOBAL_PPO = PPO()
157+
UPDATE_EVENT, ROLLING_EVENT = threading.Event(), threading.Event()
158+
UPDATE_EVENT.clear() # no update now
159+
ROLLING_EVENT.set() # start to roll out
160+
workers = [Worker(wid=i) for i in range(N_WORKER)]
161+
162+
GLOBAL_UPDATE_COUNTER, GLOBAL_EP = 0, 0
158163
GLOBAL_RUNNING_R = []
159164
COORD = tf.train.Coordinator()
160-
QUEUE = Queue(maxsize=N_WORKER)
165+
QUEUE = queue.Queue()
161166
threads = []
162167
for worker in workers: # worker threads
163-
t = threading.Thread(target=worker.work, args=(COORD, QUEUE))
168+
t = threading.Thread(target=worker.work, args=())
164169
t.start()
165170
threads.append(t)
166-
# update thread for network
167-
threads.append(threading.Thread(target=globalPPO.update, args=(COORD, QUEUE, [w.rolling_event for w in workers])))
171+
# add a PPO updating thread
172+
threads.append(threading.Thread(target=GLOBAL_PPO.update,))
168173
threads[-1].start()
169174
COORD.join(threads)
170175

171-
# plot reward change
176+
# plot reward change and testing
172177
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
173178
plt.xlabel('Episode'); plt.ylabel('Moving reward'); plt.ion(); plt.show()
174-
175-
env = gym.make('Pendulum-v0') # testing
179+
env = gym.make('Pendulum-v0')
176180
while True:
177181
s = env.reset()
178182
for t in range(400):
179183
env.render()
180-
a = globalPPO.choose_action(s)
181-
s = env.step(a)[0]
184+
s = env.step(GLOBAL_PPO.choose_action(s))[0]

0 commit comments

Comments
 (0)