Skip to content

Commit 542c810

Browse files
committed
update distributed tf for a3c
1 parent c141846 commit 542c810

File tree

2 files changed

+211
-2
lines changed

2 files changed

+211
-2
lines changed

contents/10_A3C/A3C_discrete_action.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def work(self):
117117
s = self.env.reset()
118118
ep_r = 0
119119
while True:
120-
if self.name == 'W_0':
121-
self.env.render()
120+
# if self.name == 'W_0':
121+
# self.env.render()
122122
a = self.AC.choose_action(s)
123123
s_, r, done, info = self.env.step(a)
124124
if done: r = -5

contents/10_A3C/A3C_distributed_tf.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
Asynchronous Advantage Actor Critic (A3C) with discrete action space, Reinforcement Learning.
3+
4+
The Cartpole example using distributed tensorflow + multiprocessing.
5+
6+
View more on my tutorial page: https://morvanzhou.github.io/
7+
8+
"""
9+
10+
import multiprocessing as mp
11+
import tensorflow as tf
12+
import numpy as np
13+
import gym, time
14+
import matplotlib.pyplot as plt
15+
16+
17+
UPDATE_GLOBAL_ITER = 10
18+
GAMMA = 0.9
19+
ENTROPY_BETA = 0.001
20+
LR_A = 0.001 # learning rate for actor
21+
LR_C = 0.001 # learning rate for critic
22+
23+
env = gym.make('CartPole-v0')
24+
N_S = env.observation_space.shape[0]
25+
N_A = env.action_space.n
26+
27+
28+
class ACNet(object):
29+
sess = None
30+
31+
def __init__(self, scope, opt_a=None, opt_c=None, global_net=None):
32+
if scope == 'global_net': # get global network
33+
with tf.variable_scope(scope):
34+
self.s = tf.placeholder(tf.float32, [None, N_S], 'S')
35+
self.a_params, self.c_params = self._build_net(scope)[-2:]
36+
else:
37+
with tf.variable_scope(scope):
38+
self.s = tf.placeholder(tf.float32, [None, N_S], 'S')
39+
self.a_his = tf.placeholder(tf.int32, [None, ], 'A')
40+
self.v_target = tf.placeholder(tf.float32, [None, 1], 'Vtarget')
41+
42+
self.a_prob, self.v, self.a_params, self.c_params = self._build_net(scope)
43+
44+
td = tf.subtract(self.v_target, self.v, name='TD_error')
45+
with tf.name_scope('c_loss'):
46+
self.c_loss = tf.reduce_mean(tf.square(td))
47+
48+
with tf.name_scope('a_loss'):
49+
log_prob = tf.reduce_sum(
50+
tf.log(self.a_prob) * tf.one_hot(self.a_his, N_A, dtype=tf.float32),
51+
axis=1, keep_dims=True)
52+
exp_v = log_prob * td
53+
entropy = -tf.reduce_sum(self.a_prob * tf.log(self.a_prob + 1e-5),
54+
axis=1, keep_dims=True) # encourage exploration
55+
self.exp_v = ENTROPY_BETA * entropy + exp_v
56+
self.a_loss = tf.reduce_mean(-self.exp_v)
57+
58+
with tf.name_scope('local_grad'):
59+
self.a_grads = tf.gradients(self.a_loss, self.a_params)
60+
self.c_grads = tf.gradients(self.c_loss, self.c_params)
61+
62+
self.global_step = tf.train.get_or_create_global_step()
63+
with tf.name_scope('sync'):
64+
with tf.name_scope('pull'):
65+
self.pull_a_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.a_params, global_net.a_params)]
66+
self.pull_c_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.c_params, global_net.c_params)]
67+
with tf.name_scope('push'):
68+
self.update_a_op = opt_a.apply_gradients(zip(self.a_grads, global_net.a_params), global_step=self.global_step)
69+
self.update_c_op = opt_c.apply_gradients(zip(self.c_grads, global_net.c_params))
70+
71+
def _build_net(self, scope):
72+
w_init = tf.random_normal_initializer(0., .1)
73+
with tf.variable_scope('actor'):
74+
l_a = tf.layers.dense(self.s, 200, tf.nn.relu6, kernel_initializer=w_init, name='la')
75+
a_prob = tf.layers.dense(l_a, N_A, tf.nn.softmax, kernel_initializer=w_init, name='ap')
76+
with tf.variable_scope('critic'):
77+
l_c = tf.layers.dense(self.s, 100, tf.nn.relu6, kernel_initializer=w_init, name='lc')
78+
v = tf.layers.dense(l_c, 1, kernel_initializer=w_init, name='v') # state value
79+
a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/actor')
80+
c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/critic')
81+
return a_prob, v, a_params, c_params
82+
83+
def choose_action(self, s): # run by a local
84+
prob_weights = self.sess.run(self.a_prob, feed_dict={self.s: s[np.newaxis, :]})
85+
action = np.random.choice(range(prob_weights.shape[1]),
86+
p=prob_weights.ravel()) # select action w.r.t the actions prob
87+
return action
88+
89+
def update_global(self, feed_dict): # run by a local
90+
self.sess.run([self.update_a_op, self.update_c_op], feed_dict) # local grads applies to global net
91+
92+
def pull_global(self): # run by a local
93+
self.sess.run([self.pull_a_params_op, self.pull_c_params_op])
94+
95+
96+
def work(job_name, task_index, global_ep, lock, r_queue, global_running_r):
97+
# set work's ip:port
98+
cluster = tf.train.ClusterSpec({
99+
"ps": ['localhost:2220', 'localhost:2221',],
100+
"worker": ['localhost:2222', 'localhost:2223', 'localhost:2224', 'localhost:2225',]
101+
})
102+
server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
103+
if job_name == 'ps':
104+
print('Start Parameter Sever: ', task_index)
105+
server.join()
106+
else:
107+
t1 = time.time()
108+
env = gym.make('CartPole-v0').unwrapped
109+
print('Start Worker: ', task_index)
110+
with tf.device(tf.train.replica_device_setter(
111+
worker_device="/job:worker/task:%d" % task_index,
112+
cluster=cluster)):
113+
opt_a = tf.train.RMSPropOptimizer(LR_A, name='opt_a')
114+
opt_c = tf.train.RMSPropOptimizer(LR_C, name='opt_c')
115+
global_net = ACNet('global_net')
116+
117+
local_net = ACNet('local_ac%d' % task_index, opt_a, opt_c, global_net)
118+
# set training steps
119+
hooks = [tf.train.StopAtStepHook(last_step=100000)]
120+
with tf.train.MonitoredTrainingSession(master=server.target,
121+
is_chief=True,
122+
hooks=hooks,) as sess:
123+
print('Start Worker Session: ', task_index)
124+
local_net.sess = sess
125+
total_step = 1
126+
buffer_s, buffer_a, buffer_r = [], [], []
127+
while (not sess.should_stop()) and (global_ep.value < 1000):
128+
s = env.reset()
129+
ep_r = 0
130+
while True:
131+
# if task_index:
132+
# env.render()
133+
a = local_net.choose_action(s)
134+
s_, r, done, info = env.step(a)
135+
if done: r = -5.
136+
ep_r += r
137+
buffer_s.append(s)
138+
buffer_a.append(a)
139+
buffer_r.append(r)
140+
141+
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
142+
if done:
143+
v_s_ = 0 # terminal
144+
else:
145+
v_s_ = sess.run(local_net.v, {local_net.s: s_[np.newaxis, :]})[0, 0]
146+
buffer_v_target = []
147+
for r in buffer_r[::-1]: # reverse buffer r
148+
v_s_ = r + GAMMA * v_s_
149+
buffer_v_target.append(v_s_)
150+
buffer_v_target.reverse()
151+
152+
buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.array(buffer_a), np.vstack(
153+
buffer_v_target)
154+
feed_dict = {
155+
local_net.s: buffer_s,
156+
local_net.a_his: buffer_a,
157+
local_net.v_target: buffer_v_target,
158+
}
159+
local_net.update_global(feed_dict)
160+
buffer_s, buffer_a, buffer_r = [], [], []
161+
local_net.pull_global()
162+
s = s_
163+
total_step += 1
164+
if done:
165+
if r_queue.empty(): # record running episode reward
166+
global_running_r.value = ep_r
167+
else:
168+
global_running_r.value = .99 * global_running_r.value + 0.01 * ep_r
169+
r_queue.put(global_running_r.value)
170+
171+
print(
172+
"Task: %i" % task_index,
173+
"| Ep: %i" % global_ep.value,
174+
"| Ep_r: %i" % global_running_r.value,
175+
"| Global_step: %i" % sess.run(local_net.global_step),
176+
)
177+
with lock:
178+
global_ep.value += 1
179+
break
180+
181+
print('Worker Done: ', task_index, time.time()-t1)
182+
183+
184+
if __name__ == "__main__":
185+
# use multiprocessing to create a local cluster with 2 parameter servers and 2 workers
186+
global_ep = mp.Value('i', 0)
187+
lock = mp.Lock()
188+
r_queue = mp.Queue()
189+
global_running_r = mp.Value('d', 0)
190+
191+
jobs = [
192+
('ps', 0), ('ps', 1),
193+
('worker', 0), ('worker', 1), ('worker', 2), ('worker', 3)
194+
]
195+
ps = [mp.Process(target=work, args=(j, i, global_ep, lock, r_queue, global_running_r), ) for j, i in jobs]
196+
[p.start() for p in ps]
197+
[p.join() for p in ps[2:]]
198+
199+
ep_r = []
200+
while not r_queue.empty():
201+
ep_r.append(r_queue.get())
202+
plt.plot(np.arange(len(ep_r)), ep_r)
203+
plt.title('Distributed training')
204+
plt.xlabel('Step')
205+
plt.ylabel('Total moving reward')
206+
plt.show()
207+
208+
209+

0 commit comments

Comments
 (0)