Skip to content

Commit

Permalink
Simple Implementation of ICM
Browse files Browse the repository at this point in the history
  • Loading branch information
philtabor committed Oct 5, 2021
1 parent b5be8d7 commit 7e99e19
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 0 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
80 changes: 80 additions & 0 deletions ReinforcementLearning/ICM/actor_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
import torch as T
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


class ActorCritic(nn.Module):
def __init__(self, input_dims, n_actions, gamma=0.99, tau=0.98):
super(ActorCritic, self).__init__()
self.gamma = gamma
self.tau = tau

self.input = nn.Linear(*input_dims, 256)
self.dense = nn.Linear(256, 256)

self.gru = nn.GRUCell(256, 256)
self.pi = nn.Linear(256, n_actions)
self.v = nn.Linear(256, 1)

def forward(self, state, hx):
x = F.relu(self.input(state))
x = F.relu(self.dense(x))
hx = self.gru(x, (hx))

pi = self.pi(hx)
v = self.v(hx)

probs = T.softmax(pi, dim=1)
dist = Categorical(probs)
action = dist.sample()
log_prob = dist.log_prob(action)

return action.numpy()[0], v, log_prob, hx

def calc_R(self, done, rewards, values):
values = T.cat(values).squeeze()
if len(values.size()) == 1: # batch of states
R = values[-1] * (1-int(done))
elif len(values.size()) == 0: # single state
R = values*(1-int(done))

batch_return = []
for reward in rewards[::-1]:
R = reward + self.gamma * R
batch_return.append(R)
batch_return.reverse()
batch_return = T.tensor(batch_return,
dtype=T.float).reshape(values.size())
return batch_return

def calc_loss(self, new_states, hx, done,
rewards, values, log_probs, r_i_t=None):
if r_i_t is not None:
rewards += r_i_t.detach().numpy()
returns = self.calc_R(done, rewards, values)
next_v = T.zeros(1, 1) if done else self.forward(T.tensor([new_states],
dtype=T.float), hx)[1]

values.append(next_v.detach())
values = T.cat(values).squeeze()
log_probs = T.cat(log_probs)
rewards = T.tensor(rewards)

delta_t = rewards + self.gamma*values[1:] - values[:-1]
n_steps = len(delta_t)
gae = np.zeros(n_steps)
for t in range(n_steps):
for k in range(0, n_steps-t):
temp = (self.gamma*self.tau)**k*delta_t[t+k]
gae[t] += temp
gae = T.tensor(gae, dtype=T.float)

actor_loss = -(log_probs*gae).sum()
entropy_loss = (-log_probs*T.exp(log_probs)).sum()
# [T] vs ()
critic_loss = F.mse_loss(values[:-1].squeeze(), returns)

total_loss = actor_loss + critic_loss - 0.01*entropy_loss
return total_loss
47 changes: 47 additions & 0 deletions ReinforcementLearning/ICM/icm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch as T
import torch.nn as nn
import torch.nn.functional as F


class ICM(nn.Module):
def __init__(self, input_dims, n_actions=2, alpha=1, beta=0.2):
super(ICM, self).__init__()
self.alpha = alpha
self.beta = beta
# hard coded for cartpole environment
self.inverse = nn.Linear(4*2, 256)
self.pi_logits = nn.Linear(256, n_actions)

self.dense1 = nn.Linear(4+1, 256)
self.new_state = nn.Linear(256, 4)

device = T.device('cpu')
self.to(device)

def forward(self, state, new_state, action):
inverse = F.elu(self.inverse(T.cat([state, new_state], dim=1)))
pi_logits = self.pi_logits(inverse)

# from [T] to [T,1]
action = action.reshape((action.size()[0], 1))
forward_input = T.cat([state, action], dim=1)
dense = F.elu(self.dense1(forward_input))
state_ = self.new_state(dense)

return pi_logits, state_

def calc_loss(self, state, new_state, action):
state = T.tensor(state, dtype=T.float)
action = T.tensor(action, dtype=T.float)
new_state = T.tensor(new_state, dtype=T.float)

pi_logits, state_ = self.forward(state, new_state, action)

inverse_loss = nn.CrossEntropyLoss()
L_I = (1-self.beta)*inverse_loss(pi_logits, action.to(T.long))

forward_loss = nn.MSELoss()
L_F = self.beta*forward_loss(state_, new_state)

intrinsic_reward = self.alpha*((state_ - new_state).pow(2)).mean(dim=1)
return intrinsic_reward, L_I, L_F
15 changes: 15 additions & 0 deletions ReinforcementLearning/ICM/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
import torch.multiprocessing as mp
from parallel_env import ParallelEnv

os.environ['OMP_NUM_THREADS'] = '1'


if __name__ == '__main__':
mp.set_start_method('spawn')
env_id = 'CartPole-v0'
n_threads = 12
n_actions = 2
input_shape = [4]
env = ParallelEnv(env_id=env_id, n_threads=n_threads,
n_actions=n_actions, input_shape=input_shape, icm=True)
28 changes: 28 additions & 0 deletions ReinforcementLearning/ICM/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
class Memory:
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.new_states = []
self.values = []
self.log_probs = []

def remember(self, state, action, reward, new_state, value, log_p):
self.actions.append(action)
self.rewards.append(reward)
self.states.append(state)
self.new_states.append(new_state)
self.log_probs.append(log_p)
self.values.append(value)

def clear_memory(self):
self.states = []
self.actions = []
self.rewards = []
self.new_states = []
self.values = []
self.log_probs = []

def sample_memory(self):
return self.states, self.actions, self.rewards, self.new_states,\
self.values, self.log_probs
32 changes: 32 additions & 0 deletions ReinforcementLearning/ICM/parallel_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch.multiprocessing as mp
from actor_critic import ActorCritic
from icm import ICM
from shared_adam import SharedAdam
from worker import worker


class ParallelEnv:
def __init__(self, env_id, input_shape, n_actions, icm, n_threads=8):
names = [str(i) for i in range(1, n_threads+1)]

global_actor_critic = ActorCritic(input_shape, n_actions)
global_actor_critic.share_memory()
global_optim = SharedAdam(global_actor_critic.parameters())

if not icm:
global_icm = None
global_icm_optim = None
else:
global_icm = ICM(input_shape, n_actions)
global_icm.share_memory()
global_icm_optim = SharedAdam(global_icm.parameters())

self.ps = [mp.Process(target=worker,
args=(name, input_shape, n_actions,
global_actor_critic, global_icm,
global_optim, global_icm_optim, env_id,
n_threads, icm))
for name in names]

[p.start() for p in self.ps]
[p.join() for p in self.ps]
18 changes: 18 additions & 0 deletions ReinforcementLearning/ICM/shared_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch as T


class SharedAdam(T.optim.Adam):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), eps=1e-8,
weight_decay=0):
super(SharedAdam, self).__init__(params, lr=lr, betas=betas,
eps=eps, weight_decay=weight_decay)

for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['exp_avg'] = T.zeros_like(p.data)
state['exp_avg_sq'] = T.zeros_like(p.data)

state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()
11 changes: 11 additions & 0 deletions ReinforcementLearning/ICM/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import matplotlib.pyplot as plt
import numpy as np


def plot_learning_curve(x, scores, figure_file):
running_avg = np.zeros(len(scores))
for i in range(len(running_avg)):
running_avg[i] = np.mean(scores[max(0, i-100):(i+1)])
plt.plot(x, running_avg)
plt.title('Running average of previous 100 episodes')
plt.savefig(figure_file)
91 changes: 91 additions & 0 deletions ReinforcementLearning/ICM/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import gym
import numpy as np
import torch as T
from actor_critic import ActorCritic
from icm import ICM
from memory import Memory
from utils import plot_learning_curve


def worker(name, input_shape, n_actions, global_agent, global_icm,
optimizer, icm_optimizer, env_id, n_threads, icm=False):
T_MAX = 20

local_agent = ActorCritic(input_shape, n_actions)

if icm:
local_icm = ICM(input_shape, n_actions)
algo = 'ICM'
else:
intrinsic_reward = T.zeros(1)
algo = 'A3C'

memory = Memory()

env = gym.make(env_id)

t_steps, max_eps, episode, scores, avg_score = 0, 1000, 0, [], 0

while episode < max_eps:
obs = env.reset()
hx = T.zeros(1, 256)
score, done, ep_steps = 0, False, 0
while not done:
state = T.tensor([obs], dtype=T.float)
action, value, log_prob, hx = local_agent(state, hx)
obs_, reward, done, info = env.step(action)
t_steps += 1
ep_steps += 1
score += reward
reward = 0 # turn off extrinsic rewards
memory.remember(obs, action, reward, obs_, value, log_prob)
obs = obs_
if ep_steps % T_MAX == 0 or done:
states, actions, rewards, new_states, values, log_probs = \
memory.sample_memory()
if icm:
intrinsic_reward, L_I, L_F = \
local_icm.calc_loss(states, new_states, actions)

loss = local_agent.calc_loss(obs, hx, done, rewards, values,
log_probs, intrinsic_reward)

optimizer.zero_grad()
hx = hx.detach_()
if icm:
icm_optimizer.zero_grad()
(L_I + L_F).backward()

loss.backward()
T.nn.utils.clip_grad_norm_(local_agent.parameters(), 40)

for local_param, global_param in zip(
local_agent.parameters(),
global_agent.parameters()):
global_param._grad = local_param.grad
optimizer.step()
local_agent.load_state_dict(global_agent.state_dict())

if icm:
for local_param, global_param in zip(
local_icm.parameters(),
global_icm.parameters()):
global_param._grad = local_param.grad
icm_optimizer.step()
local_icm.load_state_dict(global_icm.state_dict())
memory.clear_memory()

if name == '1':
scores.append(score)
avg_score = np.mean(scores[-100:])
print('{} episode {} thread {} of {} steps {:.2f}M score {:.2f} '
'intrinsic_reward {:.2f} avg score (100) {:.1f}'.format(
algo, episode, name, n_threads,
t_steps/1e6, score,
T.sum(intrinsic_reward),
avg_score))
episode += 1
if name == '1':
x = [z for z in range(episode)]
fname = algo + '_CartPole_no_rewards.png'
plot_learning_curve(x, scores, fname)

0 comments on commit 7e99e19

Please sign in to comment.