File tree Expand file tree Collapse file tree 5 files changed +22
-30
lines changed
5.2_Prioritized_Replay_DQN Expand file tree Collapse file tree 5 files changed +22
-30
lines changed Original file line number Diff line number Diff line change @@ -47,6 +47,10 @@ def __init__(
47
47
self .learn_step_counter = 0
48
48
self .memory = np .zeros ((self .memory_size , n_features * 2 + 2 ))
49
49
self ._build_net ()
50
+ t_params = tf .get_collection ('target_net_params' )
51
+ e_params = tf .get_collection ('eval_net_params' )
52
+ self .replace_target_op = [tf .assign (t , e ) for t , e in zip (t_params , e_params )]
53
+
50
54
if sess is None :
51
55
self .sess = tf .Session ()
52
56
self .sess .run (tf .global_variables_initializer ())
@@ -114,14 +118,9 @@ def choose_action(self, observation):
114
118
action = np .random .randint (0 , self .n_actions )
115
119
return action
116
120
117
- def _replace_target_params (self ):
118
- t_params = tf .get_collection ('target_net_params' )
119
- e_params = tf .get_collection ('eval_net_params' )
120
- self .sess .run ([tf .assign (t , e ) for t , e in zip (t_params , e_params )])
121
-
122
121
def learn (self ):
123
122
if self .learn_step_counter % self .replace_target_iter == 0 :
124
- self ._replace_target_params ( )
123
+ self .sess . run ( self . replace_target_op )
125
124
print ('\n target_params_replaced\n ' )
126
125
127
126
if self .memory_counter > self .memory_size :
Original file line number Diff line number Diff line change @@ -176,6 +176,9 @@ def __init__(
176
176
self .learn_step_counter = 0
177
177
178
178
self ._build_net ()
179
+ t_params = tf .get_collection ('target_net_params' )
180
+ e_params = tf .get_collection ('eval_net_params' )
181
+ self .replace_target_op = [tf .assign (t , e ) for t , e in zip (t_params , e_params )]
179
182
180
183
if self .prioritized :
181
184
self .memory = Memory (capacity = memory_size )
@@ -254,14 +257,9 @@ def choose_action(self, observation):
254
257
action = np .random .randint (0 , self .n_actions )
255
258
return action
256
259
257
- def _replace_target_params (self ):
258
- t_params = tf .get_collection ('target_net_params' )
259
- e_params = tf .get_collection ('eval_net_params' )
260
- self .sess .run ([tf .assign (t , e ) for t , e in zip (t_params , e_params )])
261
-
262
260
def learn (self ):
263
261
if self .learn_step_counter % self .replace_target_iter == 0 :
264
- self ._replace_target_params ( )
262
+ self .sess . run ( self . replace_target_op )
265
263
print ('\n target_params_replaced\n ' )
266
264
267
265
if self .prioritized :
Original file line number Diff line number Diff line change @@ -47,6 +47,10 @@ def __init__(
47
47
self .learn_step_counter = 0
48
48
self .memory = np .zeros ((self .memory_size , n_features * 2 + 2 ))
49
49
self ._build_net ()
50
+ t_params = tf .get_collection ('target_net_params' )
51
+ e_params = tf .get_collection ('eval_net_params' )
52
+ self .replace_target_op = [tf .assign (t , e ) for t , e in zip (t_params , e_params )]
53
+
50
54
if sess is None :
51
55
self .sess = tf .Session ()
52
56
self .sess .run (tf .global_variables_initializer ())
@@ -124,14 +128,9 @@ def choose_action(self, observation):
124
128
action = np .random .randint (0 , self .n_actions )
125
129
return action
126
130
127
- def _replace_target_params (self ):
128
- t_params = tf .get_collection ('target_net_params' )
129
- e_params = tf .get_collection ('eval_net_params' )
130
- self .sess .run ([tf .assign (t , e ) for t , e in zip (t_params , e_params )])
131
-
132
131
def learn (self ):
133
132
if self .learn_step_counter % self .replace_target_iter == 0 :
134
- self ._replace_target_params ( )
133
+ self .sess . run ( self . replace_target_op )
135
134
print ('\n target_params_replaced\n ' )
136
135
137
136
sample_index = np .random .choice (self .memory_size , size = self .batch_size )
Original file line number Diff line number Diff line change @@ -52,6 +52,9 @@ def __init__(
52
52
53
53
# consist of [target_net, evaluate_net]
54
54
self ._build_net ()
55
+ t_params = tf .get_collection ('target_net_params' )
56
+ e_params = tf .get_collection ('eval_net_params' )
57
+ self .replace_target_op = [tf .assign (t , e ) for t , e in zip (t_params , e_params )]
55
58
56
59
self .sess = tf .Session ()
57
60
@@ -132,15 +135,10 @@ def choose_action(self, observation):
132
135
action = np .random .randint (0 , self .n_actions )
133
136
return action
134
137
135
- def _replace_target_params (self ):
136
- t_params = tf .get_collection ('target_net_params' )
137
- e_params = tf .get_collection ('eval_net_params' )
138
- self .sess .run ([tf .assign (t , e ) for t , e in zip (t_params , e_params )])
139
-
140
138
def learn (self ):
141
139
# check to replace target parameters
142
140
if self .learn_step_counter % self .replace_target_iter == 0 :
143
- self ._replace_target_params ( )
141
+ self .sess . run ( self . replace_target_op )
144
142
print ('\n target_params_replaced\n ' )
145
143
146
144
# sample batch memory from all memory
Original file line number Diff line number Diff line change @@ -52,6 +52,9 @@ def __init__(
52
52
53
53
# consist of [target_net, evaluate_net]
54
54
self ._build_net ()
55
+ t_params = tf .get_collection ('target_net_params' )
56
+ e_params = tf .get_collection ('eval_net_params' )
57
+ self .replace_target_op = [tf .assign (t , e ) for t , e in zip (t_params , e_params )]
55
58
56
59
self .sess = tf .Session ()
57
60
@@ -132,15 +135,10 @@ def choose_action(self, observation):
132
135
action = np .random .randint (0 , self .n_actions )
133
136
return action
134
137
135
- def _replace_target_params (self ):
136
- t_params = tf .get_collection ('target_net_params' )
137
- e_params = tf .get_collection ('eval_net_params' )
138
- self .sess .run ([tf .assign (t , e ) for t , e in zip (t_params , e_params )])
139
-
140
138
def learn (self ):
141
139
# check to replace target parameters
142
140
if self .learn_step_counter % self .replace_target_iter == 0 :
143
- self ._replace_target_params ( )
141
+ self .sess . run ( self . replace_target_op )
144
142
print ('\n target_params_replaced\n ' )
145
143
146
144
# sample batch memory from all memory
You can’t perform that action at this time.
0 commit comments