Skip to content

Commit a3ee123

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents f71dc71 + 4abdde1 commit a3ee123

File tree

2 files changed

+200
-1
lines changed

2 files changed

+200
-1
lines changed

contents/12_Proximal_Policy_Optimization/DPPO.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def work(self):
124124
s = s_
125125
ep_r += r
126126

127-
GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size, no need to wait other workers
127+
GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size, no need to wait other workers
128128
if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
129129
v_s_ = self.ppo.get_v(s_)
130130
discounted_r = [] # compute discounted reward
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
"""
2+
A simple version of OpenAI's Proximal Policy Optimization (PPO). [https://arxiv.org/abs/1707.06347]
3+
4+
Distributing workers in parallel to collect data, then stop worker's roll-out and train PPO on collected data.
5+
Restart workers once PPO is updated.
6+
7+
The global PPO updating rule is adopted from DeepMind's paper (DPPO):
8+
Emergence of Locomotion Behaviours in Rich Environments (Google Deepmind): [https://arxiv.org/abs/1707.02286]
9+
10+
View more on my tutorial website: https://morvanzhou.github.io/tutorials
11+
12+
Dependencies:
13+
tensorflow 1.8.0
14+
gym 0.9.2
15+
"""
16+
17+
import tensorflow as tf
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
import gym, threading, queue
21+
22+
EP_MAX = 1000
23+
EP_LEN = 500
24+
N_WORKER = 4 # parallel workers
25+
GAMMA = 0.9 # reward discount factor
26+
A_LR = 0.0001 # learning rate for actor
27+
C_LR = 0.0001 # learning rate for critic
28+
MIN_BATCH_SIZE = 64 # minimum batch size for updating PPO
29+
UPDATE_STEP = 15 # loop update operation n-steps
30+
EPSILON = 0.2 # for clipping surrogate objective
31+
GAME = 'CartPole-v0'
32+
33+
env = gym.make(GAME)
34+
S_DIM = env.observation_space.shape[0]
35+
A_DIM = env.action_space.n
36+
37+
38+
class PPONet(object):
39+
def __init__(self):
40+
self.sess = tf.Session()
41+
self.tfs = tf.placeholder(tf.float32, [None, S_DIM], 'state')
42+
43+
# critic
44+
w_init = tf.random_normal_initializer(0., .1)
45+
lc = tf.layers.dense(self.tfs, 200, tf.nn.relu, kernel_initializer=w_init, name='lc')
46+
self.v = tf.layers.dense(lc, 1)
47+
self.tfdc_r = tf.placeholder(tf.float32, [None, 1], 'discounted_r')
48+
self.advantage = self.tfdc_r - self.v
49+
self.closs = tf.reduce_mean(tf.square(self.advantage))
50+
self.ctrain_op = tf.train.AdamOptimizer(C_LR).minimize(self.closs)
51+
52+
# actor
53+
self.pi, pi_params = self._build_anet('pi', trainable=True)
54+
oldpi, oldpi_params = self._build_anet('oldpi', trainable=False)
55+
56+
self.update_oldpi_op = [oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params)]
57+
58+
self.tfa = tf.placeholder(tf.int32, [None, ], 'action')
59+
self.tfadv = tf.placeholder(tf.float32, [None, 1], 'advantage')
60+
61+
a_indices = tf.stack([tf.range(tf.shape(self.tfa)[0], dtype=tf.int32), self.tfa], axis=1)
62+
pi_prob = tf.gather_nd(params=self.pi, indices=a_indices) # shape=(None, )
63+
oldpi_prob = tf.gather_nd(params=oldpi, indices=a_indices) # shape=(None, )
64+
ratio = pi_prob/oldpi_prob
65+
surr = ratio * self.tfadv # surrogate loss
66+
67+
self.aloss = -tf.reduce_mean(tf.minimum( # clipped surrogate objective
68+
surr,
69+
tf.clip_by_value(ratio, 1. - EPSILON, 1. + EPSILON) * self.tfadv))
70+
71+
self.atrain_op = tf.train.AdamOptimizer(A_LR).minimize(self.aloss)
72+
self.sess.run(tf.global_variables_initializer())
73+
74+
def update(self):
75+
global GLOBAL_UPDATE_COUNTER
76+
while not COORD.should_stop():
77+
if GLOBAL_EP < EP_MAX:
78+
UPDATE_EVENT.wait() # wait until get batch of data
79+
self.sess.run(self.update_oldpi_op) # copy pi to old pi
80+
data = [QUEUE.get() for _ in range(QUEUE.qsize())] # collect data from all workers
81+
data = np.vstack(data)
82+
s, a, r = data[:, :S_DIM], data[:, S_DIM: S_DIM + 1].ravel(), data[:, -1:]
83+
adv = self.sess.run(self.advantage, {self.tfs: s, self.tfdc_r: r})
84+
# update actor and critic in a update loop
85+
[self.sess.run(self.atrain_op, {self.tfs: s, self.tfa: a, self.tfadv: adv}) for _ in range(UPDATE_STEP)]
86+
[self.sess.run(self.ctrain_op, {self.tfs: s, self.tfdc_r: r}) for _ in range(UPDATE_STEP)]
87+
UPDATE_EVENT.clear() # updating finished
88+
GLOBAL_UPDATE_COUNTER = 0 # reset counter
89+
ROLLING_EVENT.set() # set roll-out available
90+
91+
def _build_anet(self, name, trainable):
92+
with tf.variable_scope(name):
93+
l_a = tf.layers.dense(self.tfs, 200, tf.nn.relu, trainable=trainable)
94+
a_prob = tf.layers.dense(l_a, A_DIM, tf.nn.softmax, trainable=trainable)
95+
params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
96+
return a_prob, params
97+
98+
def choose_action(self, s): # run by a local
99+
prob_weights = self.sess.run(self.pi, feed_dict={self.tfs: s[None, :]})
100+
action = np.random.choice(range(prob_weights.shape[1]),
101+
p=prob_weights.ravel()) # select action w.r.t the actions prob
102+
return action
103+
104+
def get_v(self, s):
105+
if s.ndim < 2: s = s[np.newaxis, :]
106+
return self.sess.run(self.v, {self.tfs: s})[0, 0]
107+
108+
109+
class Worker(object):
110+
def __init__(self, wid):
111+
self.wid = wid
112+
self.env = gym.make(GAME).unwrapped
113+
self.ppo = GLOBAL_PPO
114+
115+
def work(self):
116+
global GLOBAL_EP, GLOBAL_RUNNING_R, GLOBAL_UPDATE_COUNTER
117+
while not COORD.should_stop():
118+
s = self.env.reset()
119+
ep_r = 0
120+
buffer_s, buffer_a, buffer_r = [], [], []
121+
for t in range(EP_LEN):
122+
if not ROLLING_EVENT.is_set(): # while global PPO is updating
123+
ROLLING_EVENT.wait() # wait until PPO is updated
124+
buffer_s, buffer_a, buffer_r = [], [], [] # clear history buffer, use new policy to collect data
125+
a = self.ppo.choose_action(s)
126+
s_, r, done, _ = self.env.step(a)
127+
if done: r = -10
128+
buffer_s.append(s)
129+
buffer_a.append(a)
130+
buffer_r.append(r-1) # 0 for not down, -11 for down. Reward engineering
131+
s = s_
132+
ep_r += r
133+
134+
GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size, no need to wait other workers
135+
if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE or done:
136+
if done:
137+
v_s_ = 0 # end of episode
138+
else:
139+
v_s_ = self.ppo.get_v(s_)
140+
141+
discounted_r = [] # compute discounted reward
142+
for r in buffer_r[::-1]:
143+
v_s_ = r + GAMMA * v_s_
144+
discounted_r.append(v_s_)
145+
discounted_r.reverse()
146+
147+
bs, ba, br = np.vstack(buffer_s), np.vstack(buffer_a), np.array(discounted_r)[:, None]
148+
buffer_s, buffer_a, buffer_r = [], [], []
149+
QUEUE.put(np.hstack((bs, ba, br))) # put data in the queue
150+
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
151+
ROLLING_EVENT.clear() # stop collecting data
152+
UPDATE_EVENT.set() # globalPPO update
153+
154+
if GLOBAL_EP >= EP_MAX: # stop training
155+
COORD.request_stop()
156+
break
157+
158+
if done: break
159+
160+
# record reward changes, plot later
161+
if len(GLOBAL_RUNNING_R) == 0: GLOBAL_RUNNING_R.append(ep_r)
162+
else: GLOBAL_RUNNING_R.append(GLOBAL_RUNNING_R[-1]*0.9+ep_r*0.1)
163+
GLOBAL_EP += 1
164+
print('{0:.1f}%'.format(GLOBAL_EP/EP_MAX*100), '|W%i' % self.wid, '|Ep_r: %.2f' % ep_r,)
165+
166+
167+
if __name__ == '__main__':
168+
GLOBAL_PPO = PPONet()
169+
UPDATE_EVENT, ROLLING_EVENT = threading.Event(), threading.Event()
170+
UPDATE_EVENT.clear() # not update now
171+
ROLLING_EVENT.set() # start to roll out
172+
workers = [Worker(wid=i) for i in range(N_WORKER)]
173+
174+
GLOBAL_UPDATE_COUNTER, GLOBAL_EP = 0, 0
175+
GLOBAL_RUNNING_R = []
176+
COORD = tf.train.Coordinator()
177+
QUEUE = queue.Queue() # workers putting data in this queue
178+
threads = []
179+
for worker in workers: # worker threads
180+
t = threading.Thread(target=worker.work, args=())
181+
t.start() # training
182+
threads.append(t)
183+
# add a PPO updating thread
184+
threads.append(threading.Thread(target=GLOBAL_PPO.update,))
185+
threads[-1].start()
186+
COORD.join(threads)
187+
188+
# plot reward change and test
189+
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
190+
plt.xlabel('Episode'); plt.ylabel('Moving reward'); plt.ion(); plt.show()
191+
env = gym.make('CartPole-v0')
192+
while True:
193+
s = env.reset()
194+
for t in range(1000):
195+
env.render()
196+
s, r, done, info = env.step(GLOBAL_PPO.choose_action(s))
197+
if done:
198+
break
199+

0 commit comments

Comments
 (0)