@@ -269,9 +269,9 @@ def get_screen():
269
269
# Convert to float, rescare, convert to torch tensor
270
270
# (this doesn't require a copy)
271
271
screen = np .ascontiguousarray (screen , dtype = np .float32 ) / 255
272
- screen = torch .from_numpy (screen ). type ( Tensor )
272
+ screen = torch .from_numpy (screen )
273
273
# Resize, and add a batch dimension (BCHW)
274
- return resize (screen ).unsqueeze (0 )
274
+ return resize (screen ).unsqueeze (0 ). type ( Tensor )
275
275
276
276
env .reset ()
277
277
plt .figure ()
@@ -353,6 +353,8 @@ def plot_durations():
353
353
means = durations_t .unfold (0 , 100 , 1 ).mean (1 ).view (- 1 )
354
354
means = torch .cat ((torch .zeros (99 ), means ))
355
355
plt .plot (means .numpy ())
356
+
357
+ plt .pause (0.001 ) # pause a bit so that plots are updated
356
358
if is_ipython :
357
359
display .clear_output (wait = True )
358
360
display .display (plt .gcf ())
@@ -403,7 +405,7 @@ def optimize_model():
403
405
state_action_values = model (state_batch ).gather (1 , action_batch )
404
406
405
407
# Compute V(s_{t+1}) for all next states.
406
- next_state_values = Variable (torch .zeros (BATCH_SIZE ))
408
+ next_state_values = Variable (torch .zeros (BATCH_SIZE ). type ( Tensor ) )
407
409
next_state_values [non_final_mask ] = model (non_final_next_states ).max (1 )[0 ]
408
410
# Now, we don't want to mess up the loss with a volatile flag, so let's
409
411
# clear it. After this, we'll just end up with a Variable that has
@@ -468,6 +470,7 @@ def optimize_model():
468
470
break
469
471
470
472
print ('Complete' )
473
+ env .render (close = True )
471
474
env .close ()
472
475
plt .ioff ()
473
476
plt .show ()
0 commit comments