Skip to content

Commit 3bfca8b

Browse files
committed
Feat: TNPG use memory
1 parent a621ab7 commit 3bfca8b

File tree

4 files changed

+37
-20
lines changed

4 files changed

+37
-20
lines changed

PG/5-TNPG/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
env_name = 'CartPole-v1'
44
gamma = 0.99
5-
lr = 0.01
5+
lr = 0.001
66
goal_score = 200
77
log_interval = 10
88
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PG/5-TNPG/memory.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import random
2+
from collections import namedtuple, deque
3+
4+
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask'))
5+
6+
class Memory(object):
7+
def __init__(self):
8+
self.memory = deque()
9+
10+
def push(self, state, next_state, action, reward, mask):
11+
self.memory.append(Transition(state, next_state, action, reward, mask))
12+
13+
def sample(self):
14+
memory = self.memory
15+
return Transition(*zip(*memory))
16+
17+
def __len__(self):
18+
return len(self.memory)

PG/5-TNPG/model.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,25 @@ def forward(self, input):
9898

9999
@classmethod
100100
def train_model(cls, net, transitions, k):
101-
states, actions, rewards, masks = transitions
101+
states, actions, rewards, masks = transitions.state, transitions.action, transitions.reward, transitions.mask
102+
102103
states = torch.stack(states)
103104
actions = torch.stack(actions)
104105
rewards = torch.Tensor(rewards)
105106
masks = torch.Tensor(masks)
106107

107-
policy = net(states)
108-
policy = policy.view(-1, net.num_outputs)
109-
policy_action = (policy * actions.detach()).sum(dim=1)
108+
returns = torch.zeros_like(rewards)
109+
110+
running_return = 0
111+
for t in reversed(range(len(rewards))):
112+
running_return = rewards[t] + gamma * running_return * masks[t]
113+
returns[t] = running_return
114+
115+
policies = net(states)
116+
policies = policies.view(-1, net.num_outputs)
117+
policy_actions = (policies * actions.detach()).sum(dim=1)
110118

111-
loss = (policy_action * rewards).mean()
119+
loss = (policy_actions * returns).mean()
112120

113121
loss_grad = torch.autograd.grad(loss, net.parameters())
114122
loss_grad = flat_grad(loss_grad)

PG/5-TNPG/train.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from model import QNet
1111
from tensorboardX import SummaryWriter
1212

13+
from memory import Memory
1314
from config import env_name, goal_score, log_interval, device, lr, gamma
1415

1516

@@ -24,7 +25,6 @@ def main():
2425
print('action size:', num_actions)
2526

2627
net = QNet(num_inputs, num_actions)
27-
2828
writer = SummaryWriter('logs')
2929

3030
net.to(device)
@@ -33,9 +33,9 @@ def main():
3333
steps = 0
3434
loss = 0
3535
k=0
36-
for e in range(3000):
36+
for e in range(30000):
3737
done = False
38-
memory = []
38+
memory = Memory()
3939

4040
score = 0
4141
state = env.reset()
@@ -56,22 +56,13 @@ def main():
5656

5757
action_one_hot = torch.zeros(2)
5858
action_one_hot[action] = 1
59-
memory.append([state, next_state, action_one_hot, reward, mask])
59+
memory.push(state, next_state, action_one_hot, reward, mask)
6060

6161
score += reward
6262
state = next_state
6363

6464
sum_reward = 0
65-
memory.reverse()
66-
states, actions, rewards, masks = [], [], [], []
67-
for t, transition in enumerate(memory):
68-
state, next_state, action, reward, mask = transition
69-
sum_reward = (reward + gamma * sum_reward)
70-
states.append(state)
71-
actions.append(action)
72-
rewards.append(sum_reward)
73-
masks.append(mask)
74-
loss = QNet.train_model(net, (states, actions, rewards, masks), k)
65+
loss = QNet.train_model(net, memory.sample(), k)
7566
k+=1
7667

7768
score = score if score == 500.0 else score + 1

0 commit comments

Comments
 (0)