Skip to content

Commit 1f98ebf

Browse files
committed
Feat: Apex Dqn
1 parent d67917b commit 1f98ebf

File tree

7 files changed

+417
-6
lines changed

7 files changed

+417
-6
lines changed

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,22 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
2222
- [x] GAE(Generalized Advantage Estimation) [[12]](#reference)
2323
- [x] TNPG [[20]](#reference)
2424
- [x] TRPO [[13]](#reference)
25-
- [ ] PPO [[14]](#reference)
25+
- [x] PPO [[14]](#reference)
2626

2727
## Parallel
2828
- [x] Asynchronous Q-learning [[11]](#reference)
2929
- [x] A3C (Asynchronous Advange Actor Critice) [[11]](#reference)
3030
- [x] ACER [[21]](#reference)
31-
- [ ] APE-X [[15]](#reference)
31+
- [ ] PPO [[14]](#reference)
32+
- [x] APE-X DQN [[15]](#reference)
33+
- [ ] IMPALA [[23]](#reference)
34+
- [ ] R2D2 [[16]](#reference)
3235

3336
## Will
34-
- [ ] R2D2 [[16]](#reference)
3537
- [ ] RND [[17]](#reference)
38+
- [ ] ICM [[22]](#refercence)
3639
- [ ] QRDQN [[18]](#reference)
3740
- [ ] IQN [[19]](#reference)
38-
- [ ] PAAC
3941

4042

4143
## Reference
@@ -60,6 +62,8 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
6062
[19][Implicit Quantile Networks for Distributional Reinforcement Learning](https://arxiv.org/pdf/1806.06923.pdf)
6163
[20][A Natural Policy Gradient](https://papers.nips.cc/paper/2073-a-natural-policy-gradient.pdf)
6264
[21][SAMPLE EFFICIENT ACTOR-CRITIC WITH EXPERIENCE REPLAY](https://arxiv.org/pdf/1611.01224.pdf)
65+
[22][Curiosity-driven Exploration by Self-supervised Prediction](https://arxiv.org/pdf/1705.05363.pdf)
66+
[23][IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures](https://arxiv.org/pdf/1802.01561.pdf)
6367

6468

6569
## Acknowledgements

parallel/1-Async-Q-Learning/worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def get_action(self, state, epsilon):
4343
return self.target_net.get_action(state)
4444

4545
def run(self):
46-
running_score = 0
4746
epsilon = 1.0
4847
steps = 0
4948
while self.global_ep.value < max_episode:
@@ -84,7 +83,7 @@ def run(self):
8483
loss = QNet.train_model(self.online_net, self.target_net, self.optimizer, batch)
8584
memory = Memory(async_update_step)
8685
if done:
87-
running_score = self.record(score, epsilon, loss)
86+
self.record(score, epsilon, loss)
8887
break
8988
if steps % update_target == 0:
9089
self.update_target_model()

parallel/5-ApeX/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
env_name = 'CartPole-v1'
4+
gamma = 0.99
5+
lr = 0.002
6+
goal_score = 200
7+
log_interval = 10
8+
max_episode = 30000
9+
10+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11+
12+
13+
14+
replay_memory_capacity = 10000
15+
n_step = 3
16+
local_mini_batch = 32
17+
batch_size = 32
18+
alpha = 0.5
19+
beta = 0.4

parallel/5-ApeX/memory.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import random
2+
import numpy as np
3+
import torch
4+
from collections import namedtuple, deque
5+
6+
from config import gamma, batch_size, alpha, beta
7+
8+
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask', 'step'))
9+
10+
class N_Step_Buffer(object):
11+
def __init__(self):
12+
self.memory = []
13+
self.step = 0
14+
15+
def push(self, state, next_state, action, reward, mask):
16+
self.step += 1
17+
self.memory.append([state, next_state, action, reward, mask])
18+
19+
def sample(self):
20+
[state, _, action, _, _] = self.memory[0]
21+
[_, next_state, _, _, mask] = self.memory[-1]
22+
23+
sum_reward = 0
24+
for t in reversed(range(len(self.memory))):
25+
[_, _, _, reward, _] = self.memory[t]
26+
sum_reward += reward + gamma * sum_reward
27+
reward = sum_reward
28+
step = self.step
29+
self.reset()
30+
31+
return [state, next_state, action, reward, mask, step]
32+
33+
def reset(self):
34+
self.memory = []
35+
self.step = 0
36+
37+
def __len__(self):
38+
return len(self.memory)
39+
40+
41+
class LocalBuffer(object):
42+
def __init__(self):
43+
self.memory = []
44+
45+
def push(self, state, next_state, action, reward, mask, step):
46+
self.memory.append(Transition(state, next_state, action, reward, mask, step))
47+
48+
def sample(self):
49+
transitions = self.memory
50+
batch = Transition(*zip(*transitions))
51+
return batch
52+
53+
def reset(self):
54+
self.memory = []
55+
56+
def __len__(self):
57+
return len(self.memory)
58+
59+
class Memory(object):
60+
def __init__(self, capacity):
61+
self.capacity = capacity
62+
self.memory = deque(maxlen=capacity)
63+
self.memory_probability = deque(maxlen=capacity)
64+
65+
def push(self, state, next_state, action, reward, mask, step, prior):
66+
self.memory.append(Transition(state, next_state, action, reward, mask, step))
67+
self.memory_probability.append(prior)
68+
69+
def sample(self):
70+
probaility = torch.Tensor(self.memory_probability)
71+
probaility = probaility.pow(alpha)
72+
probaility = probaility / probaility.sum()
73+
74+
p = probaility.numpy()
75+
76+
indexes = np.random.choice(range(len(self.memory_probability)), batch_size, p=p)
77+
78+
transitions = [self.memory[idx] for idx in indexes]
79+
transitions_p = torch.Tensor([self.memory_probability[idx] for idx in indexes])
80+
81+
batch = Transition(*zip(*transitions))
82+
83+
weights = (self.capacity * transitions_p).pow(-beta)
84+
weights = weights / weights.max()
85+
86+
return indexes, batch, weights
87+
88+
def update_prior(self, indexes, priors):
89+
priors_idx = 0
90+
for idx in indexes:
91+
self.memory_probability[idx] = priors[priors_idx]
92+
priors_idx += 1
93+
94+
def __len__(self):
95+
return len(self.memory)
96+
97+
98+

parallel/5-ApeX/model.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
class Model(nn.Module):
7+
def __init__(self, num_inputs, num_outputs):
8+
super(Model, self).__init__()
9+
self.num_inputs = num_inputs
10+
self.num_outputs = num_outputs
11+
12+
self.fc = nn.Linear(num_inputs, 128)
13+
self.fc_adv = nn.Linear(128, num_outputs)
14+
self.fc_val = nn.Linear(128, 1)
15+
16+
for m in self.modules():
17+
if isinstance(m, nn.Linear):
18+
nn.init.xavier_uniform(m.weight)
19+
20+
def forward(self, x):
21+
x = F.relu(self.fc(x))
22+
adv = self.fc_adv(x)
23+
adv = adv.view(-1, self.num_outputs)
24+
val = self.fc_val(x)
25+
val = val.view(-1, 1)
26+
27+
qvalue = val + (adv - adv.mean(dim=1, keepdim=True))
28+
return qvalue
29+
30+
class LocalModel(Model):
31+
def __init__(self, num_inputs, num_outputs):
32+
super(LocalModel, self).__init__(num_inputs, num_outputs)
33+
34+
def pull_from_global_model(self, global_model):
35+
self.load_state_dict(global_model.state_dict())
36+
37+
def get_action(self, input):
38+
qvalue = self.forward(input)
39+
_, action = torch.max(qvalue, 1)
40+
return action.numpy()[0]

parallel/5-ApeX/train.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import gym
2+
import torch
3+
4+
from model import Model
5+
from worker import Actor, Learner
6+
import torch.multiprocessing as mp
7+
from tensorboardX import SummaryWriter
8+
9+
from memory import Memory
10+
from config import env_name, lr, replay_memory_capacity
11+
12+
def main():
13+
env = gym.make(env_name)
14+
env.seed(500)
15+
torch.manual_seed(500)
16+
17+
num_inputs = env.observation_space.shape[0]
18+
num_actions = env.action_space.n
19+
env.close()
20+
21+
global_target_model = Model(num_inputs, num_actions)
22+
global_online_model = Model(num_inputs, num_actions)
23+
global_target_model.train()
24+
global_online_model.train()
25+
26+
global_target_model.load_state_dict(global_online_model.state_dict())
27+
global_target_model.share_memory()
28+
global_online_model.share_memory()
29+
30+
global_memory = Memory(replay_memory_capacity)
31+
32+
33+
global_ep, global_ep_r, res_queue, global_memory_pipe = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Queue()
34+
35+
writer = SummaryWriter('logs')
36+
37+
n = 2
38+
epsilons = [(i * 0.05 + 0.1) for i in range(n)]
39+
40+
actors = [Actor(global_target_model, global_memory_pipe, global_ep, global_ep_r, epsilons[i], i) for i in range(n)]
41+
[w.start() for w in actors]
42+
learner = Learner(global_online_model, global_target_model, global_memory, global_memory_pipe, res_queue)
43+
learner.start()
44+
45+
res = []
46+
while True:
47+
r = res_queue.get()
48+
if r is not None:
49+
res.append(r)
50+
[ep, loss] = r
51+
# writer.add_scalar('log/score', float(ep_r), ep)
52+
writer.add_scalar('log/loss', float(loss), ep)
53+
else:
54+
break
55+
[w.join() for w in actors]
56+
57+
if __name__=="__main__":
58+
main()

0 commit comments

Comments
 (0)