Skip to content

Commit f6efb6a

Browse files
committed
Feat: A3C
1 parent f069a49 commit f6efb6a

File tree

8 files changed

+94
-94
lines changed

8 files changed

+94
-94
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
2020
- [x] Actor Critic [[10]](#reference)
2121
- [x] Advantage Actor Critic
2222
- [x] GAE(Generalized Advantage Estimation) [[12]](#reference)
23-
- [ ] TNPG [[20]](#reference)
23+
- [x] TNPG [[20]](#reference)
2424
- [ ] TRPO [[13]](#reference)
2525
- [ ] PPO [[14]](#reference)
26-
- [ ] ACER [[21]](#reference)
2726

2827
## Parallel
2928
- [x] Asynchronous Q-learning [[11]](#reference)
30-
- [ ] A3C (Asynchronous Advange Actor Critice) [[11]](#reference)
29+
- [x] A3C (Asynchronous Advange Actor Critice) [[11]](#reference)
3130

3231
## Will
32+
- [ ] ACER [[21]](#reference)
3333
- [ ] APE-X [[15]](#reference)
3434
- [ ] R2D2 [[16]](#reference)
3535
- [ ] RND [[17]](#reference)

parallel/1-Async-Q-Learning/worker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@ def run(self):
4646
running_score = 0
4747
epsilon = 1.0
4848
steps = 0
49-
total_step = 0
5049
while self.global_ep.value < max_episode:
5150
if self.global_ep_r.value > goal_score:
5251
break
53-
total_step +=1
5452
done = False
5553

5654
score = 0

parallel/2-A3C/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
3+
env_name = 'CartPole-v1'
4+
gamma = 0.99
5+
lr = 0.0001
6+
goal_score = 200
7+
log_interval = 10
8+
n_step = 10
9+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10+
max_episode = 30000

parallel/2-A3C/memory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import random
22
from collections import namedtuple
33

4-
# Taken from
5-
# https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb
6-
74
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask'))
85

96

parallel/2-A3C/model.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5+
from config import gamma
6+
57
def set_init(layers):
68
for layer in layers:
79
nn.init.normal_(layer.weight, mean=0., std=0.1)
@@ -13,26 +15,27 @@ def __init__(self, num_inputs, num_outputs):
1315
self.num_inputs = num_inputs
1416
self.num_outputs = num_outputs
1517

16-
self.fc1 = nn.Linear(num_inputs, 128)
17-
self.fc2 = nn.Linear(128, 128)
18+
self.fc = nn.Linear(num_inputs, 128)
1819
self.fc_actor = nn.Linear(128, num_outputs)
19-
20-
self.fc3 = nn.Linear(num_inputs, 128)
21-
self.fc4 = nn.Linear(128, 128)
2220
self.fc_critic = nn.Linear(128, 1)
2321

24-
set_init([self.fc1, self.fc2, self.fc_actor, self.fc3, self.fc4, self.fc_critic])
22+
for m in self.modules():
23+
if isinstance(m, nn.Linear):
24+
nn.init.xavier_uniform(m.weight)
2525

2626
def forward(self, input):
27-
x = F.relu(self.fc1(input))
28-
x = F.relu(self.fc2(x))
27+
x = F.relu(self.fc(input))
2928
policy = F.softmax(self.fc_actor(x))
30-
31-
y = F.relu(self.fc3(input))
32-
y = F.relu(self.fc4(y))
33-
value = self.fc_critic(y)
29+
value = self.fc_critic(x)
3430
return policy, value
3531

32+
def get_action(self, input):
33+
policy, _ = self.forward(input)
34+
policy = policy[0].data.numpy()
35+
36+
action = np.random.choice(self.num_outputs, 1, p=policy)[0]
37+
return action
38+
3639

3740
class GlobalModel(Model):
3841
def __init__(self, num_inputs, num_outputs):
@@ -43,34 +46,41 @@ class LocalModel(Model):
4346
def __init__(self, num_inputs, num_outputs):
4447
super(LocalModel, self).__init__(num_inputs, num_outputs)
4548

46-
def push_to_global_model(self, batch, global_model, global_optimizer, args):
49+
def push_to_global_model(self, batch, global_model, global_optimizer):
4750
states = torch.stack(batch.state)
4851
next_states = torch.stack(batch.next_state)
49-
actions = torch.Tensor(batch.action).long()
52+
actions = torch.stack(batch.action)
5053
rewards = torch.Tensor(batch.reward)
5154
masks = torch.Tensor(batch.mask)
5255

53-
policy, value = self.forward(states[0])
56+
policy, value = self.forward(states)
57+
policy = policy.view(-1, self.num_outputs)
58+
value = value.view(-1)
59+
5460
_, last_value = self.forward(next_states[-1])
5561

56-
running_returns = last_value[0]
62+
running_return = last_value[0].data
63+
running_returns = torch.zeros(rewards.size())
5764
for t in reversed(range(0, len(rewards))):
58-
running_returns = rewards[t] + args.gamma * running_returns * masks[t]
65+
running_return = rewards[t] + gamma * running_return * masks[t]
66+
running_returns[t] = running_return
5967

60-
pred = running_returns
61-
td_error = pred - value[0]
6268

63-
log_policy = torch.log(policy[0] + 1e-5)[actions[0]]
64-
loss1 = - log_policy * td_error.item()
65-
loss2 = F.mse_loss(value[0], pred.detach())
66-
entropy = torch.log(policy + 1e-5) * policy
67-
loss = loss1 + loss2 - 0.01 * entropy.sum()
69+
td_error = running_returns - value.detach()
70+
log_policy = (torch.log(policy + 1e-10) * actions).sum(dim=1, keepdim=True)
71+
loss_policy = - log_policy * td_error
72+
loss_value = torch.pow(td_error, 2)
73+
entropy = (torch.log(policy + 1e-10) * policy).sum(dim=1, keepdim=True)
74+
75+
loss = (loss_policy + loss_value - 0.01 * entropy).mean()
6876

6977
global_optimizer.zero_grad()
7078
loss.backward()
7179
for lp, gp in zip(self.parameters(), global_model.parameters()):
7280
gp._grad = lp.grad
7381
global_optimizer.step()
7482

83+
return loss
84+
7585
def pull_from_global_model(self, global_model):
7686
self.load_state_dict(global_model.state_dict())

parallel/2-A3C/shared_adam.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
"""
2-
Shared optimizer, the parameters in the optimizer will shared in the multiprocessors.
3-
"""
4-
51
import torch
62
class SharedAdam(torch.optim.Adam): # extend a pytorch optimizer so it shares grads across processes
73
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):

parallel/2-A3C/train.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,41 @@
1-
import os
2-
import sys
31
import gym
4-
import argparse
5-
import numpy as np
2+
import torch
63

74
from model import Model
85
from worker import Worker
96
from shared_adam import SharedAdam
107
from tensorboardX import SummaryWriter
118
import torch.multiprocessing as mp
129

13-
parser = argparse.ArgumentParser()
14-
parser.add_argument('--env_name', type=str, default="CartPole-v1", help='')
15-
parser.add_argument('--load_model', type=str, default=None)
16-
parser.add_argument('--save_path', default='./save_model/', help='')
17-
parser.add_argument('--render', default=False, action="store_true")
18-
parser.add_argument('--gamma', default=0.9, help='')
19-
parser.add_argument('--goal_score', default=400, help='')
20-
parser.add_argument('--log_interval', default=10, help='')
21-
parser.add_argument('--logdir', type=str, default='./logs',
22-
help='tensorboardx logs directory')
23-
parser.add_argument('--MAX_EP', default=10000)
24-
args = parser.parse_args()
10+
from config import env_name, lr
2511

26-
if __name__ == "__main__":
27-
env = gym.make(args.env_name)
28-
global_model = Model(env.observation_space.shape[0], env.action_space.n)
12+
def main():
13+
env = gym.make(env_name)
14+
env.seed(500)
15+
torch.manual_seed(500)
16+
17+
num_inputs = env.observation_space.shape[0]
18+
num_actions = env.action_space.n
19+
global_model = Model(num_inputs, num_actions)
2920
global_model.share_memory()
30-
global_optimizer = SharedAdam(global_model.parameters(), lr=0.0001)
21+
global_optimizer = SharedAdam(global_model.parameters(), lr=lr)
3122
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
3223

33-
# mp.cpu_count()
34-
workers = [Worker(global_model, global_optimizer, global_ep, global_ep_r, res_queue, i, args) for i in range(mp.cpu_count())]
24+
writer = SummaryWriter('logs')
25+
26+
workers = [Worker(global_model, global_optimizer, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())]
3527
[w.start() for w in workers]
3628
res = []
3729
while True:
3830
r = res_queue.get()
3931
if r is not None:
4032
res.append(r)
33+
[ep, ep_r, loss] = r
34+
writer.add_scalar('log/score', float(ep_r), ep)
35+
writer.add_scalar('log/loss', float(loss), ep)
4136
else:
4237
break
4338
[w.join() for w in workers]
4439

45-
# import matplotlib.pyplot as plt
46-
# plt.plot(res)
47-
# plt.ylabel('Moving average ep reward')
48-
# plt.xlabel('Step')
49-
# plt.show()
40+
if __name__=="__main__":
41+
main()

parallel/2-A3C/worker.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,13 @@
44
import numpy as np
55
from model import LocalModel
66
from memory import Memory
7-
8-
def record(global_ep, global_ep_r, ep_r, res_queue, name):
9-
with global_ep.get_lock():
10-
global_ep.value += 1
11-
with global_ep_r.get_lock():
12-
if global_ep_r.value == 0.:
13-
global_ep_r.value = ep_r
14-
else:
15-
global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
16-
res_queue.put(global_ep_r.value)
17-
print(
18-
name,
19-
"Ep:", global_ep.value,
20-
"| Ep_r:", global_ep_r.value,
21-
)
7+
from config import env_name, n_step, max_episode, log_interval
228

239
class Worker(mp.Process):
24-
def __init__(self, global_model, global_optimizer, global_ep, global_ep_r, res_queue, name, args):
10+
def __init__(self, global_model, global_optimizer, global_ep, global_ep_r, res_queue, name):
2511
super(Worker, self).__init__()
26-
self.args = args
2712

28-
self.env = gym.make(self.args.env_name)
13+
self.env = gym.make(env_name)
2914
self.env.seed(500)
3015

3116
self.name = 'w%i' % name
@@ -34,16 +19,28 @@ def __init__(self, global_model, global_optimizer, global_ep, global_ep_r, res_q
3419
self.local_model = LocalModel(self.env.observation_space.shape[0], self.env.action_space.n)
3520
self.num_actions = self.env.action_space.n
3621

22+
def record(self, score, loss):
23+
with self.global_ep.get_lock():
24+
self.global_ep.value += 1
25+
with self.global_ep_r.get_lock():
26+
if self.global_ep_r.value == 0.:
27+
self.global_ep_r.value = score
28+
else:
29+
self.global_ep_r.value = 0.99 * self.global_ep_r.value + 0.01 * score
30+
if self.global_ep.value % log_interval == 0:
31+
print('{} , {} episode | score: {:.2f}'.format(
32+
self.name, self.global_ep.value, self.global_ep_r.value))
33+
34+
self.res_queue.put([self.global_ep.value, self.global_ep_r.value, loss])
3735

3836
def get_action(self, policy, num_actions):
3937
policy = policy.data.numpy()[0]
4038
action = np.random.choice(num_actions, 1, p=policy)[0]
4139
return action
4240

4341
def run(self):
44-
self.local_model.train()
45-
total_step = 1
46-
while self.global_ep.value < self.args.MAX_EP:
42+
43+
while self.global_ep.value < max_episode:
4744
self.local_model.pull_from_global_model(self.global_model)
4845
done = False
4946
score = 0
@@ -52,10 +49,9 @@ def run(self):
5249
state = self.env.reset()
5350
state = torch.Tensor(state)
5451
state = state.unsqueeze(0)
55-
memory = Memory(100)
52+
memory = Memory(n_step)
5653

5754
while True:
58-
self.local_model.eval()
5955
policy, value = self.local_model(state)
6056
action = self.get_action(policy, self.num_actions)
6157

@@ -65,21 +61,22 @@ def run(self):
6561

6662
mask = 0 if done else 1
6763
reward = reward if not done or score == 499 else -1
68-
score += reward
64+
action_one_hot = torch.zeros(2)
65+
action_one_hot[action] = 1
66+
memory.push(state, next_state, action_one_hot, reward, mask)
6967

70-
memory.push(state, next_state, action, reward, mask)
68+
score += reward
69+
state = next_state
7170

72-
if len(memory) == 10 or done:
71+
if len(memory) == n_step or done:
7372
batch = memory.sample()
74-
self.local_model.push_to_global_model(batch, self.global_model, self.global_optimizer, self.args)
73+
loss = self.local_model.push_to_global_model(batch, self.global_model, self.global_optimizer)
7574
self.local_model.pull_from_global_model(self.global_model)
76-
memory = Memory(100)
75+
memory = Memory(n_step)
7776

7877
if done:
79-
record(self.global_ep, self.global_ep_r, score, self.res_queue, self.name)
78+
running_score = self.record(score, loss)
8079
break
8180

8281

83-
total_step += 1
84-
state = next_state
8582
self.res_queue.put(None)

0 commit comments

Comments
 (0)