18
18
import numpy as np
19
19
import matplotlib .pyplot as plt
20
20
import gym , threading , queue
21
+ import time
21
22
22
- EP_MAX = 4000
23
+ EP_MAX = 1000
23
24
EP_LEN = 500
24
25
N_WORKER = 4 # parallel workers
25
26
GAMMA = 0.9 # reward discount factor
@@ -171,14 +172,15 @@ def work(self):
171
172
172
173
QUEUE .put (q_in )
173
174
174
- if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE or done :
175
+ if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE :
175
176
ROLLING_EVENT .clear () # stop collecting data
176
177
UPDATE_EVENT .set () # globalPPO update
177
- break
178
178
179
179
if GLOBAL_EP >= EP_MAX : # stop training
180
180
COORD .request_stop ()
181
181
break
182
+
183
+ if done :break
182
184
183
185
# record reward changes, plot later
184
186
if len (GLOBAL_RUNNING_R ) == 0 : GLOBAL_RUNNING_R .append (ep_r )
@@ -187,12 +189,6 @@ def work(self):
187
189
print ("EP" , GLOBAL_EP ,'|W%i' % self .wid , '|step %i' % t , '|Ep_r: %.2f' % ep_r ,)
188
190
np .save ("Global_return" ,GLOBAL_RUNNING_R )
189
191
np .savez ("PI_PARA" ,self .ppo .sess .run (GLOBAL_PPO .pi_params ))
190
- # np.savez("tfa",self.ppo.sess.run(GLOBAL_PPO.tfa))
191
- # np.savez("tfadv",self.ppo.sess.run(GLOBAL_PPO.tfadv))
192
- # np.savez("val1",self.ppo.sess.run(GLOBAL_PPO.val1))
193
- # np.savez("val2",self.ppo.sess.run(GLOBAL_PPO.val2))
194
- # print self.ppo.sess.run(GLOBAL_PPO.val2)
195
-
196
192
197
193
198
194
if __name__ == '__main__' :
@@ -202,6 +198,8 @@ def work(self):
202
198
ROLLING_EVENT .set () # start to roll out
203
199
workers = [Worker (wid = i ) for i in range (N_WORKER )]
204
200
201
+ start = time .time ()
202
+
205
203
GLOBAL_UPDATE_COUNTER , GLOBAL_EP = 0 , 0
206
204
GLOBAL_RUNNING_R = []
207
205
COORD = tf .train .Coordinator ()
@@ -216,6 +214,9 @@ def work(self):
216
214
threads [- 1 ].start ()
217
215
COORD .join (threads )
218
216
217
+ end = time .time ()
218
+ print "Total time " , (end - start )
219
+
219
220
# plot reward change and test
220
221
plt .plot (np .arange (len (GLOBAL_RUNNING_R )), GLOBAL_RUNNING_R )
221
222
plt .xlabel ('Episode' ); plt .ylabel ('Moving reward' ); plt .ion (); plt .show ()
@@ -224,6 +225,7 @@ def work(self):
224
225
s = env .reset ()
225
226
for t in range (1000 ):
226
227
env .render ()
227
- s = env .step (GLOBAL_PPO .choose_action (s ))[ 0 ]
228
+ s , r , done , info = env .step (GLOBAL_PPO .choose_action (s ))
228
229
if done :
229
230
break
231
+
0 commit comments