Skip to content

Commit 7e99e19

Browse files
committed
Simple Implementation of ICM
1 parent b5be8d7 commit 7e99e19

10 files changed

+322
-0
lines changed
Loading
Loading
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import numpy as np
2+
import torch as T
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from torch.distributions import Categorical
6+
7+
8+
class ActorCritic(nn.Module):
9+
def __init__(self, input_dims, n_actions, gamma=0.99, tau=0.98):
10+
super(ActorCritic, self).__init__()
11+
self.gamma = gamma
12+
self.tau = tau
13+
14+
self.input = nn.Linear(*input_dims, 256)
15+
self.dense = nn.Linear(256, 256)
16+
17+
self.gru = nn.GRUCell(256, 256)
18+
self.pi = nn.Linear(256, n_actions)
19+
self.v = nn.Linear(256, 1)
20+
21+
def forward(self, state, hx):
22+
x = F.relu(self.input(state))
23+
x = F.relu(self.dense(x))
24+
hx = self.gru(x, (hx))
25+
26+
pi = self.pi(hx)
27+
v = self.v(hx)
28+
29+
probs = T.softmax(pi, dim=1)
30+
dist = Categorical(probs)
31+
action = dist.sample()
32+
log_prob = dist.log_prob(action)
33+
34+
return action.numpy()[0], v, log_prob, hx
35+
36+
def calc_R(self, done, rewards, values):
37+
values = T.cat(values).squeeze()
38+
if len(values.size()) == 1: # batch of states
39+
R = values[-1] * (1-int(done))
40+
elif len(values.size()) == 0: # single state
41+
R = values*(1-int(done))
42+
43+
batch_return = []
44+
for reward in rewards[::-1]:
45+
R = reward + self.gamma * R
46+
batch_return.append(R)
47+
batch_return.reverse()
48+
batch_return = T.tensor(batch_return,
49+
dtype=T.float).reshape(values.size())
50+
return batch_return
51+
52+
def calc_loss(self, new_states, hx, done,
53+
rewards, values, log_probs, r_i_t=None):
54+
if r_i_t is not None:
55+
rewards += r_i_t.detach().numpy()
56+
returns = self.calc_R(done, rewards, values)
57+
next_v = T.zeros(1, 1) if done else self.forward(T.tensor([new_states],
58+
dtype=T.float), hx)[1]
59+
60+
values.append(next_v.detach())
61+
values = T.cat(values).squeeze()
62+
log_probs = T.cat(log_probs)
63+
rewards = T.tensor(rewards)
64+
65+
delta_t = rewards + self.gamma*values[1:] - values[:-1]
66+
n_steps = len(delta_t)
67+
gae = np.zeros(n_steps)
68+
for t in range(n_steps):
69+
for k in range(0, n_steps-t):
70+
temp = (self.gamma*self.tau)**k*delta_t[t+k]
71+
gae[t] += temp
72+
gae = T.tensor(gae, dtype=T.float)
73+
74+
actor_loss = -(log_probs*gae).sum()
75+
entropy_loss = (-log_probs*T.exp(log_probs)).sum()
76+
# [T] vs ()
77+
critic_loss = F.mse_loss(values[:-1].squeeze(), returns)
78+
79+
total_loss = actor_loss + critic_loss - 0.01*entropy_loss
80+
return total_loss

ReinforcementLearning/ICM/icm.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import torch as T
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class ICM(nn.Module):
7+
def __init__(self, input_dims, n_actions=2, alpha=1, beta=0.2):
8+
super(ICM, self).__init__()
9+
self.alpha = alpha
10+
self.beta = beta
11+
# hard coded for cartpole environment
12+
self.inverse = nn.Linear(4*2, 256)
13+
self.pi_logits = nn.Linear(256, n_actions)
14+
15+
self.dense1 = nn.Linear(4+1, 256)
16+
self.new_state = nn.Linear(256, 4)
17+
18+
device = T.device('cpu')
19+
self.to(device)
20+
21+
def forward(self, state, new_state, action):
22+
inverse = F.elu(self.inverse(T.cat([state, new_state], dim=1)))
23+
pi_logits = self.pi_logits(inverse)
24+
25+
# from [T] to [T,1]
26+
action = action.reshape((action.size()[0], 1))
27+
forward_input = T.cat([state, action], dim=1)
28+
dense = F.elu(self.dense1(forward_input))
29+
state_ = self.new_state(dense)
30+
31+
return pi_logits, state_
32+
33+
def calc_loss(self, state, new_state, action):
34+
state = T.tensor(state, dtype=T.float)
35+
action = T.tensor(action, dtype=T.float)
36+
new_state = T.tensor(new_state, dtype=T.float)
37+
38+
pi_logits, state_ = self.forward(state, new_state, action)
39+
40+
inverse_loss = nn.CrossEntropyLoss()
41+
L_I = (1-self.beta)*inverse_loss(pi_logits, action.to(T.long))
42+
43+
forward_loss = nn.MSELoss()
44+
L_F = self.beta*forward_loss(state_, new_state)
45+
46+
intrinsic_reward = self.alpha*((state_ - new_state).pow(2)).mean(dim=1)
47+
return intrinsic_reward, L_I, L_F

ReinforcementLearning/ICM/main.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
import torch.multiprocessing as mp
3+
from parallel_env import ParallelEnv
4+
5+
os.environ['OMP_NUM_THREADS'] = '1'
6+
7+
8+
if __name__ == '__main__':
9+
mp.set_start_method('spawn')
10+
env_id = 'CartPole-v0'
11+
n_threads = 12
12+
n_actions = 2
13+
input_shape = [4]
14+
env = ParallelEnv(env_id=env_id, n_threads=n_threads,
15+
n_actions=n_actions, input_shape=input_shape, icm=True)

ReinforcementLearning/ICM/memory.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
class Memory:
2+
def __init__(self):
3+
self.states = []
4+
self.actions = []
5+
self.rewards = []
6+
self.new_states = []
7+
self.values = []
8+
self.log_probs = []
9+
10+
def remember(self, state, action, reward, new_state, value, log_p):
11+
self.actions.append(action)
12+
self.rewards.append(reward)
13+
self.states.append(state)
14+
self.new_states.append(new_state)
15+
self.log_probs.append(log_p)
16+
self.values.append(value)
17+
18+
def clear_memory(self):
19+
self.states = []
20+
self.actions = []
21+
self.rewards = []
22+
self.new_states = []
23+
self.values = []
24+
self.log_probs = []
25+
26+
def sample_memory(self):
27+
return self.states, self.actions, self.rewards, self.new_states,\
28+
self.values, self.log_probs
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch.multiprocessing as mp
2+
from actor_critic import ActorCritic
3+
from icm import ICM
4+
from shared_adam import SharedAdam
5+
from worker import worker
6+
7+
8+
class ParallelEnv:
9+
def __init__(self, env_id, input_shape, n_actions, icm, n_threads=8):
10+
names = [str(i) for i in range(1, n_threads+1)]
11+
12+
global_actor_critic = ActorCritic(input_shape, n_actions)
13+
global_actor_critic.share_memory()
14+
global_optim = SharedAdam(global_actor_critic.parameters())
15+
16+
if not icm:
17+
global_icm = None
18+
global_icm_optim = None
19+
else:
20+
global_icm = ICM(input_shape, n_actions)
21+
global_icm.share_memory()
22+
global_icm_optim = SharedAdam(global_icm.parameters())
23+
24+
self.ps = [mp.Process(target=worker,
25+
args=(name, input_shape, n_actions,
26+
global_actor_critic, global_icm,
27+
global_optim, global_icm_optim, env_id,
28+
n_threads, icm))
29+
for name in names]
30+
31+
[p.start() for p in self.ps]
32+
[p.join() for p in self.ps]
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch as T
2+
3+
4+
class SharedAdam(T.optim.Adam):
5+
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), eps=1e-8,
6+
weight_decay=0):
7+
super(SharedAdam, self).__init__(params, lr=lr, betas=betas,
8+
eps=eps, weight_decay=weight_decay)
9+
10+
for group in self.param_groups:
11+
for p in group['params']:
12+
state = self.state[p]
13+
state['step'] = 0
14+
state['exp_avg'] = T.zeros_like(p.data)
15+
state['exp_avg_sq'] = T.zeros_like(p.data)
16+
17+
state['exp_avg'].share_memory_()
18+
state['exp_avg_sq'].share_memory_()

ReinforcementLearning/ICM/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
4+
5+
def plot_learning_curve(x, scores, figure_file):
6+
running_avg = np.zeros(len(scores))
7+
for i in range(len(running_avg)):
8+
running_avg[i] = np.mean(scores[max(0, i-100):(i+1)])
9+
plt.plot(x, running_avg)
10+
plt.title('Running average of previous 100 episodes')
11+
plt.savefig(figure_file)

ReinforcementLearning/ICM/worker.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import gym
2+
import numpy as np
3+
import torch as T
4+
from actor_critic import ActorCritic
5+
from icm import ICM
6+
from memory import Memory
7+
from utils import plot_learning_curve
8+
9+
10+
def worker(name, input_shape, n_actions, global_agent, global_icm,
11+
optimizer, icm_optimizer, env_id, n_threads, icm=False):
12+
T_MAX = 20
13+
14+
local_agent = ActorCritic(input_shape, n_actions)
15+
16+
if icm:
17+
local_icm = ICM(input_shape, n_actions)
18+
algo = 'ICM'
19+
else:
20+
intrinsic_reward = T.zeros(1)
21+
algo = 'A3C'
22+
23+
memory = Memory()
24+
25+
env = gym.make(env_id)
26+
27+
t_steps, max_eps, episode, scores, avg_score = 0, 1000, 0, [], 0
28+
29+
while episode < max_eps:
30+
obs = env.reset()
31+
hx = T.zeros(1, 256)
32+
score, done, ep_steps = 0, False, 0
33+
while not done:
34+
state = T.tensor([obs], dtype=T.float)
35+
action, value, log_prob, hx = local_agent(state, hx)
36+
obs_, reward, done, info = env.step(action)
37+
t_steps += 1
38+
ep_steps += 1
39+
score += reward
40+
reward = 0 # turn off extrinsic rewards
41+
memory.remember(obs, action, reward, obs_, value, log_prob)
42+
obs = obs_
43+
if ep_steps % T_MAX == 0 or done:
44+
states, actions, rewards, new_states, values, log_probs = \
45+
memory.sample_memory()
46+
if icm:
47+
intrinsic_reward, L_I, L_F = \
48+
local_icm.calc_loss(states, new_states, actions)
49+
50+
loss = local_agent.calc_loss(obs, hx, done, rewards, values,
51+
log_probs, intrinsic_reward)
52+
53+
optimizer.zero_grad()
54+
hx = hx.detach_()
55+
if icm:
56+
icm_optimizer.zero_grad()
57+
(L_I + L_F).backward()
58+
59+
loss.backward()
60+
T.nn.utils.clip_grad_norm_(local_agent.parameters(), 40)
61+
62+
for local_param, global_param in zip(
63+
local_agent.parameters(),
64+
global_agent.parameters()):
65+
global_param._grad = local_param.grad
66+
optimizer.step()
67+
local_agent.load_state_dict(global_agent.state_dict())
68+
69+
if icm:
70+
for local_param, global_param in zip(
71+
local_icm.parameters(),
72+
global_icm.parameters()):
73+
global_param._grad = local_param.grad
74+
icm_optimizer.step()
75+
local_icm.load_state_dict(global_icm.state_dict())
76+
memory.clear_memory()
77+
78+
if name == '1':
79+
scores.append(score)
80+
avg_score = np.mean(scores[-100:])
81+
print('{} episode {} thread {} of {} steps {:.2f}M score {:.2f} '
82+
'intrinsic_reward {:.2f} avg score (100) {:.1f}'.format(
83+
algo, episode, name, n_threads,
84+
t_steps/1e6, score,
85+
T.sum(intrinsic_reward),
86+
avg_score))
87+
episode += 1
88+
if name == '1':
89+
x = [z for z in range(episode)]
90+
fname = algo + '_CartPole_no_rewards.png'
91+
plot_learning_curve(x, scores, fname)

0 commit comments

Comments
 (0)