Skip to content

Commit 0f546db

Browse files
MorvanZhouMorvan Zhou
authored andcommitted
move replace target to graph building
1 parent 673ef4c commit 0f546db

File tree

5 files changed

+22
-30
lines changed

5 files changed

+22
-30
lines changed

contents/5.1_Double_DQN/RL_brain.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def __init__(
4747
self.learn_step_counter = 0
4848
self.memory = np.zeros((self.memory_size, n_features*2+2))
4949
self._build_net()
50+
t_params = tf.get_collection('target_net_params')
51+
e_params = tf.get_collection('eval_net_params')
52+
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
53+
5054
if sess is None:
5155
self.sess = tf.Session()
5256
self.sess.run(tf.global_variables_initializer())
@@ -114,14 +118,9 @@ def choose_action(self, observation):
114118
action = np.random.randint(0, self.n_actions)
115119
return action
116120

117-
def _replace_target_params(self):
118-
t_params = tf.get_collection('target_net_params')
119-
e_params = tf.get_collection('eval_net_params')
120-
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])
121-
122121
def learn(self):
123122
if self.learn_step_counter % self.replace_target_iter == 0:
124-
self._replace_target_params()
123+
self.sess.run(self.replace_target_op)
125124
print('\ntarget_params_replaced\n')
126125

127126
if self.memory_counter > self.memory_size:

contents/5.2_Prioritized_Replay_DQN/RL_brain.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def __init__(
176176
self.learn_step_counter = 0
177177

178178
self._build_net()
179+
t_params = tf.get_collection('target_net_params')
180+
e_params = tf.get_collection('eval_net_params')
181+
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
179182

180183
if self.prioritized:
181184
self.memory = Memory(capacity=memory_size)
@@ -254,14 +257,9 @@ def choose_action(self, observation):
254257
action = np.random.randint(0, self.n_actions)
255258
return action
256259

257-
def _replace_target_params(self):
258-
t_params = tf.get_collection('target_net_params')
259-
e_params = tf.get_collection('eval_net_params')
260-
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])
261-
262260
def learn(self):
263261
if self.learn_step_counter % self.replace_target_iter == 0:
264-
self._replace_target_params()
262+
self.sess.run(self.replace_target_op)
265263
print('\ntarget_params_replaced\n')
266264

267265
if self.prioritized:

contents/5.3_Dueling_DQN/RL_brain.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def __init__(
4747
self.learn_step_counter = 0
4848
self.memory = np.zeros((self.memory_size, n_features*2+2))
4949
self._build_net()
50+
t_params = tf.get_collection('target_net_params')
51+
e_params = tf.get_collection('eval_net_params')
52+
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
53+
5054
if sess is None:
5155
self.sess = tf.Session()
5256
self.sess.run(tf.global_variables_initializer())
@@ -124,14 +128,9 @@ def choose_action(self, observation):
124128
action = np.random.randint(0, self.n_actions)
125129
return action
126130

127-
def _replace_target_params(self):
128-
t_params = tf.get_collection('target_net_params')
129-
e_params = tf.get_collection('eval_net_params')
130-
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])
131-
132131
def learn(self):
133132
if self.learn_step_counter % self.replace_target_iter == 0:
134-
self._replace_target_params()
133+
self.sess.run(self.replace_target_op)
135134
print('\ntarget_params_replaced\n')
136135

137136
sample_index = np.random.choice(self.memory_size, size=self.batch_size)

contents/5_Deep_Q_Network/RL_brain.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def __init__(
5252

5353
# consist of [target_net, evaluate_net]
5454
self._build_net()
55+
t_params = tf.get_collection('target_net_params')
56+
e_params = tf.get_collection('eval_net_params')
57+
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
5558

5659
self.sess = tf.Session()
5760

@@ -132,15 +135,10 @@ def choose_action(self, observation):
132135
action = np.random.randint(0, self.n_actions)
133136
return action
134137

135-
def _replace_target_params(self):
136-
t_params = tf.get_collection('target_net_params')
137-
e_params = tf.get_collection('eval_net_params')
138-
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])
139-
140138
def learn(self):
141139
# check to replace target parameters
142140
if self.learn_step_counter % self.replace_target_iter == 0:
143-
self._replace_target_params()
141+
self.sess.run(self.replace_target_op)
144142
print('\ntarget_params_replaced\n')
145143

146144
# sample batch memory from all memory

contents/6_OpenAI_gym/RL_brain.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def __init__(
5252

5353
# consist of [target_net, evaluate_net]
5454
self._build_net()
55+
t_params = tf.get_collection('target_net_params')
56+
e_params = tf.get_collection('eval_net_params')
57+
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
5558

5659
self.sess = tf.Session()
5760

@@ -132,15 +135,10 @@ def choose_action(self, observation):
132135
action = np.random.randint(0, self.n_actions)
133136
return action
134137

135-
def _replace_target_params(self):
136-
t_params = tf.get_collection('target_net_params')
137-
e_params = tf.get_collection('eval_net_params')
138-
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])
139-
140138
def learn(self):
141139
# check to replace target parameters
142140
if self.learn_step_counter % self.replace_target_iter == 0:
143-
self._replace_target_params()
141+
self.sess.run(self.replace_target_op)
144142
print('\ntarget_params_replaced\n')
145143

146144
# sample batch memory from all memory

0 commit comments

Comments
 (0)