Skip to content

Commit d8bc712

Browse files
committed
[WIP] Feat: R2D2 Double, Duel, PER
1 parent f77670c commit d8bc712

File tree

4 files changed

+354
-0
lines changed

4 files changed

+354
-0
lines changed

POMDP/4-R2D2-Single/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
env_name = 'CartPole-v1'
4+
gamma = 0.99
5+
batch_size = 32
6+
lr = 0.001
7+
initial_exploration = 1000
8+
goal_score = 200
9+
log_interval = 10
10+
update_target = 100
11+
replay_memory_capacity = 1000
12+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13+
14+
sequence_length = 32
15+
burn_in_length = 4
16+
eta = 0.9
17+
local_mini_batch = 8

POMDP/4-R2D2-Single/memory.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import random
2+
from collections import namedtuple, deque
3+
from config import sequence_length, burn_in_length, eta
4+
import torch
5+
import numpy as np
6+
7+
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask', 'rnn_state'))
8+
9+
10+
class LocalBuffer(object):
11+
def __init__(self):
12+
self.local_memory = []
13+
self.memory = []
14+
15+
def push(self, state, next_state, action, reward, mask, rnn_state):
16+
self.local_memory.append(Transition(state, next_state, action, reward, mask, torch.stack(rnn_state).view(2, -1)))
17+
if len(self.local_memory) == sequence_length or mask == 0:
18+
if len(self.local_memory) > burn_in_length:
19+
length = len(self.local_memory)
20+
while len(self.local_memory) < sequence_length:
21+
self.local_memory.append(Transition(
22+
torch.Tensor([0, 0]),
23+
torch.Tensor([0, 0]),
24+
0,
25+
0,
26+
0,
27+
torch.zeros([2, 1, 16]).view(2, -1)
28+
))
29+
self.memory.append([self.local_memory, length])
30+
self.local_memory = []
31+
32+
def sample(self):
33+
episodes = self.memory
34+
batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state = [], [], [], [], [], []
35+
lengths = []
36+
for episode, length in episodes:
37+
batch = Transition(*zip(*episode))
38+
39+
batch_state.append(torch.stack(list(batch.state)))
40+
batch_next_state.append(torch.stack(list(batch.next_state)))
41+
batch_action.append(torch.Tensor(list(batch.action)))
42+
batch_reward.append(torch.Tensor(list(batch.reward)))
43+
batch_mask.append(torch.Tensor(list(batch.mask)))
44+
batch_rnn_state.append(torch.stack(list(batch.rnn_state)))
45+
46+
lengths.append(length)
47+
self.memory = []
48+
return Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state), lengths
49+
50+
51+
52+
class Memory(object):
53+
def __init__(self, capacity):
54+
self.capacity = capacity
55+
self.memory = deque(maxlen=capacity)
56+
self.memory_probability = deque(maxlen=capacity)
57+
58+
def td_error_to_prior(self, td_error, lengths):
59+
abs_td_error_sum = td_error.sum(dim=1, keepdim=True).view(-1).abs().detach().numpy()
60+
lengths_burn = [length - burn_in_length for length in lengths]
61+
62+
prior = abs_td_error_sum / lengths_burn
63+
return prior
64+
65+
def push(self, td_error, batch, lengths):
66+
# batch.state[local_mini_batch, sequence_length, item]
67+
prior = self.td_error_to_prior(td_error, lengths)
68+
69+
for i in range(len(batch)):
70+
if len(self.memory_probability) > 0:
71+
memory_probability = np.array(self.memory_probability)
72+
probability_max = max(memory_probability.max(), prior[i])
73+
probability_mean = (memory_probability.sum() + prior[i]) / (len(self.memory_probability) + 1)
74+
else:
75+
probability_max = prior[i]
76+
probability_mean = prior[i]
77+
self.memory.append([Transition(batch.state[i], batch.next_state[i], batch.action[i], batch.reward[i], batch.mask[i], batch.rnn_state[i]), lengths[i]])
78+
p = eta * probability_max + (1 - eta) * probability_mean
79+
self.memory_probability.append(p)
80+
81+
def sample(self, batch_size):
82+
probability = np.array(self.memory_probability)
83+
probability = probability / probability.sum()
84+
85+
indexes = np.random.choice(range(len(self.memory_probability)), batch_size, p=probability)
86+
# indexes = np.random.choice(range(len(self.memory_probability)), batch_size)
87+
episodes = [self.memory[idx][0] for idx in indexes]
88+
lengths = [self.memory[idx][1] for idx in indexes]
89+
90+
batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state = [], [], [], [], [], []
91+
for episode in episodes:
92+
batch_state.append(episode.state)
93+
batch_next_state.append(episode.next_state)
94+
batch_action.append(episode.action)
95+
batch_reward.append(episode.reward)
96+
batch_mask.append(episode.mask)
97+
batch_rnn_state.append(episode.rnn_state)
98+
99+
return Transition(batch_state, batch_next_state, batch_action, batch_reward, batch_mask, batch_rnn_state), indexes, lengths
100+
101+
def update_prior(self, indexes, td_error, lengths):
102+
prior = self.td_error_to_prior(td_error, lengths)
103+
priors_idx = 0
104+
for idx in indexes:
105+
self.memory_probability[idx] = prior[priors_idx]
106+
priors_idx += 1
107+
108+
def __len__(self):
109+
return len(self.memory)

POMDP/4-R2D2-Single/model.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from config import gamma, device, batch_size, sequence_length, burn_in_length
6+
7+
class R2D2(nn.Module):
8+
def __init__(self, num_inputs, num_outputs):
9+
super(R2D2, self).__init__()
10+
self.num_inputs = num_inputs
11+
self.num_outputs = num_outputs
12+
13+
self.lstm = nn.LSTM(input_size=num_inputs, hidden_size=16, batch_first=True)
14+
self.fc = nn.Linear(16, 128)
15+
self.fc_adv = nn.Linear(128, num_outputs)
16+
self.fc_val = nn.Linear(128, 1)
17+
18+
for m in self.modules():
19+
if isinstance(m, nn.Linear):
20+
nn.init.xavier_uniform(m.weight)
21+
22+
def forward(self, x, hidden=None):
23+
# x [batch_size, sequence_length, num_inputs]
24+
batch_size = x.size()[0]
25+
sequence_length = x.size()[1]
26+
if hidden is not None:
27+
out, hidden = self.lstm(x, hidden)
28+
else:
29+
out, hidden = self.lstm(x)
30+
31+
out = F.relu(self.fc(out))
32+
adv = self.fc_adv(out)
33+
adv = adv.view(batch_size, sequence_length, self.num_outputs)
34+
val = self.fc_val(out)
35+
val = val.view(batch_size, sequence_length, 1)
36+
37+
qvalue = val + (adv - adv.mean(dim=2, keepdim=True))
38+
39+
return qvalue, hidden
40+
41+
@classmethod
42+
def get_td_error(cls, online_net, target_net, batch, lengths):
43+
def slice_burn_in(item):
44+
return item[:, burn_in_length:, :]
45+
batch_size = torch.stack(batch.state).size()[0]
46+
states = torch.stack(batch.state).view(batch_size, sequence_length, online_net.num_inputs)
47+
next_states = torch.stack(batch.next_state).view(batch_size, sequence_length, online_net.num_inputs)
48+
actions = torch.stack(batch.action).view(batch_size, sequence_length, -1).long()
49+
rewards = torch.stack(batch.reward).view(batch_size, sequence_length, -1)
50+
masks = torch.stack(batch.mask).view(batch_size, sequence_length, -1)
51+
rnn_state = torch.stack(batch.rnn_state).view(batch_size, sequence_length, 2, -1)
52+
53+
[h0, c0] = rnn_state[:, 0, :, :].transpose(0, 1)
54+
h0 = h0.unsqueeze(0).detach()
55+
c0 = c0.unsqueeze(0).detach()
56+
57+
[h1, c1] = rnn_state[:, 1, :, :].transpose(0, 1)
58+
h1 = h1.unsqueeze(0).detach()
59+
c1 = c1.unsqueeze(0).detach()
60+
61+
pred, _ = online_net(states, (h0, c0))
62+
next_pred, _ = target_net(next_states, (h1, c1))
63+
64+
next_pred_online, _ = online_net(next_states, (h1, c1))
65+
66+
pred = slice_burn_in(pred)
67+
next_pred = slice_burn_in(next_pred)
68+
actions = slice_burn_in(actions)
69+
rewards = slice_burn_in(rewards)
70+
masks = slice_burn_in(masks)
71+
next_pred_online = slice_burn_in(next_pred_online)
72+
73+
pred = pred.gather(2, actions)
74+
75+
_, next_pred_online_action = next_pred_online.max(2)
76+
77+
target = rewards + masks * gamma * next_pred.gather(2, next_pred_online_action.unsqueeze(2))
78+
79+
td_error = pred - target.detach()
80+
81+
td_error_slice = []
82+
for idx, length in enumerate(lengths):
83+
td_error_slice.append(td_error[idx][:length-burn_in_length][:])
84+
85+
return td_error
86+
87+
@classmethod
88+
def train_model(cls, online_net, target_net, optimizer, batch, lengths):
89+
td_error = cls.get_td_error(online_net, target_net, batch, lengths)
90+
91+
loss = pow(td_error, 2).mean()
92+
93+
optimizer.zero_grad()
94+
loss.backward()
95+
optimizer.step()
96+
97+
return loss, td_error
98+
99+
def get_action(self, state, hidden):
100+
state = state.unsqueeze(0).unsqueeze(0)
101+
102+
qvalue, hidden = self.forward(state, hidden)
103+
104+
_, action = torch.max(qvalue, 2)
105+
return action.numpy()[0][0], hidden

POMDP/4-R2D2-Single/train.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import os
2+
import sys
3+
import gym
4+
import random
5+
import numpy as np
6+
7+
import torch
8+
import torch.optim as optim
9+
import torch.nn.functional as F
10+
from model import R2D2
11+
from memory import Memory, LocalBuffer
12+
from tensorboardX import SummaryWriter
13+
14+
from config import env_name, initial_exploration, batch_size, update_target, goal_score, log_interval, device, replay_memory_capacity, lr, sequence_length, local_mini_batch
15+
16+
from collections import deque
17+
18+
def get_action(state, target_net, epsilon, env, hidden):
19+
action, hidden = target_net.get_action(state, hidden)
20+
21+
if np.random.rand() <= epsilon:
22+
return env.action_space.sample(), hidden
23+
else:
24+
return action, hidden
25+
26+
def update_target_model(online_net, target_net):
27+
# Target <- Net
28+
target_net.load_state_dict(online_net.state_dict())
29+
30+
def state_to_partial_observability(state):
31+
state = state[[0, 2]]
32+
return state
33+
34+
def main():
35+
env = gym.make(env_name)
36+
env.seed(500)
37+
torch.manual_seed(500)
38+
39+
# num_inputs = env.observation_space.shape[0]
40+
num_inputs = 2
41+
num_actions = env.action_space.n
42+
print('state size:', num_inputs)
43+
print('action size:', num_actions)
44+
45+
online_net = R2D2(num_inputs, num_actions)
46+
target_net = R2D2(num_inputs, num_actions)
47+
update_target_model(online_net, target_net)
48+
49+
optimizer = optim.Adam(online_net.parameters(), lr=lr)
50+
writer = SummaryWriter('logs')
51+
52+
online_net.to(device)
53+
target_net.to(device)
54+
online_net.train()
55+
target_net.train()
56+
memory = Memory(replay_memory_capacity)
57+
running_score = 0
58+
epsilon = 1.0
59+
steps = 0
60+
loss = 0
61+
local_buffer = LocalBuffer()
62+
63+
for e in range(30000):
64+
done = False
65+
66+
score = 0
67+
state = env.reset()
68+
state = state_to_partial_observability(state)
69+
state = torch.Tensor(state).to(device)
70+
71+
hidden = None
72+
73+
while not done:
74+
steps += 1
75+
76+
action, hidden = get_action(state, target_net, epsilon, env, hidden)
77+
78+
next_state, reward, done, _ = env.step(action)
79+
80+
next_state = state_to_partial_observability(next_state)
81+
next_state = torch.Tensor(next_state)
82+
83+
mask = 0 if done else 1
84+
reward = reward if not done or score == 499 else -1
85+
86+
local_buffer.push(state, next_state, action, reward, mask, hidden)
87+
if len(local_buffer.memory) == local_mini_batch:
88+
batch, lengths = local_buffer.sample()
89+
td_error = R2D2.get_td_error(online_net, target_net, batch, lengths)
90+
memory.push(td_error, batch, lengths)
91+
92+
score += reward
93+
state = next_state
94+
95+
if steps > initial_exploration and len(memory) > batch_size:
96+
epsilon -= 0.00005
97+
epsilon = max(epsilon, 0.1)
98+
99+
batch, indexes, lengths = memory.sample(batch_size)
100+
loss, td_error = R2D2.train_model(online_net, target_net, optimizer, batch, lengths)
101+
102+
memory.update_prior(indexes, td_error, lengths)
103+
104+
if steps % update_target == 0:
105+
update_target_model(online_net, target_net)
106+
107+
score = score if score == 500.0 else score + 1
108+
if running_score == 0:
109+
running_score = score
110+
else:
111+
running_score = 0.99 * running_score + 0.01 * score
112+
if e % log_interval == 0:
113+
print('{} episode | score: {:.2f} | epsilon: {:.2f}'.format(
114+
e, running_score, epsilon))
115+
writer.add_scalar('log/score', float(running_score), e)
116+
writer.add_scalar('log/loss', float(loss), e)
117+
118+
if running_score > goal_score:
119+
break
120+
121+
122+
if __name__=="__main__":
123+
main()

0 commit comments

Comments
 (0)