1
1
"""
2
2
A simple version of OpenAI's Proximal Policy Optimization (PPO). [http://adsabs.harvard.edu/abs/2017arXiv170706347S]
3
+
3
4
Distributing workers in parallel to collect data, then stop worker's roll-out and train PPO on collected data.
4
- Restart workers once PPO is updated. I think A3C may be faster than this version of PPO, because this PPO has to stop
5
- parallel data-collection for training.
5
+ Restart workers once PPO is updated.
6
+
7
+ The global PPO updating rule is adopted from DeepMind's paper (DPPO):
8
+ Emergence of Locomotion Behaviours in Rich Environments (Google Deepmind): [http://adsabs.harvard.edu/abs/2017arXiv170702286H]
6
9
7
10
View more on my tutorial website: https://morvanzhou.github.io/tutorials
8
11
15
18
from tensorflow .contrib .distributions import Normal
16
19
import numpy as np
17
20
import matplotlib .pyplot as plt
18
- import gym , threading
19
- from queue import Queue
21
+ import gym , threading , queue
20
22
21
- EP_MAX = 600
23
+ EP_MAX = 1000
22
24
EP_LEN = 200
23
- N_WORKER = 3
24
- GAMMA = 0.9
25
- A_LR = 0.0001
26
- C_LR = 0.0002
27
- ROLL_OUT_STEP = 32
28
- UPDATE_STEP = 10
29
- EPSILON = 0.2 # Clipped surrogate objective
30
- S_DIM , A_DIM = 3 , 1
25
+ N_WORKER = 4 # parallel workers
26
+ GAMMA = 0.9 # reward discount factor
27
+ A_LR = 0.0001 # learning rate for actor
28
+ C_LR = 0.001 # learning rate for critic
29
+ MIN_BATCH_SIZE = 64 # minimum batch size for updating PPO
30
+ UPDATE_STEP = 5 # loop update operation n-steps
31
+ EPSILON = 0.2 # for clipping surrogate objective
32
+ GAME = 'Pendulum-v0'
33
+ S_DIM , A_DIM = 3 , 1 # state and action dimension
31
34
32
35
33
36
class PPO (object ):
34
- def __init__ (self , s_dim , a_dim ,):
35
- self .a_dim = a_dim
36
- self .s_dim = s_dim
37
+ def __init__ (self ):
37
38
self .sess = tf .Session ()
38
39
39
- self .tfs = tf .placeholder (tf .float32 , [None , s_dim ], 'state' )
40
+ self .tfs = tf .placeholder (tf .float32 , [None , S_DIM ], 'state' )
40
41
41
42
# critic
42
43
l1 = tf .layers .dense (self .tfs , 100 , tf .nn .relu )
@@ -52,7 +53,7 @@ def __init__(self, s_dim, a_dim,):
52
53
self .sample_op = tf .squeeze (pi .sample (1 ), axis = 0 ) # choosing action
53
54
self .update_oldpi_op = [oldp .assign (p ) for p , oldp in zip (pi_params , oldpi_params )]
54
55
55
- self .tfa = tf .placeholder (tf .float32 , [None , a_dim ], 'action' )
56
+ self .tfa = tf .placeholder (tf .float32 , [None , A_DIM ], 'action' )
56
57
self .tfadv = tf .placeholder (tf .float32 , [None , 1 ], 'advantage' )
57
58
# ratio = tf.exp(pi.log_prob(self.tfa) - oldpi.log_prob(self.tfa))
58
59
ratio = pi .prob (self .tfa ) / (oldpi .prob (self .tfa ) + 1e-5 )
@@ -65,25 +66,27 @@ def __init__(self, s_dim, a_dim,):
65
66
self .atrain_op = tf .train .AdamOptimizer (A_LR ).minimize (self .aloss )
66
67
self .sess .run (tf .global_variables_initializer ())
67
68
68
- def update (self , coord , queue , rolling_events ):
69
- while not coord .should_stop ():
70
- if queue .full ():
69
+ def update (self ):
70
+ global GLOBAL_UPDATE_COUNTER
71
+ while not COORD .should_stop ():
72
+ if GLOBAL_EP < EP_MAX :
73
+ UPDATE_EVENT .wait () # wait until get batch of data
71
74
self .sess .run (self .update_oldpi_op ) # old pi to pi
72
-
73
- data = [queue .get () for _ in range (queue .qsize ())]
75
+ data = [QUEUE .get () for _ in range (QUEUE .qsize ())]
74
76
data = np .vstack (data )
75
- s , a , r = data [:, :self . s_dim ], data [:, self . s_dim : self . s_dim + self . a_dim ], data [:, - 1 :]
77
+ s , a , r = data [:, :S_DIM ], data [:, S_DIM : S_DIM + A_DIM ], data [:, - 1 :]
76
78
adv = self .sess .run (self .advantage , {self .tfs : s , self .tfdc_r : r })
77
79
[self .sess .run (self .atrain_op , {self .tfs : s , self .tfa : a , self .tfadv : adv }) for _ in range (UPDATE_STEP )]
78
80
[self .sess .run (self .ctrain_op , {self .tfs : s , self .tfdc_r : r }) for _ in range (UPDATE_STEP )]
79
-
80
- [re .set () for re in rolling_events ] # set roll-out available
81
+ UPDATE_EVENT .clear () # updating finished
82
+ GLOBAL_UPDATE_COUNTER = 0 # reset counter
83
+ ROLLING_EVENT .set () # set roll-out available
81
84
82
85
def _build_anet (self , name , trainable ):
83
86
with tf .variable_scope (name ):
84
87
l1 = tf .layers .dense (self .tfs , 200 , tf .nn .relu , trainable = trainable )
85
- mu = 2 * tf .layers .dense (l1 , self . a_dim , tf .nn .tanh , trainable = trainable )
86
- sigma = tf .layers .dense (l1 , self . a_dim , tf .nn .softplus , trainable = trainable )
88
+ mu = 2 * tf .layers .dense (l1 , A_DIM , tf .nn .tanh , trainable = trainable )
89
+ sigma = tf .layers .dense (l1 , A_DIM , tf .nn .softplus , trainable = trainable )
87
90
norm_dist = Normal (loc = mu , scale = sigma )
88
91
params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = name )
89
92
return norm_dist , params
@@ -99,83 +102,83 @@ def get_v(self, s):
99
102
100
103
101
104
class Worker (object ):
102
- def __init__ (self , globalPPO , roll_out_steps , wid , game , ep_len , rolling_event ):
103
- self .roll_out_steps = roll_out_steps
105
+ def __init__ (self , wid ):
104
106
self .wid = wid
105
- self .ep_len = ep_len
106
- self .rolling_event = rolling_event
107
- self .env = gym .make (game ).unwrapped
108
- self .ppo = globalPPO
109
-
110
- def work (self , coord , queue ,):
111
- global GLOBAL_EP , GLOBAL_RUNNING_R
112
- while not coord .should_stop ():
107
+ self .env = gym .make (GAME ).unwrapped
108
+ self .ppo = GLOBAL_PPO
109
+
110
+ def work (self ):
111
+ global GLOBAL_EP , GLOBAL_RUNNING_R , GLOBAL_UPDATE_COUNTER
112
+ while not COORD .should_stop ():
113
113
s = self .env .reset ()
114
114
ep_r = 0
115
115
buffer_s , buffer_a , buffer_r = [], [], []
116
- for t in range (self .ep_len ):
116
+ for t in range (EP_LEN ):
117
+ if not ROLLING_EVENT .is_set (): # while global PPO is updating
118
+ ROLLING_EVENT .wait () # wait until PPO is updated
119
+ buffer_s , buffer_a , buffer_r = [], [], [] # clear history buffer, use new policy to collect data
117
120
a = self .ppo .choose_action (s )
118
121
s_ , r , done , _ = self .env .step (a )
119
122
buffer_s .append (s )
120
123
buffer_a .append (a )
121
- buffer_r .append ((r + 8 ) / 8 ) # normalize reward, find to be useful
124
+ buffer_r .append ((r + 8 ) / 8 ) # normalize reward, find to be useful
122
125
s = s_
123
126
ep_r += r
124
127
125
- # get update buffer
126
- if ( t + 1 ) % self . roll_out_steps == 0 or t == self . ep_len - 1 :
128
+ GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size
129
+ if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE :
127
130
v_s_ = self .ppo .get_v (s_ )
128
- discounted_r = [] # compute discounted reward
131
+ discounted_r = [] # compute discounted reward
129
132
for r in buffer_r [::- 1 ]:
130
133
v_s_ = r + GAMMA * v_s_
131
134
discounted_r .append (v_s_ )
132
135
discounted_r .reverse ()
133
136
134
137
bs , ba , br = np .vstack (buffer_s ), np .vstack (buffer_a ), np .array (discounted_r )[:, np .newaxis ]
135
138
buffer_s , buffer_a , buffer_r = [], [], []
136
- queue .put (np .hstack ((bs , ba , br )))
137
- if GLOBAL_EP >= EP_MAX : # stop training
138
- coord .request_stop ()
139
+ QUEUE .put (np .hstack ((bs , ba , br )))
140
+ if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE :
141
+ ROLLING_EVENT .clear () # stop collecting data
142
+ UPDATE_EVENT .set () # globalPPO update
143
+
144
+ if GLOBAL_EP >= EP_MAX : # stop training
145
+ COORD .request_stop ()
139
146
break
140
- else :
141
- self .rolling_event .clear () # stop roll-out
142
- self .rolling_event .wait () # stop and wait until network is updated
143
147
144
148
# record reward changes, plot later
145
149
if len (GLOBAL_RUNNING_R ) == 0 : GLOBAL_RUNNING_R .append (ep_r )
146
150
else : GLOBAL_RUNNING_R .append (GLOBAL_RUNNING_R [- 1 ]* 0.9 + ep_r * 0.1 )
147
151
GLOBAL_EP += 1
148
- print ('W%i' % self . wid , '|Ep: %i' % GLOBAL_EP , '|Ep_r: %.2f' % ep_r ,)
152
+ print ('{0:.1f}%' . format ( GLOBAL_EP / EP_MAX * 100 ) , '|W %i' % self . wid , '|Ep_r: %.2f' % ep_r ,)
149
153
150
154
151
155
if __name__ == '__main__' :
152
- globalPPO = PPO (S_DIM , A_DIM )
153
- workers = [Worker (
154
- globalPPO = globalPPO , roll_out_steps = ROLL_OUT_STEP , wid = i , game = 'Pendulum-v0' ,
155
- ep_len = EP_LEN , rolling_event = threading .Event ()) for i in range (N_WORKER )]
156
-
157
- GLOBAL_EP = 0
156
+ GLOBAL_PPO = PPO ()
157
+ UPDATE_EVENT , ROLLING_EVENT = threading .Event (), threading .Event ()
158
+ UPDATE_EVENT .clear () # no update now
159
+ ROLLING_EVENT .set () # start to roll out
160
+ workers = [Worker (wid = i ) for i in range (N_WORKER )]
161
+
162
+ GLOBAL_UPDATE_COUNTER , GLOBAL_EP = 0 , 0
158
163
GLOBAL_RUNNING_R = []
159
164
COORD = tf .train .Coordinator ()
160
- QUEUE = Queue (maxsize = N_WORKER )
165
+ QUEUE = queue . Queue ()
161
166
threads = []
162
167
for worker in workers : # worker threads
163
- t = threading .Thread (target = worker .work , args = (COORD , QUEUE ))
168
+ t = threading .Thread (target = worker .work , args = ())
164
169
t .start ()
165
170
threads .append (t )
166
- # update thread for network
167
- threads .append (threading .Thread (target = globalPPO .update , args = ( COORD , QUEUE , [ w . rolling_event for w in workers ]) ))
171
+ # add a PPO updating thread
172
+ threads .append (threading .Thread (target = GLOBAL_PPO .update ,))
168
173
threads [- 1 ].start ()
169
174
COORD .join (threads )
170
175
171
- # plot reward change
176
+ # plot reward change and testing
172
177
plt .plot (np .arange (len (GLOBAL_RUNNING_R )), GLOBAL_RUNNING_R )
173
178
plt .xlabel ('Episode' ); plt .ylabel ('Moving reward' ); plt .ion (); plt .show ()
174
-
175
- env = gym .make ('Pendulum-v0' ) # testing
179
+ env = gym .make ('Pendulum-v0' )
176
180
while True :
177
181
s = env .reset ()
178
182
for t in range (400 ):
179
183
env .render ()
180
- a = globalPPO .choose_action (s )
181
- s = env .step (a )[0 ]
184
+ s = env .step (GLOBAL_PPO .choose_action (s ))[0 ]
0 commit comments