Skip to content

Commit 9578f53

Browse files
add done flag
1 parent 3b68d92 commit 9578f53

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

rl2/atari/dqn_tf.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def train(self, target_network):
123123

124124
# randomly select a batch
125125
sample = random.sample(self.experience, self.batch_sz)
126-
states, actions, rewards, next_states = map(np.array, zip(*sample))
126+
states, actions, rewards, next_states, dones = map(np.array, zip(*sample))
127127
next_Q = np.max(target_network.predict(next_states), axis=1)
128-
targets = [r + self.gamma*next_q for r, next_q in zip(rewards, next_Q)]
128+
targets = [r + self.gamma*next_q if done is False else r for r, next_q, done in zip(rewards, next_Q, dones)]
129129

130130
# call optimizer
131131
self.session.run(
@@ -137,12 +137,12 @@ def train(self, target_network):
137137
}
138138
)
139139

140-
def add_experience(self, s, a, r, s2):
140+
def add_experience(self, s, a, r, s2, done):
141141
if len(self.experience) >= self.max_experiences:
142142
self.experience.pop(0)
143143
if len(s) != 4 or len(s2) != 4:
144144
print("BAD STATE")
145-
self.experience.append((s, a, r, s2))
145+
self.experience.append((s, a, r, s2, done))
146146

147147
def sample_action(self, x, eps):
148148
if np.random.random() < eps:
@@ -192,7 +192,7 @@ def play_one(env, model, tmodel, eps, eps_step, gamma, copy_period):
192192

193193
# update the model
194194
if len(state) == 4 and len(prev_state) == 4:
195-
model.add_experience(prev_state, action, reward, state)
195+
model.add_experience(prev_state, action, reward, state, done)
196196
model.train(tmodel)
197197

198198
iters += 1

rl2/atari/dqn_tf_alt.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ def train(self, target_network):
174174

175175
# randomly select a batch
176176
sample = random.sample(self.experience, self.batch_sz)
177-
states, actions, rewards, next_states = map(np.array, zip(*sample))
177+
states, actions, rewards, next_states, dones = map(np.array, zip(*sample))
178178
next_Q = np.max(target_network.predict(next_states), axis=1)
179-
targets = [r + self.gamma*next_q for r, next_q in zip(rewards, next_Q)]
179+
targets = [r + self.gamma*next_q if done is False else r for r, next_q, done in zip(rewards, next_Q, dones)]
180180

181181
# print("train start")
182182
# call optimizer
@@ -190,12 +190,12 @@ def train(self, target_network):
190190
)
191191
# print("train end")
192192

193-
def add_experience(self, s, a, r, s2):
193+
def add_experience(self, s, a, r, s2, done):
194194
if len(self.experience) >= self.max_experiences:
195195
self.experience.pop(0)
196196
if len(s) != 4 or len(s2) != 4:
197197
print("BAD STATE")
198-
self.experience.append((s, a, r, s2))
198+
self.experience.append((s, a, r, s2, done))
199199

200200
def sample_action(self, x, eps):
201201
if np.random.random() < eps:
@@ -245,7 +245,7 @@ def play_one(env, model, tmodel, eps, eps_step, gamma, copy_period):
245245

246246
# update the model
247247
if len(state) == 4 and len(prev_state) == 4:
248-
model.add_experience(prev_state, action, reward, state)
248+
model.add_experience(prev_state, action, reward, state, done)
249249
model.train(tmodel)
250250

251251
iters += 1

rl2/atari/dqn_theano.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -198,19 +198,19 @@ def train(self, target_network):
198198

199199
# randomly select a batch
200200
sample = random.sample(self.experience, self.batch_sz)
201-
states, actions, rewards, next_states = map(np.array, zip(*sample))
201+
states, actions, rewards, next_states, dones = map(np.array, zip(*sample))
202202
next_Q = np.max(target_network.predict(next_states), axis=1)
203-
targets = [r + self.gamma*next_q for r, next_q in zip(rewards, next_Q)]
203+
targets = [r + self.gamma*next_q if done is False else r for r, next_q, done in zip(rewards, next_Q, dones)]
204204

205205
# call optimizer
206206
self.train_op(states, targets, actions)
207207

208-
def add_experience(self, s, a, r, s2):
208+
def add_experience(self, s, a, r, s2, done):
209209
if len(self.experience) >= self.max_experiences:
210210
self.experience.pop(0)
211211
if len(s) != 4 or len(s2) != 4:
212212
print("BAD STATE")
213-
self.experience.append((s, a, r, s2))
213+
self.experience.append((s, a, r, s2, done))
214214

215215
def sample_action(self, x, eps):
216216
if np.random.random() < eps:
@@ -260,7 +260,7 @@ def play_one(env, model, tmodel, eps, eps_step, gamma, copy_period):
260260

261261
# update the model
262262
if len(state) == 4 and len(prev_state) == 4:
263-
model.add_experience(prev_state, action, reward, state)
263+
model.add_experience(prev_state, action, reward, state, done)
264264
model.train(tmodel)
265265

266266
iters += 1

0 commit comments

Comments
 (0)