Skip to content

Commit b66fb13

Browse files
author
ruiliLaMeilleure
committed
Relocated the file and fixed issue in ending the training phase
1 parent 7d4b120 commit b66fb13

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

discrete_dppo.py renamed to contents/12_Proximal_Policy_Optimization/discrete_DPP0.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
import numpy as np
1919
import matplotlib.pyplot as plt
2020
import gym, threading, queue
21+
import time
2122

22-
EP_MAX = 4000
23+
EP_MAX = 1000
2324
EP_LEN = 500
2425
N_WORKER = 4 # parallel workers
2526
GAMMA = 0.9 # reward discount factor
@@ -171,14 +172,15 @@ def work(self):
171172

172173
QUEUE.put(q_in)
173174

174-
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE or done:
175+
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
175176
ROLLING_EVENT.clear() # stop collecting data
176177
UPDATE_EVENT.set() # globalPPO update
177-
break
178178

179179
if GLOBAL_EP >= EP_MAX: # stop training
180180
COORD.request_stop()
181181
break
182+
183+
if done:break
182184

183185
# record reward changes, plot later
184186
if len(GLOBAL_RUNNING_R) == 0: GLOBAL_RUNNING_R.append(ep_r)
@@ -187,12 +189,6 @@ def work(self):
187189
print("EP", GLOBAL_EP,'|W%i' % self.wid, '|step %i' %t, '|Ep_r: %.2f' % ep_r,)
188190
np.save("Global_return",GLOBAL_RUNNING_R)
189191
np.savez("PI_PARA",self.ppo.sess.run(GLOBAL_PPO.pi_params))
190-
# np.savez("tfa",self.ppo.sess.run(GLOBAL_PPO.tfa))
191-
# np.savez("tfadv",self.ppo.sess.run(GLOBAL_PPO.tfadv))
192-
# np.savez("val1",self.ppo.sess.run(GLOBAL_PPO.val1))
193-
# np.savez("val2",self.ppo.sess.run(GLOBAL_PPO.val2))
194-
# print self.ppo.sess.run(GLOBAL_PPO.val2)
195-
196192

197193

198194
if __name__ == '__main__':
@@ -202,6 +198,8 @@ def work(self):
202198
ROLLING_EVENT.set() # start to roll out
203199
workers = [Worker(wid=i) for i in range(N_WORKER)]
204200

201+
start = time.time()
202+
205203
GLOBAL_UPDATE_COUNTER, GLOBAL_EP = 0, 0
206204
GLOBAL_RUNNING_R = []
207205
COORD = tf.train.Coordinator()
@@ -216,6 +214,9 @@ def work(self):
216214
threads[-1].start()
217215
COORD.join(threads)
218216

217+
end = time.time()
218+
print "Total time ", (end - start)
219+
219220
# plot reward change and test
220221
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
221222
plt.xlabel('Episode'); plt.ylabel('Moving reward'); plt.ion(); plt.show()
@@ -224,6 +225,7 @@ def work(self):
224225
s = env.reset()
225226
for t in range(1000):
226227
env.render()
227-
s = env.step(GLOBAL_PPO.choose_action(s))[0]
228+
s, r, done, info = env.step(GLOBAL_PPO.choose_action(s))
228229
if done:
229230
break
231+

0 commit comments

Comments
 (0)