Skip to content

Commit 9e6c8ba

Browse files
edowsonsoumith
authored andcommitted
reinforcement_q_learning: Remove hard-coded entries for action space. (pytorch#452)
This commit removes hard-coded entries for the output action space and gets the value directly from the gym environment. Signed-off-by: Elvis Dowson <elvis.dowson@gmail.com>
1 parent 0eec7fa commit 9e6c8ba

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def __len__(self):
208208

209209
class DQN(nn.Module):
210210

211-
def __init__(self, h, w):
211+
def __init__(self, h, w, outputs):
212212
super(DQN, self).__init__()
213213
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
214214
self.bn1 = nn.BatchNorm2d(16)
@@ -224,7 +224,7 @@ def conv2d_size_out(size, kernel_size = 5, stride = 2):
224224
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
225225
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
226226
linear_input_size = convw * convh * 32
227-
self.head = nn.Linear(linear_input_size, 2) # 448 or 512
227+
self.head = nn.Linear(linear_input_size, outputs)
228228

229229
# Called with either one element to determine next action, or a batch
230230
# during optimization. Returns tensor([[left0exp,right0exp]...]).
@@ -324,8 +324,11 @@ def get_screen():
324324
init_screen = get_screen()
325325
_, _, screen_height, screen_width = init_screen.shape
326326

327-
policy_net = DQN(screen_height, screen_width).to(device)
328-
target_net = DQN(screen_height, screen_width).to(device)
327+
# Get number of actions from gym action space
328+
n_actions = env.action_space.n
329+
330+
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
331+
target_net = DQN(screen_height, screen_width, n_actions).to(device)
329332
target_net.load_state_dict(policy_net.state_dict())
330333
target_net.eval()
331334

@@ -349,7 +352,7 @@ def select_action(state):
349352
# found, so we pick action with the larger expected reward.
350353
return policy_net(state).max(1)[1].view(1, 1)
351354
else:
352-
return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)
355+
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
353356

354357

355358
episode_durations = []

0 commit comments

Comments
 (0)