Skip to content

Commit 6fea9ea

Browse files
committed
[WIP]: ACER
1 parent c8706e7 commit 6fea9ea

File tree

8 files changed

+357
-7
lines changed

8 files changed

+357
-7
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
2727
## Parallel
2828
- [x] Asynchronous Q-learning [[11]](#reference)
2929
- [x] A3C (Asynchronous Advange Actor Critice) [[11]](#reference)
30-
31-
## Will
3230
- [ ] ACER [[21]](#reference)
3331
- [ ] APE-X [[15]](#reference)
32+
33+
## Will
3434
- [ ] R2D2 [[16]](#reference)
3535
- [ ] RND [[17]](#reference)
3636
- [ ] QRDQN [[18]](#reference)
@@ -67,3 +67,5 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
6767
- https://github.com/reinforcement-learning-kr/pg_travel
6868
- https://github.com/reinforcement-learning-kr/distributional_rl
6969
- https://github.com/Kaixhin/Rainbow
70+
- https://github.com/Kaixhin/ACER
71+
- https://github.com/higgsfield/RL-Adventure-2

parallel/2-A3C/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
from config import gamma
66

7-
def set_init(layers):
8-
for layer in layers:
9-
nn.init.normal_(layer.weight, mean=0., std=0.1)
10-
nn.init.constant_(layer.bias, 0.1)
11-
127
class Model(nn.Module):
138
def __init__(self, num_inputs, num_outputs):
149
super(Model, self).__init__()

parallel/3-ACER/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
env_name = 'CartPole-v1'
4+
gamma = 0.99
5+
lr = 0.001
6+
goal_score = 200
7+
log_interval = 10
8+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9+
max_episode = 30000
10+
11+
12+
replay_memory_capacity = 1000
13+
truncation_clip = 10
14+
delta = 1
15+
trust_region_decay = 0.99
16+
replay_ratio = 4
17+
max_gradient_norm = 40

parallel/3-ACER/memory.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import random
2+
from collections import namedtuple, deque
3+
4+
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask', 'policy'))
5+
6+
7+
class Memory(object):
8+
def __init__(self, capacity):
9+
self.memory = deque(maxlen=capacity)
10+
self.capacity = capacity
11+
12+
def push(self, trajectory):
13+
self.memory.append(trajectory.trajectory)
14+
15+
def sample(self):
16+
trajectory = self.memory[random.randrange(len(self.memory))]
17+
return Transition(*zip(*trajectory))
18+
19+
def __len__(self):
20+
return len(self.memory)
21+
22+
class Trajectory(object):
23+
def __init__(self):
24+
self.trajectory = []
25+
26+
def push(self, state, next_state, action, reward, mask, policy):
27+
self.trajectory.append(Transition(state, next_state, action, reward, mask, policy))
28+
29+
def sample(self):
30+
trajectory = self.trajectory
31+
return Transition(*zip(*trajectory))
32+
33+
def __len__(self):
34+
return len(self.trajectory)

parallel/3-ACER/model.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
from config import gamma, truncation_clip, delta, max_gradient_norm, trust_region_decay
7+
8+
class Model(nn.Module):
9+
def __init__(self, num_inputs, num_outputs):
10+
super(Model, self).__init__()
11+
self.num_inputs = num_inputs
12+
self.num_outputs = num_outputs
13+
14+
self.fc = nn.Linear(num_inputs, 128)
15+
self.fc_actor = nn.Linear(128, num_outputs)
16+
self.fc_critic = nn.Linear(128, num_outputs)
17+
18+
for m in self.modules():
19+
if isinstance(m, nn.Linear):
20+
nn.init.xavier_uniform(m.weight)
21+
22+
def forward(self, input):
23+
x = F.relu(self.fc(input))
24+
policy = F.softmax(self.fc_actor(x), dim=1)
25+
q_value = self.fc_critic(x)
26+
value = (policy * q_value).sum(-1, keepdim=True).view(-1)
27+
return policy, q_value, value
28+
29+
class LocalModel(Model):
30+
def __init__(self, num_inputs, num_outputs):
31+
super(LocalModel, self).__init__(num_inputs, num_outputs)
32+
33+
def pull_from_global_model(self, global_model):
34+
self.load_state_dict(global_model.state_dict())
35+
36+
def update_model(self, loss, global_optimizer, global_model, global_average_model):
37+
global_optimizer.zero_grad()
38+
loss.backward()
39+
# nn.utils.clip_grad_norm_(self.parameters(), max_gradient_norm)
40+
41+
for lp, gp in zip(self.parameters(), global_model.parameters()):
42+
if gp.grad is not None:
43+
return
44+
gp.grad = lp.grad
45+
46+
global_optimizer.step()
47+
48+
for gp, gap in zip(global_model.parameters(), global_average_model.parameters()):
49+
gap = trust_region_decay * gap + (1 - trust_region_decay) * gp
50+
51+
52+
def compute_q_retraces(self, rewards, masks, values, q_actions, rho_actions, next_value):
53+
q_retraces = torch.zeros(rewards.size())
54+
q_retraces[-1] = next_value
55+
56+
q_ret = q_retraces[-1]
57+
for step in reversed(range(len(rewards) - 1)):
58+
q_ret = rewards[step] + gamma * q_ret
59+
q_retraces[step] = q_ret
60+
q_ret = rho_actions[step] * (q_ret - q_actions[step]) + values[step]
61+
62+
return q_retraces
63+
64+
65+
def get_loss(self, on_policy, trajectory, average_model):
66+
states, next_states, actions, rewards, masks, old_policies = trajectory
67+
states = torch.stack(states)
68+
next_states = torch.stack(next_states)
69+
actions = torch.Tensor(actions).long().view(-1,1)
70+
rewards = torch.Tensor(rewards)
71+
masks = torch.Tensor(masks)
72+
old_policies = torch.stack(old_policies)
73+
74+
states = states.view(-1, self.num_inputs)
75+
next_states = next_states.view(-1, self.num_inputs)
76+
policies, Qs, Vs = self.forward(states)
77+
78+
Q_actions = Qs.gather(1, actions).view(-1)
79+
80+
if not on_policy:
81+
rhos = policies / old_policies
82+
else:
83+
rhos = torch.zeros(policies.size()).fill_(1)
84+
85+
rho_actions = rhos.gather(1, actions).view(-1)
86+
87+
if masks[-1] == 0:
88+
Qret = 0
89+
else:
90+
Qret = Vs[-1]
91+
Qrets = self.compute_q_retraces(rewards, masks, Vs, Q_actions, rho_actions, Qret)
92+
log_policy = torch.log(policies)
93+
log_policy_action = log_policy.gather(1, actions).view(-1)
94+
95+
actor_loss_1 = - (log_policy_action * (
96+
rho_actions.clamp(max=truncation_clip) * (Qrets - Vs)
97+
).detach()).mean()
98+
actor_loss_2 = - (log_policy * (
99+
(1 - truncation_clip / rhos).clamp(min=0) * policies * (Qs - Vs.view(-1,1).expand_as(Qs))
100+
).detach()).sum(1).mean()
101+
actor_loss = actor_loss_1 + actor_loss_2
102+
103+
value_loss = ((Qret - Q_actions) ** 2).mean()
104+
105+
106+
g_1 = ((1 / log_policy_action) * (
107+
rho_actions.clamp(max=truncation_clip) * (Qrets - Vs)
108+
))
109+
g_2 = ((1 / log_policy) * (
110+
(1 - truncation_clip / rhos).clamp(min=0) * policies * (Qs - Vs.view(-1,1).expand_as(Qs))
111+
)).sum(1)
112+
g = (g_1 + g_2).detach()
113+
average_policies, _, _ = average_model(states)
114+
k = (average_policies / policies).gather(1, actions).view(-1)
115+
116+
kl = (average_policies * torch.log(average_policies / policies)).sum(1).mean(0)
117+
118+
119+
k_dot_g = (k * g).sum()
120+
k_dot_k = (k * k).sum()
121+
122+
adj = ((k_dot_g - delta) / k_dot_k).clamp(min=0).detach()
123+
trust_region_actor_loss = actor_loss + adj * kl
124+
125+
loss = trust_region_actor_loss + value_loss
126+
127+
return loss
128+
129+
def get_action(self, input):
130+
policy, _, _ = self.forward(input)
131+
policy = policy[0].data.numpy()
132+
133+
action = np.random.choice(self.num_outputs, 1, p=policy)[0]
134+
return action, policy

parallel/3-ACER/shared_adam.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
class SharedAdam(torch.optim.Adam): # extend a pytorch optimizer so it shares grads across processes
3+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
4+
super(SharedAdam, self).__init__(params, lr, betas, eps, weight_decay)
5+
for group in self.param_groups:
6+
for p in group['params']:
7+
state = self.state[p]
8+
state['shared_steps'], state['step'] = torch.zeros(1).share_memory_(), 0
9+
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_().share_memory_()
10+
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_().share_memory_()
11+
12+
def step(self, closure=None):
13+
for group in self.param_groups:
14+
for p in group['params']:
15+
if p.grad is None: continue
16+
self.state[p]['shared_steps'] += 1
17+
self.state[p]['step'] = self.state[p]['shared_steps'][0] - 1 # a "step += 1" comes later
18+
super.step(closure)
19+
20+
# class SharedAdam(torch.optim.Adam):
21+
# def __init__(self, params, lr=1e-3, betas=(0.9, 0.9), eps=1e-8,
22+
# weight_decay=0):
23+
# super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
24+
# # State initialization
25+
# for group in self.param_groups:
26+
# for p in group['params']:
27+
# state = self.state[p]
28+
# state['step'] = 0
29+
# state['exp_avg'] = torch.zeros_like(p.data)
30+
# state['exp_avg_sq'] = torch.zeros_like(p.data)
31+
#
32+
# # share in memory
33+
# state['exp_avg'].share_memory_()
34+
# state['exp_avg_sq'].share_memory_()

parallel/3-ACER/train.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import gym
2+
import torch
3+
4+
from model import Model
5+
from worker import Worker
6+
from shared_adam import SharedAdam
7+
from tensorboardX import SummaryWriter
8+
import torch.multiprocessing as mp
9+
10+
from config import env_name, lr
11+
12+
def main():
13+
env = gym.make(env_name)
14+
env.seed(500)
15+
torch.manual_seed(500)
16+
17+
num_inputs = env.observation_space.shape[0]
18+
num_actions = env.action_space.n
19+
env.close()
20+
21+
global_model = Model(num_inputs, num_actions)
22+
global_average_model = Model(num_inputs, num_actions)
23+
global_model.share_memory()
24+
global_average_model.share_memory()
25+
global_optimizer = SharedAdam(global_model.parameters(), lr=lr)
26+
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
27+
28+
writer = SummaryWriter('logs')
29+
30+
# n = mp.cpu_count()
31+
n = 1
32+
workers = [Worker(global_model, global_average_model, global_optimizer, global_ep, global_ep_r, res_queue, i) for i in range(n)]
33+
[w.start() for w in workers]
34+
res = []
35+
while True:
36+
r = res_queue.get()
37+
if r is not None:
38+
res.append(r)
39+
[ep, ep_r, loss] = r
40+
writer.add_scalar('log/score', float(ep_r), ep)
41+
writer.add_scalar('log/loss', float(loss), ep)
42+
else:
43+
break
44+
[w.join() for w in workers]
45+
46+
if __name__=="__main__":
47+
main()

parallel/3-ACER/worker.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import gym
2+
import torch
3+
import torch.multiprocessing as mp
4+
import numpy as np
5+
from model import LocalModel
6+
from memory import Memory, Trajectory
7+
from config import env_name, max_episode, log_interval, replay_memory_capacity, replay_ratio
8+
9+
class Worker(mp.Process):
10+
def __init__(self, global_model, global_average_model, global_optimizer, global_ep, global_ep_r, res_queue, name):
11+
super(Worker, self).__init__()
12+
13+
self.env = gym.make(env_name)
14+
self.env.seed(500)
15+
16+
self.name = 'w%i' % name
17+
self.global_ep, self.global_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
18+
self.global_model, self.global_average_model, self.global_optimizer = global_model, global_average_model, global_optimizer
19+
self.local_model = LocalModel(self.env.observation_space.shape[0], self.env.action_space.n)
20+
self.num_actions = self.env.action_space.n
21+
22+
self.memory = Memory(replay_memory_capacity)
23+
24+
def record(self, score, loss):
25+
with self.global_ep.get_lock():
26+
self.global_ep.value += 1
27+
with self.global_ep_r.get_lock():
28+
if self.global_ep_r.value == 0.:
29+
self.global_ep_r.value = score
30+
else:
31+
self.global_ep_r.value = 0.99 * self.global_ep_r.value + 0.01 * score
32+
if self.global_ep.value % log_interval == 0:
33+
print('{} , {} episode | score: {:.2f}'.format(
34+
self.name, self.global_ep.value, self.global_ep_r.value))
35+
36+
self.res_queue.put([self.global_ep.value, self.global_ep_r.value, loss])
37+
38+
def run(self):
39+
while self.global_ep.value < max_episode:
40+
self.algorithm(True)
41+
n = np.random.poisson(replay_ratio)
42+
for _ in range(n):
43+
self.algorithm(False)
44+
45+
def algorithm(self, on_policy):
46+
self.local_model.pull_from_global_model(self.global_model)
47+
if not on_policy and len(self.memory) > 100:
48+
trajectory = self.memory.sample()
49+
else:
50+
trajectory, score = self.run_env()
51+
loss = self.local_model.get_loss(on_policy, trajectory, self.global_average_model)
52+
self.local_model.update_model(loss, self.global_optimizer, self.global_model, self.global_average_model)
53+
if on_policy:
54+
self.record(score, loss)
55+
56+
57+
def run_env(self):
58+
done = False
59+
score = 0
60+
steps = 0
61+
62+
state = self.env.reset()
63+
state = torch.Tensor(state)
64+
state = state.unsqueeze(0)
65+
trajectory = Trajectory()
66+
67+
while True:
68+
action, policy = self.local_model.get_action(state)
69+
policy = torch.Tensor(policy)
70+
71+
next_state, reward, done, _ = self.env.step(action)
72+
next_state = torch.Tensor(next_state)
73+
next_state = next_state.unsqueeze(0)
74+
75+
mask = 0 if done else 1
76+
reward = reward if not done or score == 499 else -1
77+
trajectory.push(state, next_state, action, reward, mask, policy)
78+
79+
score += reward
80+
state = next_state
81+
82+
if done:
83+
break
84+
85+
self.memory.push(trajectory)
86+
trajectory = trajectory.sample()
87+
return trajectory, score

0 commit comments

Comments
 (0)