Skip to content

Commit 390a9e3

Browse files
committed
Feat: IQN
1 parent ea34053 commit 390a9e3

File tree

5 files changed

+226
-1
lines changed

5 files changed

+226
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ So you can run this example in your computer(maybe it take just only 1~2 minitue
3535

3636
## Distributional DQN
3737
- [x] QRDQN [[18]](#reference)
38-
- [ ] IQN [[19]](#reference)
38+
- [x] IQN [[19]](#reference)
3939

4040
## Exploration
4141
- [ ] ICM [[22]](#refercence)

distributional/2-IQN/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
3+
env_name = 'CartPole-v1'
4+
gamma = 0.99
5+
batch_size = 32
6+
lr = 0.001
7+
initial_exploration = 1000
8+
goal_score = 200
9+
log_interval = 10
10+
update_target = 100
11+
replay_memory_capacity = 1000
12+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13+
14+
15+
num_quantile_sample = 32
16+
num_tau_sample = 16
17+
num_tau_prime_sample = 8
18+
quantile_embedding_dim = 64

distributional/2-IQN/memory.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import random
2+
from collections import namedtuple, deque
3+
4+
5+
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask'))
6+
7+
8+
class Memory(object):
9+
def __init__(self, capacity):
10+
self.memory = deque(maxlen=capacity)
11+
self.capacity = capacity
12+
13+
def push(self, state, next_state, action, reward, mask):
14+
self.memory.append(Transition(state, next_state, action, reward, mask))
15+
16+
def sample(self, batch_size):
17+
transitions = random.sample(self.memory, batch_size)
18+
batch = Transition(*zip(*transitions))
19+
return batch
20+
21+
def __len__(self):
22+
return len(self.memory)

distributional/2-IQN/model.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
from config import batch_size, gamma, quantile_embedding_dim, num_tau_sample, num_tau_prime_sample, num_quantile_sample
7+
8+
class QRDQN(nn.Module):
9+
def __init__(self, num_inputs, num_outputs):
10+
super(QRDQN, self).__init__()
11+
self.num_inputs = num_inputs
12+
self.num_outputs = num_outputs
13+
14+
self.fc1 = nn.Linear(num_inputs, 128)
15+
self.fc2 = nn.Linear(128, num_outputs)
16+
self.phi = nn.Linear(quantile_embedding_dim, 128)
17+
18+
for m in self.modules():
19+
if isinstance(m, nn.Linear):
20+
nn.init.xavier_uniform(m.weight)
21+
22+
def forward(self, state, tau, num_quantiles):
23+
input_size = state.size()[0] # batch_size(train) or 1(get_action)
24+
tau = tau.expand(input_size * num_quantiles, quantile_embedding_dim)
25+
pi_mtx = torch.Tensor(np.pi * np.arange(0, quantile_embedding_dim)).expand(input_size * num_quantiles, quantile_embedding_dim)
26+
cos_tau = torch.cos(tau * pi_mtx)
27+
28+
phi = self.phi(cos_tau)
29+
phi = F.relu(phi)
30+
31+
state_tile = state.expand(input_size, num_quantiles, self.num_inputs)
32+
state_tile = state_tile.flatten().view(-1, self.num_inputs)
33+
34+
x = F.relu(self.fc1(state_tile))
35+
x = self.fc2(x * phi)
36+
z = x.view(-1, num_quantiles, self.num_outputs)
37+
38+
z = z.transpose(1, 2) # [input_size, num_output, num_quantile]
39+
return z
40+
41+
def get_action(self, state):
42+
tau = torch.Tensor(np.random.rand(num_quantile_sample, 1) * 0.5) # CVaR
43+
z = self.forward(state, tau, num_quantile_sample)
44+
q = z.mean(dim=2, keepdim=True)
45+
action = torch.argmax(q)
46+
return action.item()
47+
48+
@classmethod
49+
def train_model(cls, online_net, target_net, optimizer, batch):
50+
states = torch.stack(batch.state)
51+
next_states = torch.stack(batch.next_state)
52+
actions = torch.Tensor(batch.action).long()
53+
rewards = torch.Tensor(batch.reward)
54+
masks = torch.Tensor(batch.mask)
55+
56+
tau = torch.Tensor(np.random.rand(batch_size * num_tau_sample, 1))
57+
z = online_net(states, tau, num_tau_sample)
58+
action = actions.unsqueeze(1).unsqueeze(1).expand(-1, 1, num_tau_sample)
59+
z_a = z.gather(1, action).squeeze(1)
60+
61+
tau_prime = torch.Tensor(np.random.rand(batch_size * num_tau_prime_sample, 1))
62+
next_z = target_net(next_states, tau_prime, num_tau_prime_sample)
63+
next_action = next_z.mean(dim=2).max(1)[1]
64+
next_action = next_action.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, num_tau_prime_sample)
65+
next_z_a = next_z.gather(1, next_action).squeeze(1)
66+
67+
T_z = rewards.unsqueeze(1) + gamma * next_z_a * masks.unsqueeze(1)
68+
69+
T_z_tile = T_z.view(-1, num_tau_prime_sample, 1).expand(-1, num_tau_prime_sample, num_tau_sample)
70+
z_a_tile = z_a.view(-1, 1, num_tau_sample).expand(-1, num_tau_prime_sample, num_tau_sample)
71+
72+
error_loss = T_z_tile - z_a_tile
73+
huber_loss = nn.SmoothL1Loss(reduction='none')(T_z_tile, z_a_tile)
74+
tau = torch.arange(0, 1, 1 / num_tau_sample).view(1, num_tau_sample)
75+
76+
loss = (tau - (error_loss < 0).float()).abs() * huber_loss
77+
loss = loss.mean(dim=2).sum(dim=1).mean()
78+
79+
optimizer.zero_grad()
80+
loss.backward()
81+
optimizer.step()
82+
83+
return loss

distributional/2-IQN/train.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
import sys
3+
import gym
4+
import random
5+
import argparse
6+
import numpy as np
7+
8+
import torch
9+
import torch.optim as optim
10+
import torch.nn.functional as F
11+
from tensorboardX import SummaryWriter
12+
13+
from model import QRDQN
14+
from memory import Memory
15+
16+
from config import env_name, initial_exploration, batch_size, update_target, goal_score, log_interval, device, replay_memory_capacity, lr
17+
18+
19+
def get_action(state, target_net, epsilon, env):
20+
if np.random.rand() <= epsilon:
21+
return env.action_space.sample()
22+
else:
23+
return target_net.get_action(state)
24+
25+
def update_target_model(online_net, target_net):
26+
# Target <- Net
27+
target_net.load_state_dict(online_net.state_dict())
28+
29+
30+
def main():
31+
env = gym.make(env_name)
32+
env.seed(500)
33+
torch.manual_seed(500)
34+
35+
num_inputs = env.observation_space.shape[0]
36+
num_actions = env.action_space.n
37+
print('state size:', num_inputs)
38+
print('action size:', num_actions)
39+
40+
online_net = QRDQN(num_inputs, num_actions)
41+
target_net = QRDQN(num_inputs, num_actions)
42+
update_target_model(online_net, target_net)
43+
44+
optimizer = optim.Adam(online_net.parameters(), lr=lr)
45+
writer = SummaryWriter('logs')
46+
47+
online_net.to(device)
48+
target_net.to(device)
49+
online_net.train()
50+
target_net.train()
51+
memory = Memory(replay_memory_capacity)
52+
running_score = 0
53+
epsilon = 1.0
54+
steps = 0
55+
loss = 0
56+
57+
for e in range(3000):
58+
done = False
59+
60+
score = 0
61+
state = env.reset()
62+
state = torch.Tensor(state)
63+
state = state.unsqueeze(0)
64+
65+
while not done:
66+
steps += 1
67+
action = get_action(state, target_net, epsilon, env)
68+
next_state, reward, done, _ = env.step(action)
69+
70+
next_state = torch.Tensor(next_state)
71+
next_state = next_state.unsqueeze(0)
72+
73+
mask = 0 if done else 1
74+
reward = reward if not done or score == 499 else -1
75+
memory.push(state, next_state, action, reward, mask)
76+
77+
score += reward
78+
state = next_state
79+
80+
if steps > initial_exploration:
81+
epsilon -= 0.00005
82+
epsilon = max(epsilon, 0.1)
83+
84+
batch = memory.sample(batch_size)
85+
loss = QRDQN.train_model(online_net, target_net, optimizer, batch)
86+
87+
if steps % update_target == 0:
88+
update_target_model(online_net, target_net)
89+
90+
score = score if score == 500.0 else score + 1
91+
running_score = 0.99 * running_score + 0.01 * score
92+
if e % log_interval == 0:
93+
print('{} episode | score: {:.2f} | epsilon: {:.2f}'.format(
94+
e, running_score, epsilon))
95+
writer.add_scalar('log/score', float(running_score), e)
96+
writer.add_scalar('log/loss', float(loss), e)
97+
98+
if running_score > goal_score:
99+
break
100+
101+
if __name__=="__main__":
102+
main()

0 commit comments

Comments
 (0)