Skip to content

Commit f069a49

Browse files
committed
Feat: TNPG (but it has strange point. Not perfect)
1 parent 0706e9f commit f069a49

File tree

4 files changed

+230
-2
lines changed

4 files changed

+230
-2
lines changed

PG/5-TNPG/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
3+
env_name = 'CartPole-v1'
4+
gamma = 0.99
5+
lr = 0.01
6+
goal_score = 200
7+
log_interval = 10
8+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PG/5-TNPG/model.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
from config import gamma, lr
7+
8+
def flat_grad(grads):
9+
grad_flatten = []
10+
for grad in grads:
11+
grad_flatten.append(grad.view(-1))
12+
grad_flatten = torch.cat(grad_flatten)
13+
return grad_flatten
14+
15+
def flat_hessian(hessians):
16+
hessians_flatten = []
17+
for hessian in hessians:
18+
hessians_flatten.append(hessian.contiguous().view(-1))
19+
hessians_flatten = torch.cat(hessians_flatten).data
20+
return hessians_flatten
21+
22+
def flat_params(model):
23+
params = []
24+
for param in model.parameters():
25+
params.append(param.data.view(-1))
26+
params_flatten = torch.cat(params)
27+
return params_flatten
28+
29+
def update_model(model, new_params):
30+
index = 0
31+
for params in model.parameters():
32+
params_length = len(params.view(-1))
33+
new_param = new_params[index: index + params_length]
34+
new_param = new_param.view(params.size())
35+
params.data.copy_(new_param)
36+
index += params_length
37+
38+
def kl_divergence(net, old_net, states):
39+
policy = net(states)
40+
old_policy = old_net(states).detach()
41+
42+
kl = old_policy * torch.log(old_policy / policy)
43+
44+
kl = kl.sum(1, keepdim=True)
45+
return kl
46+
47+
def fisher_vector_product(net, states, p, cg_damp=0.1):
48+
kl = kl_divergence(net, net, states)
49+
kl = kl.mean()
50+
kl_grad = torch.autograd.grad(kl, net.parameters(), create_graph=True) # create_graph is True if we need higher order derivative products
51+
kl_grad = flat_grad(kl_grad)
52+
53+
kl_grad_p = (kl_grad * p.detach()).sum()
54+
kl_hessian_p = torch.autograd.grad(kl_grad_p, net.parameters())
55+
kl_hessian_p = flat_hessian(kl_hessian_p)
56+
57+
return kl_hessian_p + cg_damp * p.detach()
58+
59+
60+
def conjugate_gradient(net, states, loss_grad, n_step=10, residual_tol=1e-10):
61+
x = torch.zeros(loss_grad.size())
62+
r = loss_grad.clone()
63+
p = loss_grad.clone()
64+
r_dot_r = torch.dot(r, r)
65+
66+
for i in range(n_step):
67+
A_dot_p = fisher_vector_product(net, states, p)
68+
alpha = r_dot_r / torch.dot(p, A_dot_p)
69+
x += alpha * p
70+
r -= alpha * A_dot_p
71+
new_r_dot_r = torch.dot(r,r)
72+
betta = new_r_dot_r / r_dot_r
73+
p = r + betta * p
74+
r_dot_r = new_r_dot_r
75+
if r_dot_r < residual_tol:
76+
break
77+
return x
78+
79+
class QNet(nn.Module):
80+
def __init__(self, num_inputs, num_outputs):
81+
super(QNet, self).__init__()
82+
self.t = 0
83+
self.num_inputs = num_inputs
84+
self.num_outputs = num_outputs
85+
86+
self.fc_1 = nn.Linear(num_inputs, 128)
87+
self.fc_2 = nn.Linear(128, num_outputs)
88+
89+
for m in self.modules():
90+
if isinstance(m, nn.Linear):
91+
nn.init.xavier_uniform(m.weight)
92+
93+
def forward(self, input):
94+
x = torch.tanh(self.fc_1(input))
95+
policy = F.softmax(self.fc_2(x))
96+
97+
return policy
98+
99+
@classmethod
100+
def train_model(cls, net, transitions, k):
101+
states, actions, rewards, masks = transitions
102+
states = torch.stack(states)
103+
actions = torch.stack(actions)
104+
rewards = torch.Tensor(rewards)
105+
masks = torch.Tensor(masks)
106+
107+
policy = net(states)
108+
policy = policy.view(-1, net.num_outputs)
109+
policy_action = (policy * actions.detach()).sum(dim=1)
110+
111+
loss = (policy_action * rewards).mean()
112+
113+
loss_grad = torch.autograd.grad(loss, net.parameters())
114+
loss_grad = flat_grad(loss_grad)
115+
116+
step_dir = conjugate_gradient(net, states, loss_grad.data)
117+
118+
params = flat_params(net)
119+
new_params = params + lr * step_dir
120+
121+
update_model(net, new_params)
122+
123+
return -loss
124+
125+
def get_action(self, input):
126+
policy = self.forward(input)
127+
policy = policy[0].data.numpy()
128+
129+
action = np.random.choice(self.num_outputs, 1, p=policy)[0]
130+
return action

PG/5-TNPG/train.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
import sys
3+
import gym
4+
import random
5+
import numpy as np
6+
7+
import torch
8+
import torch.optim as optim
9+
import torch.nn.functional as F
10+
from model import QNet
11+
from tensorboardX import SummaryWriter
12+
13+
from config import env_name, goal_score, log_interval, device, lr, gamma
14+
15+
16+
def main():
17+
env = gym.make(env_name)
18+
env.seed(500)
19+
torch.manual_seed(500)
20+
21+
num_inputs = env.observation_space.shape[0]
22+
num_actions = env.action_space.n
23+
print('state size:', num_inputs)
24+
print('action size:', num_actions)
25+
26+
net = QNet(num_inputs, num_actions)
27+
28+
writer = SummaryWriter('logs')
29+
30+
net.to(device)
31+
net.train()
32+
running_score = 0
33+
steps = 0
34+
loss = 0
35+
k=0
36+
for e in range(3000):
37+
done = False
38+
memory = []
39+
40+
score = 0
41+
state = env.reset()
42+
state = torch.Tensor(state).to(device)
43+
state = state.unsqueeze(0)
44+
45+
while not done:
46+
steps += 1
47+
48+
action = net.get_action(state)
49+
next_state, reward, done, _ = env.step(action)
50+
51+
next_state = torch.Tensor(next_state)
52+
next_state = next_state.unsqueeze(0)
53+
54+
mask = 0 if done else 1
55+
reward = reward if not done or score == 499 else -1
56+
57+
action_one_hot = torch.zeros(2)
58+
action_one_hot[action] = 1
59+
memory.append([state, next_state, action_one_hot, reward, mask])
60+
61+
score += reward
62+
state = next_state
63+
64+
sum_reward = 0
65+
memory.reverse()
66+
states, actions, rewards, masks = [], [], [], []
67+
for t, transition in enumerate(memory):
68+
state, next_state, action, reward, mask = transition
69+
sum_reward = (reward + gamma * sum_reward)
70+
states.append(state)
71+
actions.append(action)
72+
rewards.append(sum_reward)
73+
masks.append(mask)
74+
loss = QNet.train_model(net, (states, actions, rewards, masks), k)
75+
k+=1
76+
77+
score = score if score == 500.0 else score + 1
78+
running_score = 0.99 * running_score + 0.01 * score
79+
if e % log_interval == 0:
80+
print('{} episode | score: {:.2f}'.format(
81+
e, running_score))
82+
writer.add_scalar('log/score', float(running_score), e)
83+
writer.add_scalar('log/loss', float(loss), e)
84+
85+
if running_score > goal_score:
86+
break
87+
88+
89+
if __name__=="__main__":
90+
main()

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
2020
- [x] Actor Critic [[10]](#reference)
2121
- [x] Advantage Actor Critic
2222
- [x] GAE(Generalized Advantage Estimation) [[12]](#reference)
23-
- [ ] ACER [[21]](#reference)
24-
- [ ] NPG [[20]](#reference)
23+
- [ ] TNPG [[20]](#reference)
2524
- [ ] TRPO [[13]](#reference)
2625
- [ ] PPO [[14]](#reference)
26+
- [ ] ACER [[21]](#reference)
2727

2828
## Parallel
2929
- [x] Asynchronous Q-learning [[11]](#reference)

0 commit comments

Comments
 (0)