Skip to content

Commit c8706e7

Browse files
committed
Feat: TRPO
1 parent f6efb6a commit c8706e7

File tree

4 files changed

+249
-1
lines changed

4 files changed

+249
-1
lines changed

PG/6-TRPO/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
3+
env_name = 'CartPole-v1'
4+
gamma = 0.99
5+
goal_score = 200
6+
log_interval = 10
7+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8+
9+
max_kl = 0.01

PG/6-TRPO/model.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
from config import gamma, max_kl
7+
8+
def flat_grad(grads):
9+
grad_flatten = []
10+
for grad in grads:
11+
grad_flatten.append(grad.view(-1))
12+
grad_flatten = torch.cat(grad_flatten)
13+
return grad_flatten
14+
15+
def flat_hessian(hessians):
16+
hessians_flatten = []
17+
for hessian in hessians:
18+
hessians_flatten.append(hessian.contiguous().view(-1))
19+
hessians_flatten = torch.cat(hessians_flatten).data
20+
return hessians_flatten
21+
22+
def flat_params(model):
23+
params = []
24+
for param in model.parameters():
25+
params.append(param.data.view(-1))
26+
params_flatten = torch.cat(params)
27+
return params_flatten
28+
29+
def update_model(model, new_params):
30+
index = 0
31+
for params in model.parameters():
32+
params_length = len(params.view(-1))
33+
new_param = new_params[index: index + params_length]
34+
new_param = new_param.view(params.size())
35+
params.data.copy_(new_param)
36+
index += params_length
37+
38+
def kl_divergence(policy, old_policy):
39+
kl = old_policy * torch.log(old_policy / policy)
40+
41+
kl = kl.sum(1, keepdim=True)
42+
return kl
43+
44+
def fisher_vector_product(net, states, p, cg_damp=0.1):
45+
policy = net(states)
46+
old_policy = net(states).detach()
47+
kl = kl_divergence(policy, old_policy)
48+
kl = kl.mean()
49+
kl_grad = torch.autograd.grad(kl, net.parameters(), create_graph=True) # create_graph is True if we need higher order derivative products
50+
kl_grad = flat_grad(kl_grad)
51+
52+
kl_grad_p = (kl_grad * p.detach()).sum()
53+
kl_hessian_p = torch.autograd.grad(kl_grad_p, net.parameters())
54+
kl_hessian_p = flat_hessian(kl_hessian_p)
55+
56+
return kl_hessian_p + cg_damp * p.detach()
57+
58+
59+
def conjugate_gradient(net, states, loss_grad, n_step=10, residual_tol=1e-10):
60+
x = torch.zeros(loss_grad.size())
61+
r = loss_grad.clone()
62+
p = loss_grad.clone()
63+
r_dot_r = torch.dot(r, r)
64+
65+
for i in range(n_step):
66+
A_dot_p = fisher_vector_product(net, states, p)
67+
alpha = r_dot_r / torch.dot(p, A_dot_p)
68+
x += alpha * p
69+
r -= alpha * A_dot_p
70+
new_r_dot_r = torch.dot(r,r)
71+
betta = new_r_dot_r / r_dot_r
72+
p = r + betta * p
73+
r_dot_r = new_r_dot_r
74+
if r_dot_r < residual_tol:
75+
break
76+
return x
77+
78+
class QNet(nn.Module):
79+
def __init__(self, num_inputs, num_outputs):
80+
super(QNet, self).__init__()
81+
self.t = 0
82+
self.num_inputs = num_inputs
83+
self.num_outputs = num_outputs
84+
85+
self.fc_1 = nn.Linear(num_inputs, 128)
86+
self.fc_2 = nn.Linear(128, num_outputs)
87+
88+
for m in self.modules():
89+
if isinstance(m, nn.Linear):
90+
nn.init.xavier_uniform(m.weight)
91+
92+
def forward(self, input):
93+
x = torch.tanh(self.fc_1(input))
94+
policy = F.softmax(self.fc_2(x))
95+
96+
return policy
97+
98+
@classmethod
99+
def train_model(cls, net, transitions, k):
100+
states, actions, rewards, masks = transitions
101+
states = torch.stack(states)
102+
actions = torch.stack(actions)
103+
rewards = torch.Tensor(rewards)
104+
masks = torch.Tensor(masks)
105+
106+
policy = net(states)
107+
policy = policy.view(-1, net.num_outputs)
108+
policy_action = (policy * actions.detach()).sum(dim=1)
109+
110+
old_policy = net(states).detach()
111+
old_policy = old_policy.view(-1, net.num_outputs)
112+
old_policy_action = (old_policy * actions.detach()).sum(dim=1)
113+
114+
surrogate_loss = ((policy_action / old_policy_action) * rewards).mean()
115+
116+
surrogate_loss_grad = torch.autograd.grad(surrogate_loss, net.parameters())
117+
surrogate_loss_grad = flat_grad(surrogate_loss_grad)
118+
119+
step_dir = conjugate_gradient(net, states, surrogate_loss_grad.data)
120+
121+
params = flat_params(net)
122+
shs = (step_dir * fisher_vector_product(net, states, step_dir)).sum(0, keepdim=True)
123+
step_size = torch.sqrt((2 * max_kl) / shs)[0]
124+
full_step = step_size * step_dir
125+
126+
fraction = 1.0
127+
for _ in range(10):
128+
new_params = params + fraction * full_step
129+
update_model(net, new_params)
130+
policy = net(states)
131+
policy = policy.view(-1, net.num_outputs)
132+
policy_action = (policy * actions.detach()).sum(dim=1)
133+
surrogate_loss = ((policy_action / old_policy_action) * rewards).mean()
134+
135+
kl = kl_divergence(policy, old_policy)
136+
kl = kl.mean()
137+
138+
if kl < max_kl:
139+
break
140+
fraction = fraction * 0.5
141+
142+
return -surrogate_loss
143+
144+
def get_action(self, input):
145+
policy = self.forward(input)
146+
policy = policy[0].data.numpy()
147+
148+
action = np.random.choice(self.num_outputs, 1, p=policy)[0]
149+
return action

PG/6-TRPO/train.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
import sys
3+
import gym
4+
import random
5+
import numpy as np
6+
7+
import torch
8+
import torch.optim as optim
9+
import torch.nn.functional as F
10+
from model import QNet
11+
from tensorboardX import SummaryWriter
12+
13+
from config import env_name, goal_score, log_interval, device, gamma
14+
15+
16+
def main():
17+
env = gym.make(env_name)
18+
env.seed(500)
19+
torch.manual_seed(500)
20+
21+
num_inputs = env.observation_space.shape[0]
22+
num_actions = env.action_space.n
23+
print('state size:', num_inputs)
24+
print('action size:', num_actions)
25+
26+
net = QNet(num_inputs, num_actions)
27+
28+
writer = SummaryWriter('logs')
29+
30+
net.to(device)
31+
net.train()
32+
running_score = 0
33+
steps = 0
34+
loss = 0
35+
k=0
36+
for e in range(3000):
37+
done = False
38+
memory = []
39+
40+
score = 0
41+
state = env.reset()
42+
state = torch.Tensor(state).to(device)
43+
state = state.unsqueeze(0)
44+
45+
while not done:
46+
steps += 1
47+
48+
action = net.get_action(state)
49+
next_state, reward, done, _ = env.step(action)
50+
51+
next_state = torch.Tensor(next_state)
52+
next_state = next_state.unsqueeze(0)
53+
54+
mask = 0 if done else 1
55+
reward = reward if not done or score == 499 else -1
56+
57+
action_one_hot = torch.zeros(2)
58+
action_one_hot[action] = 1
59+
memory.append([state, next_state, action_one_hot, reward, mask])
60+
61+
score += reward
62+
state = next_state
63+
64+
sum_reward = 0
65+
memory.reverse()
66+
states, actions, rewards, masks = [], [], [], []
67+
for t, transition in enumerate(memory):
68+
state, next_state, action, reward, mask = transition
69+
sum_reward = (reward + gamma * sum_reward)
70+
states.append(state)
71+
actions.append(action)
72+
rewards.append(sum_reward)
73+
masks.append(mask)
74+
loss = QNet.train_model(net, (states, actions, rewards, masks), k)
75+
k+=1
76+
77+
score = score if score == 500.0 else score + 1
78+
running_score = 0.99 * running_score + 0.01 * score
79+
if e % log_interval == 0:
80+
print('{} episode | score: {:.2f}'.format(
81+
e, running_score))
82+
writer.add_scalar('log/score', float(running_score), e)
83+
writer.add_scalar('log/loss', float(loss), e)
84+
85+
if running_score > goal_score:
86+
break
87+
88+
89+
if __name__=="__main__":
90+
main()

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
2121
- [x] Advantage Actor Critic
2222
- [x] GAE(Generalized Advantage Estimation) [[12]](#reference)
2323
- [x] TNPG [[20]](#reference)
24-
- [ ] TRPO [[13]](#reference)
24+
- [x] TRPO [[13]](#reference)
2525
- [ ] PPO [[14]](#reference)
2626

2727
## Parallel

0 commit comments

Comments
 (0)