23
23
As the agent observes the current state of the environment and chooses
24
24
an action, the environment *transitions* to a new state, and also
25
25
returns a reward that indicates the consequences of the action. In this
26
- task, the environment terminates if the pole falls over too far.
26
+ task, rewards are +1 for every incremental timestep and the environment
27
+ terminates if the pole falls over too far or the crat mover more then 2.4
28
+ units away from center. This means better performing scenarios will run
29
+ for longer duration, accumulating larger return.
27
30
28
31
The CartPole task is designed so that the inputs to the agent are 4 real
29
32
values representing the environment state (position, velocity, etc.).
97
100
# For this, we're going to need two classses:
98
101
#
99
102
# - ``Transition`` - a named tuple representing a single transition in
100
- # our environment
103
+ # our environment. It maps essentially maps (state, action) pairs
104
+ # to their (next_state, reward) result, with the state being the
105
+ # screen difference image as described later on.
101
106
# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
102
107
# transitions observed recently. It also implements a ``.sample()``
103
108
# method for selecting a random batch of transitions for training.
@@ -197,22 +202,32 @@ def __len__(self):
197
202
# difference between the current and previous screen patches. It has two
198
203
# outputs, representing :math:`Q(s, \mathrm{left})` and
199
204
# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the
200
- # network). In effect, the network is trying to predict the *quality * of
205
+ # network). In effect, the network is trying to predict the *expected return * of
201
206
# taking each action given the current input.
202
207
#
203
208
204
209
class DQN (nn .Module ):
205
210
206
- def __init__ (self ):
211
+ def __init__ (self , h , w ):
207
212
super (DQN , self ).__init__ ()
208
213
self .conv1 = nn .Conv2d (3 , 16 , kernel_size = 5 , stride = 2 )
209
214
self .bn1 = nn .BatchNorm2d (16 )
210
215
self .conv2 = nn .Conv2d (16 , 32 , kernel_size = 5 , stride = 2 )
211
216
self .bn2 = nn .BatchNorm2d (32 )
212
217
self .conv3 = nn .Conv2d (32 , 32 , kernel_size = 5 , stride = 2 )
213
218
self .bn3 = nn .BatchNorm2d (32 )
214
- self .head = nn .Linear (448 , 2 )
215
219
220
+ # Number of Linear input connections depends on output of conv2d layers
221
+ # and therefore the input image size, so compute it.
222
+ def conv2d_size_out (size , kernel_size = 5 , stride = 2 ):
223
+ return (size - (kernel_size - 1 ) - 1 ) // stride + 1
224
+ convw = conv2d_size_out (conv2d_size_out (conv2d_size_out (w )))
225
+ convh = conv2d_size_out (conv2d_size_out (conv2d_size_out (h )))
226
+ linear_input_size = convw * convh * 32
227
+ self .head = nn .Linear (linear_input_size , 2 ) # 448 or 512
228
+
229
+ # Called with either one element to determine next action, or a batch
230
+ # during optimization. Returns tensor([[left0exp,right0exp]...]).
216
231
def forward (self , x ):
217
232
x = F .relu (self .bn1 (self .conv1 (x )))
218
233
x = F .relu (self .bn2 (self .conv2 (x )))
@@ -234,23 +249,20 @@ def forward(self, x):
234
249
T .Resize (40 , interpolation = Image .CUBIC ),
235
250
T .ToTensor ()])
236
251
237
- # This is based on the code from gym.
238
- screen_width = 600
239
-
240
-
241
- def get_cart_location ():
252
+ def get_cart_location (screen_width ):
242
253
world_width = env .x_threshold * 2
243
254
scale = screen_width / world_width
244
255
return int (env .state [0 ] * scale + screen_width / 2.0 ) # MIDDLE OF CART
245
256
246
-
247
257
def get_screen ():
248
- screen = env .render (mode = 'rgb_array' ).transpose (
249
- (2 , 0 , 1 )) # transpose into torch order (CHW)
250
- # Strip off the top and bottom of the screen
251
- screen = screen [:, 160 :320 ]
252
- view_width = 320
253
- cart_location = get_cart_location ()
258
+ # Returned requested by gym is 400x600x3, but is sometimes larger such as
259
+ # as 800x1200x3. Transpose into torch order (CHW).
260
+ screen = env .render (mode = 'rgb_array' ).transpose ((2 , 0 , 1 ))
261
+ # Cart is in the lower half, so strip off the top and bottom of the screen
262
+ _ , screen_height , screen_width = screen .shape
263
+ screen = screen [:, int (screen_height * 0.4 ):int (screen_height * 0.8 )]
264
+ view_width = int (screen_width * 0.6 )
265
+ cart_location = get_cart_location (screen_width )
254
266
if cart_location < view_width // 2 :
255
267
slice_range = slice (view_width )
256
268
elif cart_location > (screen_width - view_width // 2 ):
@@ -298,15 +310,23 @@ def get_screen():
298
310
# episode.
299
311
#
300
312
301
- BATCH_SIZE = 128
313
+ BATCH_SIZE = 196 # 128
302
314
GAMMA = 0.999
303
315
EPS_START = 0.9
304
- EPS_END = 0.05
305
- EPS_DECAY = 200
316
+ EPS_END = 0.07
317
+ EPS_DECAY = 300
306
318
TARGET_UPDATE = 10
307
319
308
- policy_net = DQN ().to (device )
309
- target_net = DQN ().to (device )
320
+ # Get screen size so that we can initialize layers correctly based on shape
321
+ # returned from AI gym. Typical dimentions at this pont are close to 3x40x90
322
+ # which is the result of a clamped and down-scaled buffer in get_screen()
323
+ init_screen = get_screen ()
324
+ _ , _ , screen_height , screen_width = init_screen .shape
325
+ #screen_height = init_screen.shape[2]
326
+ #print("Screen size w,h:", screen_width, " ", screen_height)
327
+
328
+ policy_net = DQN (screen_height , screen_width ).to (device )
329
+ target_net = DQN (screen_height , screen_width ).to (device )
310
330
target_net .load_state_dict (policy_net .state_dict ())
311
331
target_net .eval ()
312
332
@@ -325,6 +345,9 @@ def select_action(state):
325
345
steps_done += 1
326
346
if sample > eps_threshold :
327
347
with torch .no_grad ():
348
+ # t.max(1) will return largest value for column of each row.
349
+ # second column on max result is index of where max element was
350
+ # found, so we pick action with the larger expected reward.
328
351
return policy_net (state ).max (1 )[1 ].view (1 , 1 )
329
352
else :
330
353
return torch .tensor ([[random .randrange (2 )]], device = device , dtype = torch .long )
@@ -376,10 +399,12 @@ def optimize_model():
376
399
return
377
400
transitions = memory .sample (BATCH_SIZE )
378
401
# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
379
- # detailed explanation).
402
+ # detailed explanation). This converts batch-array of Transitions
403
+ # to Transition of batch-arrays.
380
404
batch = Transition (* zip (* transitions ))
381
405
382
406
# Compute a mask of non-final states and concatenate the batch elements
407
+ # (a final state would've been the one after which simulation ended)
383
408
non_final_mask = torch .tensor (tuple (map (lambda s : s is not None ,
384
409
batch .next_state )), device = device , dtype = torch .uint8 )
385
410
non_final_next_states = torch .cat ([s for s in batch .next_state
@@ -389,10 +414,15 @@ def optimize_model():
389
414
reward_batch = torch .cat (batch .reward )
390
415
391
416
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
392
- # columns of actions taken
417
+ # columns of actions taken. These are the actions which would've been taken
418
+ # for each batch state according to policy_net
393
419
state_action_values = policy_net (state_batch ).gather (1 , action_batch )
394
420
395
421
# Compute V(s_{t+1}) for all next states.
422
+ # Expected values of actions for non_final_next_states are computed based
423
+ # on the "older" target_net; selecting their best reward with max(1)[0].
424
+ # This is merged based on the mask, such that we'll have either the expected
425
+ # state value or 0 in case the state was final.
396
426
next_state_values = torch .zeros (BATCH_SIZE , device = device )
397
427
next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ].detach ()
398
428
# Compute the expected Q values
@@ -418,10 +448,11 @@ def optimize_model():
418
448
# fails), we restart the loop.
419
449
#
420
450
# Below, `num_episodes` is set small. You should download
421
- # the notebook and run lot more epsiodes.
451
+ # the notebook and run lot more epsiodes, such as 300+ for meaningful
452
+ # duration improvements.
422
453
#
423
454
424
- num_episodes = 50
455
+ num_episodes = 500
425
456
for i_episode in range (num_episodes ):
426
457
# Initialize the environment and state
427
458
env .reset ()
@@ -454,7 +485,7 @@ def optimize_model():
454
485
episode_durations .append (t + 1 )
455
486
plot_durations ()
456
487
break
457
- # Update the target network
488
+ # Update the target network, copying all weights and biases in DQN
458
489
if i_episode % TARGET_UPDATE == 0 :
459
490
target_net .load_state_dict (policy_net .state_dict ())
460
491
@@ -463,3 +494,16 @@ def optimize_model():
463
494
env .close ()
464
495
plt .ioff ()
465
496
plt .show ()
497
+
498
+ ######################################################################
499
+ # Here is the diagram that illustrates the overall resulting flow.
500
+ #
501
+ # .. figure:: /_static/img/reinforcement_learning_diagram.jpg
502
+ #
503
+ # Actions are chosen either randomly or based on a policy, getting the next
504
+ # step sample for the gym environment. We record the results in the
505
+ # replay memory and also perform optimization step on every iteration.
506
+ # Optimization picks a random batch from the replay memory to do training of the
507
+ # new policy. "Older" target_net, used in optimization to computed expected
508
+ # Q values is updated occasionally to keep it current.
509
+ #
0 commit comments