Skip to content

Commit ee5f44c

Browse files
committed
Feat: Use memory in TNPG TRPO
1 parent 3bfca8b commit ee5f44c

File tree

5 files changed

+48
-37
lines changed

5 files changed

+48
-37
lines changed

PG/5-TNPG/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def conjugate_gradient(net, states, loss_grad, n_step=10, residual_tol=1e-10):
7676
break
7777
return x
7878

79-
class QNet(nn.Module):
79+
class TNPG(nn.Module):
8080
def __init__(self, num_inputs, num_outputs):
81-
super(QNet, self).__init__()
81+
super(TNPG, self).__init__()
8282
self.t = 0
8383
self.num_inputs = num_inputs
8484
self.num_outputs = num_outputs
@@ -97,7 +97,7 @@ def forward(self, input):
9797
return policy
9898

9999
@classmethod
100-
def train_model(cls, net, transitions, k):
100+
def train_model(cls, net, transitions):
101101
states, actions, rewards, masks = transitions.state, transitions.action, transitions.reward, transitions.mask
102102

103103
states = torch.stack(states)

PG/5-TNPG/train.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.optim as optim
99
import torch.nn.functional as F
10-
from model import QNet
10+
from model import TNPG
1111
from tensorboardX import SummaryWriter
1212

1313
from memory import Memory
@@ -24,18 +24,17 @@ def main():
2424
print('state size:', num_inputs)
2525
print('action size:', num_actions)
2626

27-
net = QNet(num_inputs, num_actions)
27+
net = TNPG(num_inputs, num_actions)
2828
writer = SummaryWriter('logs')
2929

3030
net.to(device)
3131
net.train()
3232
running_score = 0
3333
steps = 0
3434
loss = 0
35-
k=0
3635
for e in range(30000):
3736
done = False
38-
memory = Memory()
37+
memory = Memory()
3938

4039
score = 0
4140
state = env.reset()
@@ -61,9 +60,7 @@ def main():
6160
score += reward
6261
state = next_state
6362

64-
sum_reward = 0
65-
loss = QNet.train_model(net, memory.sample(), k)
66-
k+=1
63+
loss = TNPG.train_model(net, memory.sample())
6764

6865
score = score if score == 500.0 else score + 1
6966
running_score = 0.99 * running_score + 0.01 * score

PG/6-TRPO/memory.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import random
2+
from collections import namedtuple, deque
3+
4+
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask'))
5+
6+
class Memory(object):
7+
def __init__(self):
8+
self.memory = deque()
9+
10+
def push(self, state, next_state, action, reward, mask):
11+
self.memory.append(Transition(state, next_state, action, reward, mask))
12+
13+
def sample(self):
14+
memory = self.memory
15+
return Transition(*zip(*memory))
16+
17+
def __len__(self):
18+
return len(self.memory)

PG/6-TRPO/model.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def conjugate_gradient(net, states, loss_grad, n_step=10, residual_tol=1e-10):
7575
break
7676
return x
7777

78-
class QNet(nn.Module):
78+
class TRPO(nn.Module):
7979
def __init__(self, num_inputs, num_outputs):
80-
super(QNet, self).__init__()
80+
super(TRPO, self).__init__()
8181
self.t = 0
8282
self.num_inputs = num_inputs
8383
self.num_outputs = num_outputs
@@ -90,19 +90,27 @@ def __init__(self, num_inputs, num_outputs):
9090
nn.init.xavier_uniform(m.weight)
9191

9292
def forward(self, input):
93-
x = torch.tanh(self.fc_1(input))
93+
x = torch.relu(self.fc_1(input))
9494
policy = F.softmax(self.fc_2(x))
9595

9696
return policy
9797

9898
@classmethod
99-
def train_model(cls, net, transitions, k):
100-
states, actions, rewards, masks = transitions
99+
def train_model(cls, net, transitions):
100+
states, actions, rewards, masks = transitions.state, transitions.action, transitions.reward, transitions.mask
101+
101102
states = torch.stack(states)
102103
actions = torch.stack(actions)
103104
rewards = torch.Tensor(rewards)
104105
masks = torch.Tensor(masks)
105106

107+
returns = torch.zeros_like(rewards)
108+
109+
running_return = 0
110+
for t in reversed(range(len(rewards))):
111+
running_return = rewards[t] + gamma * running_return * masks[t]
112+
returns[t] = running_return
113+
106114
policy = net(states)
107115
policy = policy.view(-1, net.num_outputs)
108116
policy_action = (policy * actions.detach()).sum(dim=1)
@@ -111,7 +119,7 @@ def train_model(cls, net, transitions, k):
111119
old_policy = old_policy.view(-1, net.num_outputs)
112120
old_policy_action = (old_policy * actions.detach()).sum(dim=1)
113121

114-
surrogate_loss = ((policy_action / old_policy_action) * rewards).mean()
122+
surrogate_loss = ((policy_action / old_policy_action) * returns).mean()
115123

116124
surrogate_loss_grad = torch.autograd.grad(surrogate_loss, net.parameters())
117125
surrogate_loss_grad = flat_grad(surrogate_loss_grad)
@@ -130,7 +138,7 @@ def train_model(cls, net, transitions, k):
130138
policy = net(states)
131139
policy = policy.view(-1, net.num_outputs)
132140
policy_action = (policy * actions.detach()).sum(dim=1)
133-
surrogate_loss = ((policy_action / old_policy_action) * rewards).mean()
141+
surrogate_loss = ((policy_action / old_policy_action) * returns).mean()
134142

135143
kl = kl_divergence(policy, old_policy)
136144
kl = kl.mean()
@@ -144,6 +152,6 @@ def train_model(cls, net, transitions, k):
144152
def get_action(self, input):
145153
policy = self.forward(input)
146154
policy = policy[0].data.numpy()
147-
155+
148156
action = np.random.choice(self.num_outputs, 1, p=policy)[0]
149157
return action

PG/6-TRPO/train.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import torch
88
import torch.optim as optim
99
import torch.nn.functional as F
10-
from model import QNet
10+
from model import TRPO
1111
from tensorboardX import SummaryWriter
1212

13+
from memory import Memory
1314
from config import env_name, goal_score, log_interval, device, gamma
1415

1516

@@ -23,19 +24,17 @@ def main():
2324
print('state size:', num_inputs)
2425
print('action size:', num_actions)
2526

26-
net = QNet(num_inputs, num_actions)
27-
27+
net = TRPO(num_inputs, num_actions)
2828
writer = SummaryWriter('logs')
2929

3030
net.to(device)
3131
net.train()
3232
running_score = 0
3333
steps = 0
3434
loss = 0
35-
k=0
36-
for e in range(3000):
35+
for e in range(30000):
3736
done = False
38-
memory = []
37+
memory = Memory()
3938

4039
score = 0
4140
state = env.reset()
@@ -56,23 +55,12 @@ def main():
5655

5756
action_one_hot = torch.zeros(2)
5857
action_one_hot[action] = 1
59-
memory.append([state, next_state, action_one_hot, reward, mask])
58+
memory.push(state, next_state, action_one_hot, reward, mask)
6059

6160
score += reward
6261
state = next_state
6362

64-
sum_reward = 0
65-
memory.reverse()
66-
states, actions, rewards, masks = [], [], [], []
67-
for t, transition in enumerate(memory):
68-
state, next_state, action, reward, mask = transition
69-
sum_reward = (reward + gamma * sum_reward)
70-
states.append(state)
71-
actions.append(action)
72-
rewards.append(sum_reward)
73-
masks.append(mask)
74-
loss = QNet.train_model(net, (states, actions, rewards, masks), k)
75-
k+=1
63+
loss = TRPO.train_model(net, memory.sample())
7664

7765
score = score if score == 500.0 else score + 1
7866
running_score = 0.99 * running_score + 0.01 * score

0 commit comments

Comments
 (0)