@@ -123,9 +123,9 @@ def train(self, target_network):
123
123
124
124
# randomly select a batch
125
125
sample = random .sample (self .experience , self .batch_sz )
126
- states , actions , rewards , next_states = map (np .array , zip (* sample ))
126
+ states , actions , rewards , next_states , dones = map (np .array , zip (* sample ))
127
127
next_Q = np .max (target_network .predict (next_states ), axis = 1 )
128
- targets = [r + self .gamma * next_q for r , next_q in zip (rewards , next_Q )]
128
+ targets = [r + self .gamma * next_q if done is False else r for r , next_q , done in zip (rewards , next_Q , dones )]
129
129
130
130
# call optimizer
131
131
self .session .run (
@@ -137,12 +137,12 @@ def train(self, target_network):
137
137
}
138
138
)
139
139
140
- def add_experience (self , s , a , r , s2 ):
140
+ def add_experience (self , s , a , r , s2 , done ):
141
141
if len (self .experience ) >= self .max_experiences :
142
142
self .experience .pop (0 )
143
143
if len (s ) != 4 or len (s2 ) != 4 :
144
144
print ("BAD STATE" )
145
- self .experience .append ((s , a , r , s2 ))
145
+ self .experience .append ((s , a , r , s2 , done ))
146
146
147
147
def sample_action (self , x , eps ):
148
148
if np .random .random () < eps :
@@ -192,7 +192,7 @@ def play_one(env, model, tmodel, eps, eps_step, gamma, copy_period):
192
192
193
193
# update the model
194
194
if len (state ) == 4 and len (prev_state ) == 4 :
195
- model .add_experience (prev_state , action , reward , state )
195
+ model .add_experience (prev_state , action , reward , state , done )
196
196
model .train (tmodel )
197
197
198
198
iters += 1
0 commit comments