Skip to content

Commit 7ae5a2c

Browse files
committed
fix more cuda bugs in RL tutorial
1 parent 6aa8b5e commit 7ae5a2c

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

intermediate_source/reinforcement_q_learning.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,9 @@ def get_screen():
269269
# Convert to float, rescare, convert to torch tensor
270270
# (this doesn't require a copy)
271271
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
272-
screen = torch.from_numpy(screen).type(Tensor)
272+
screen = torch.from_numpy(screen)
273273
# Resize, and add a batch dimension (BCHW)
274-
return resize(screen).unsqueeze(0)
274+
return resize(screen).unsqueeze(0).type(Tensor)
275275

276276
env.reset()
277277
plt.figure()
@@ -353,6 +353,8 @@ def plot_durations():
353353
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
354354
means = torch.cat((torch.zeros(99), means))
355355
plt.plot(means.numpy())
356+
357+
plt.pause(0.001) # pause a bit so that plots are updated
356358
if is_ipython:
357359
display.clear_output(wait=True)
358360
display.display(plt.gcf())
@@ -403,7 +405,7 @@ def optimize_model():
403405
state_action_values = model(state_batch).gather(1, action_batch)
404406

405407
# Compute V(s_{t+1}) for all next states.
406-
next_state_values = Variable(torch.zeros(BATCH_SIZE))
408+
next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
407409
next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
408410
# Now, we don't want to mess up the loss with a volatile flag, so let's
409411
# clear it. After this, we'll just end up with a Variable that has
@@ -468,6 +470,7 @@ def optimize_model():
468470
break
469471

470472
print('Complete')
473+
env.render(close=True)
471474
env.close()
472475
plt.ioff()
473476
plt.show()

0 commit comments

Comments
 (0)