forked from philtabor/Youtube-Code-Repository
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
import torch as T | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.distributions import Categorical | ||
|
||
|
||
class ActorCritic(nn.Module): | ||
def __init__(self, input_dims, n_actions, gamma=0.99, tau=0.98): | ||
super(ActorCritic, self).__init__() | ||
self.gamma = gamma | ||
self.tau = tau | ||
|
||
self.input = nn.Linear(*input_dims, 256) | ||
self.dense = nn.Linear(256, 256) | ||
|
||
self.gru = nn.GRUCell(256, 256) | ||
self.pi = nn.Linear(256, n_actions) | ||
self.v = nn.Linear(256, 1) | ||
|
||
def forward(self, state, hx): | ||
x = F.relu(self.input(state)) | ||
x = F.relu(self.dense(x)) | ||
hx = self.gru(x, (hx)) | ||
|
||
pi = self.pi(hx) | ||
v = self.v(hx) | ||
|
||
probs = T.softmax(pi, dim=1) | ||
dist = Categorical(probs) | ||
action = dist.sample() | ||
log_prob = dist.log_prob(action) | ||
|
||
return action.numpy()[0], v, log_prob, hx | ||
|
||
def calc_R(self, done, rewards, values): | ||
values = T.cat(values).squeeze() | ||
if len(values.size()) == 1: # batch of states | ||
R = values[-1] * (1-int(done)) | ||
elif len(values.size()) == 0: # single state | ||
R = values*(1-int(done)) | ||
|
||
batch_return = [] | ||
for reward in rewards[::-1]: | ||
R = reward + self.gamma * R | ||
batch_return.append(R) | ||
batch_return.reverse() | ||
batch_return = T.tensor(batch_return, | ||
dtype=T.float).reshape(values.size()) | ||
return batch_return | ||
|
||
def calc_loss(self, new_states, hx, done, | ||
rewards, values, log_probs, r_i_t=None): | ||
if r_i_t is not None: | ||
rewards += r_i_t.detach().numpy() | ||
returns = self.calc_R(done, rewards, values) | ||
next_v = T.zeros(1, 1) if done else self.forward(T.tensor([new_states], | ||
dtype=T.float), hx)[1] | ||
|
||
values.append(next_v.detach()) | ||
values = T.cat(values).squeeze() | ||
log_probs = T.cat(log_probs) | ||
rewards = T.tensor(rewards) | ||
|
||
delta_t = rewards + self.gamma*values[1:] - values[:-1] | ||
n_steps = len(delta_t) | ||
gae = np.zeros(n_steps) | ||
for t in range(n_steps): | ||
for k in range(0, n_steps-t): | ||
temp = (self.gamma*self.tau)**k*delta_t[t+k] | ||
gae[t] += temp | ||
gae = T.tensor(gae, dtype=T.float) | ||
|
||
actor_loss = -(log_probs*gae).sum() | ||
entropy_loss = (-log_probs*T.exp(log_probs)).sum() | ||
# [T] vs () | ||
critic_loss = F.mse_loss(values[:-1].squeeze(), returns) | ||
|
||
total_loss = actor_loss + critic_loss - 0.01*entropy_loss | ||
return total_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch as T | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class ICM(nn.Module): | ||
def __init__(self, input_dims, n_actions=2, alpha=1, beta=0.2): | ||
super(ICM, self).__init__() | ||
self.alpha = alpha | ||
self.beta = beta | ||
# hard coded for cartpole environment | ||
self.inverse = nn.Linear(4*2, 256) | ||
self.pi_logits = nn.Linear(256, n_actions) | ||
|
||
self.dense1 = nn.Linear(4+1, 256) | ||
self.new_state = nn.Linear(256, 4) | ||
|
||
device = T.device('cpu') | ||
self.to(device) | ||
|
||
def forward(self, state, new_state, action): | ||
inverse = F.elu(self.inverse(T.cat([state, new_state], dim=1))) | ||
pi_logits = self.pi_logits(inverse) | ||
|
||
# from [T] to [T,1] | ||
action = action.reshape((action.size()[0], 1)) | ||
forward_input = T.cat([state, action], dim=1) | ||
dense = F.elu(self.dense1(forward_input)) | ||
state_ = self.new_state(dense) | ||
|
||
return pi_logits, state_ | ||
|
||
def calc_loss(self, state, new_state, action): | ||
state = T.tensor(state, dtype=T.float) | ||
action = T.tensor(action, dtype=T.float) | ||
new_state = T.tensor(new_state, dtype=T.float) | ||
|
||
pi_logits, state_ = self.forward(state, new_state, action) | ||
|
||
inverse_loss = nn.CrossEntropyLoss() | ||
L_I = (1-self.beta)*inverse_loss(pi_logits, action.to(T.long)) | ||
|
||
forward_loss = nn.MSELoss() | ||
L_F = self.beta*forward_loss(state_, new_state) | ||
|
||
intrinsic_reward = self.alpha*((state_ - new_state).pow(2)).mean(dim=1) | ||
return intrinsic_reward, L_I, L_F |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import os | ||
import torch.multiprocessing as mp | ||
from parallel_env import ParallelEnv | ||
|
||
os.environ['OMP_NUM_THREADS'] = '1' | ||
|
||
|
||
if __name__ == '__main__': | ||
mp.set_start_method('spawn') | ||
env_id = 'CartPole-v0' | ||
n_threads = 12 | ||
n_actions = 2 | ||
input_shape = [4] | ||
env = ParallelEnv(env_id=env_id, n_threads=n_threads, | ||
n_actions=n_actions, input_shape=input_shape, icm=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
class Memory: | ||
def __init__(self): | ||
self.states = [] | ||
self.actions = [] | ||
self.rewards = [] | ||
self.new_states = [] | ||
self.values = [] | ||
self.log_probs = [] | ||
|
||
def remember(self, state, action, reward, new_state, value, log_p): | ||
self.actions.append(action) | ||
self.rewards.append(reward) | ||
self.states.append(state) | ||
self.new_states.append(new_state) | ||
self.log_probs.append(log_p) | ||
self.values.append(value) | ||
|
||
def clear_memory(self): | ||
self.states = [] | ||
self.actions = [] | ||
self.rewards = [] | ||
self.new_states = [] | ||
self.values = [] | ||
self.log_probs = [] | ||
|
||
def sample_memory(self): | ||
return self.states, self.actions, self.rewards, self.new_states,\ | ||
self.values, self.log_probs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch.multiprocessing as mp | ||
from actor_critic import ActorCritic | ||
from icm import ICM | ||
from shared_adam import SharedAdam | ||
from worker import worker | ||
|
||
|
||
class ParallelEnv: | ||
def __init__(self, env_id, input_shape, n_actions, icm, n_threads=8): | ||
names = [str(i) for i in range(1, n_threads+1)] | ||
|
||
global_actor_critic = ActorCritic(input_shape, n_actions) | ||
global_actor_critic.share_memory() | ||
global_optim = SharedAdam(global_actor_critic.parameters()) | ||
|
||
if not icm: | ||
global_icm = None | ||
global_icm_optim = None | ||
else: | ||
global_icm = ICM(input_shape, n_actions) | ||
global_icm.share_memory() | ||
global_icm_optim = SharedAdam(global_icm.parameters()) | ||
|
||
self.ps = [mp.Process(target=worker, | ||
args=(name, input_shape, n_actions, | ||
global_actor_critic, global_icm, | ||
global_optim, global_icm_optim, env_id, | ||
n_threads, icm)) | ||
for name in names] | ||
|
||
[p.start() for p in self.ps] | ||
[p.join() for p in self.ps] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch as T | ||
|
||
|
||
class SharedAdam(T.optim.Adam): | ||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), eps=1e-8, | ||
weight_decay=0): | ||
super(SharedAdam, self).__init__(params, lr=lr, betas=betas, | ||
eps=eps, weight_decay=weight_decay) | ||
|
||
for group in self.param_groups: | ||
for p in group['params']: | ||
state = self.state[p] | ||
state['step'] = 0 | ||
state['exp_avg'] = T.zeros_like(p.data) | ||
state['exp_avg_sq'] = T.zeros_like(p.data) | ||
|
||
state['exp_avg'].share_memory_() | ||
state['exp_avg_sq'].share_memory_() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
def plot_learning_curve(x, scores, figure_file): | ||
running_avg = np.zeros(len(scores)) | ||
for i in range(len(running_avg)): | ||
running_avg[i] = np.mean(scores[max(0, i-100):(i+1)]) | ||
plt.plot(x, running_avg) | ||
plt.title('Running average of previous 100 episodes') | ||
plt.savefig(figure_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import gym | ||
import numpy as np | ||
import torch as T | ||
from actor_critic import ActorCritic | ||
from icm import ICM | ||
from memory import Memory | ||
from utils import plot_learning_curve | ||
|
||
|
||
def worker(name, input_shape, n_actions, global_agent, global_icm, | ||
optimizer, icm_optimizer, env_id, n_threads, icm=False): | ||
T_MAX = 20 | ||
|
||
local_agent = ActorCritic(input_shape, n_actions) | ||
|
||
if icm: | ||
local_icm = ICM(input_shape, n_actions) | ||
algo = 'ICM' | ||
else: | ||
intrinsic_reward = T.zeros(1) | ||
algo = 'A3C' | ||
|
||
memory = Memory() | ||
|
||
env = gym.make(env_id) | ||
|
||
t_steps, max_eps, episode, scores, avg_score = 0, 1000, 0, [], 0 | ||
|
||
while episode < max_eps: | ||
obs = env.reset() | ||
hx = T.zeros(1, 256) | ||
score, done, ep_steps = 0, False, 0 | ||
while not done: | ||
state = T.tensor([obs], dtype=T.float) | ||
action, value, log_prob, hx = local_agent(state, hx) | ||
obs_, reward, done, info = env.step(action) | ||
t_steps += 1 | ||
ep_steps += 1 | ||
score += reward | ||
reward = 0 # turn off extrinsic rewards | ||
memory.remember(obs, action, reward, obs_, value, log_prob) | ||
obs = obs_ | ||
if ep_steps % T_MAX == 0 or done: | ||
states, actions, rewards, new_states, values, log_probs = \ | ||
memory.sample_memory() | ||
if icm: | ||
intrinsic_reward, L_I, L_F = \ | ||
local_icm.calc_loss(states, new_states, actions) | ||
|
||
loss = local_agent.calc_loss(obs, hx, done, rewards, values, | ||
log_probs, intrinsic_reward) | ||
|
||
optimizer.zero_grad() | ||
hx = hx.detach_() | ||
if icm: | ||
icm_optimizer.zero_grad() | ||
(L_I + L_F).backward() | ||
|
||
loss.backward() | ||
T.nn.utils.clip_grad_norm_(local_agent.parameters(), 40) | ||
|
||
for local_param, global_param in zip( | ||
local_agent.parameters(), | ||
global_agent.parameters()): | ||
global_param._grad = local_param.grad | ||
optimizer.step() | ||
local_agent.load_state_dict(global_agent.state_dict()) | ||
|
||
if icm: | ||
for local_param, global_param in zip( | ||
local_icm.parameters(), | ||
global_icm.parameters()): | ||
global_param._grad = local_param.grad | ||
icm_optimizer.step() | ||
local_icm.load_state_dict(global_icm.state_dict()) | ||
memory.clear_memory() | ||
|
||
if name == '1': | ||
scores.append(score) | ||
avg_score = np.mean(scores[-100:]) | ||
print('{} episode {} thread {} of {} steps {:.2f}M score {:.2f} ' | ||
'intrinsic_reward {:.2f} avg score (100) {:.1f}'.format( | ||
algo, episode, name, n_threads, | ||
t_steps/1e6, score, | ||
T.sum(intrinsic_reward), | ||
avg_score)) | ||
episode += 1 | ||
if name == '1': | ||
x = [z for z in range(episode)] | ||
fname = algo + '_CartPole_no_rewards.png' | ||
plot_learning_curve(x, scores, fname) |