Skip to content

Commit 5f1dcfe

Browse files
author
Michael Antonov
committed
Fixed reinforcement learning to run with any screen size; added diagrams.
1 parent a46b643 commit 5f1dcfe

File tree

2 files changed

+71
-27
lines changed

2 files changed

+71
-27
lines changed
23.2 KB
Loading

intermediate_source/reinforcement_q_learning.py

+71-27
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
As the agent observes the current state of the environment and chooses
2424
an action, the environment *transitions* to a new state, and also
2525
returns a reward that indicates the consequences of the action. In this
26-
task, the environment terminates if the pole falls over too far.
26+
task, rewards are +1 for every incremental timestep and the environment
27+
terminates if the pole falls over too far or the crat mover more then 2.4
28+
units away from center. This means better performing scenarios will run
29+
for longer duration, accumulating larger return.
2730
2831
The CartPole task is designed so that the inputs to the agent are 4 real
2932
values representing the environment state (position, velocity, etc.).
@@ -97,7 +100,9 @@
97100
# For this, we're going to need two classses:
98101
#
99102
# - ``Transition`` - a named tuple representing a single transition in
100-
# our environment
103+
# our environment. It maps essentially maps (state, action) pairs
104+
# to their (next_state, reward) result, with the state being the
105+
# screen difference image as described later on.
101106
# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
102107
# transitions observed recently. It also implements a ``.sample()``
103108
# method for selecting a random batch of transitions for training.
@@ -197,22 +202,32 @@ def __len__(self):
197202
# difference between the current and previous screen patches. It has two
198203
# outputs, representing :math:`Q(s, \mathrm{left})` and
199204
# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the
200-
# network). In effect, the network is trying to predict the *quality* of
205+
# network). In effect, the network is trying to predict the *expected return* of
201206
# taking each action given the current input.
202207
#
203208

204209
class DQN(nn.Module):
205210

206-
def __init__(self):
211+
def __init__(self, h, w):
207212
super(DQN, self).__init__()
208213
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
209214
self.bn1 = nn.BatchNorm2d(16)
210215
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
211216
self.bn2 = nn.BatchNorm2d(32)
212217
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
213218
self.bn3 = nn.BatchNorm2d(32)
214-
self.head = nn.Linear(448, 2)
215219

220+
# Number of Linear input connections depends on output of conv2d layers
221+
# and therefore the input image size, so compute it.
222+
def conv2d_size_out(size, kernel_size = 5, stride = 2):
223+
return (size - (kernel_size - 1) - 1) // stride + 1
224+
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
225+
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
226+
linear_input_size = convw * convh * 32
227+
self.head = nn.Linear(linear_input_size, 2) # 448 or 512
228+
229+
# Called with either one element to determine next action, or a batch
230+
# during optimization. Returns tensor([[left0exp,right0exp]...]).
216231
def forward(self, x):
217232
x = F.relu(self.bn1(self.conv1(x)))
218233
x = F.relu(self.bn2(self.conv2(x)))
@@ -234,23 +249,20 @@ def forward(self, x):
234249
T.Resize(40, interpolation=Image.CUBIC),
235250
T.ToTensor()])
236251

237-
# This is based on the code from gym.
238-
screen_width = 600
239-
240-
241-
def get_cart_location():
252+
def get_cart_location(screen_width):
242253
world_width = env.x_threshold * 2
243254
scale = screen_width / world_width
244255
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
245256

246-
247257
def get_screen():
248-
screen = env.render(mode='rgb_array').transpose(
249-
(2, 0, 1)) # transpose into torch order (CHW)
250-
# Strip off the top and bottom of the screen
251-
screen = screen[:, 160:320]
252-
view_width = 320
253-
cart_location = get_cart_location()
258+
# Returned requested by gym is 400x600x3, but is sometimes larger such as
259+
# as 800x1200x3. Transpose into torch order (CHW).
260+
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
261+
# Cart is in the lower half, so strip off the top and bottom of the screen
262+
_, screen_height, screen_width = screen.shape
263+
screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
264+
view_width = int(screen_width * 0.6)
265+
cart_location = get_cart_location(screen_width)
254266
if cart_location < view_width // 2:
255267
slice_range = slice(view_width)
256268
elif cart_location > (screen_width - view_width // 2):
@@ -298,15 +310,23 @@ def get_screen():
298310
# episode.
299311
#
300312

301-
BATCH_SIZE = 128
313+
BATCH_SIZE = 196 #128
302314
GAMMA = 0.999
303315
EPS_START = 0.9
304-
EPS_END = 0.05
305-
EPS_DECAY = 200
316+
EPS_END = 0.07
317+
EPS_DECAY = 300
306318
TARGET_UPDATE = 10
307319

308-
policy_net = DQN().to(device)
309-
target_net = DQN().to(device)
320+
# Get screen size so that we can initialize layers correctly based on shape
321+
# returned from AI gym. Typical dimentions at this pont are close to 3x40x90
322+
# which is the result of a clamped and down-scaled buffer in get_screen()
323+
init_screen = get_screen()
324+
_, _, screen_height, screen_width = init_screen.shape
325+
#screen_height = init_screen.shape[2]
326+
#print("Screen size w,h:", screen_width, " ", screen_height)
327+
328+
policy_net = DQN(screen_height, screen_width).to(device)
329+
target_net = DQN(screen_height, screen_width).to(device)
310330
target_net.load_state_dict(policy_net.state_dict())
311331
target_net.eval()
312332

@@ -325,6 +345,9 @@ def select_action(state):
325345
steps_done += 1
326346
if sample > eps_threshold:
327347
with torch.no_grad():
348+
# t.max(1) will return largest value for column of each row.
349+
# second column on max result is index of where max element was
350+
# found, so we pick action with the larger expected reward.
328351
return policy_net(state).max(1)[1].view(1, 1)
329352
else:
330353
return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)
@@ -376,10 +399,12 @@ def optimize_model():
376399
return
377400
transitions = memory.sample(BATCH_SIZE)
378401
# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
379-
# detailed explanation).
402+
# detailed explanation). This converts batch-array of Transitions
403+
# to Transition of batch-arrays.
380404
batch = Transition(*zip(*transitions))
381405

382406
# Compute a mask of non-final states and concatenate the batch elements
407+
# (a final state would've been the one after which simulation ended)
383408
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
384409
batch.next_state)), device=device, dtype=torch.uint8)
385410
non_final_next_states = torch.cat([s for s in batch.next_state
@@ -389,10 +414,15 @@ def optimize_model():
389414
reward_batch = torch.cat(batch.reward)
390415

391416
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
392-
# columns of actions taken
417+
# columns of actions taken. These are the actions which would've been taken
418+
# for each batch state according to policy_net
393419
state_action_values = policy_net(state_batch).gather(1, action_batch)
394420

395421
# Compute V(s_{t+1}) for all next states.
422+
# Expected values of actions for non_final_next_states are computed based
423+
# on the "older" target_net; selecting their best reward with max(1)[0].
424+
# This is merged based on the mask, such that we'll have either the expected
425+
# state value or 0 in case the state was final.
396426
next_state_values = torch.zeros(BATCH_SIZE, device=device)
397427
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
398428
# Compute the expected Q values
@@ -418,10 +448,11 @@ def optimize_model():
418448
# fails), we restart the loop.
419449
#
420450
# Below, `num_episodes` is set small. You should download
421-
# the notebook and run lot more epsiodes.
451+
# the notebook and run lot more epsiodes, such as 300+ for meaningful
452+
# duration improvements.
422453
#
423454

424-
num_episodes = 50
455+
num_episodes = 500
425456
for i_episode in range(num_episodes):
426457
# Initialize the environment and state
427458
env.reset()
@@ -454,7 +485,7 @@ def optimize_model():
454485
episode_durations.append(t + 1)
455486
plot_durations()
456487
break
457-
# Update the target network
488+
# Update the target network, copying all weights and biases in DQN
458489
if i_episode % TARGET_UPDATE == 0:
459490
target_net.load_state_dict(policy_net.state_dict())
460491

@@ -463,3 +494,16 @@ def optimize_model():
463494
env.close()
464495
plt.ioff()
465496
plt.show()
497+
498+
######################################################################
499+
# Here is the diagram that illustrates the overall resulting flow.
500+
#
501+
# .. figure:: /_static/img/reinforcement_learning_diagram.jpg
502+
#
503+
# Actions are chosen either randomly or based on a policy, getting the next
504+
# step sample for the gym environment. We record the results in the
505+
# replay memory and also perform optimization step on every iteration.
506+
# Optimization picks a random batch from the replay memory to do training of the
507+
# new policy. "Older" target_net, used in optimization to computed expected
508+
# Q values is updated occasionally to keep it current.
509+
#

0 commit comments

Comments
 (0)