Skip to content

Commit

Permalink
added SAC implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Mar 7, 2021
1 parent 4483219 commit 68809d8
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 14 deletions.
125 changes: 120 additions & 5 deletions policy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from torch.nn import functional as F
from copy import deepcopy

from policy.networks import ActorCritic, Actor, Critic
from policy.networks import ActorCritic, Actor, Critic, SACActor, SACCritic, SACValue
from policy.utils import ReplayBuffer, OUActionNoise, clip_action, GaussianActionNoise

torch.autograd.set_detect_anomaly(True)

class BlackJackAgent:
def __init__(self, method, env, function='V', gamma=0.99, epsilon=0.1):
Expand Down Expand Up @@ -332,7 +332,7 @@ def choose_action(self, observation, test):
with torch.no_grad():
action = self.actor(observation)
if not test:
action = action + self.noise(action.size())
action = action + self.noise(action.size()).to(self.device)
self.actor.train()
action = action.cpu().detach().numpy()
# clip noised action to ensure not out of bounds
Expand All @@ -357,12 +357,14 @@ def update(self):

# calculate targets & only update online critic network
self.critic.optimizer.zero_grad()
self.critic2.optimizer.zero_grad()
with torch.no_grad():
# y <- r + gamma * min_(i=1,2) Q_(theta'_i)(s', a_telda)
target_actions = self.target_actor(next_states)
target_actions += self.noise(
target_actions.size(), clip=self.action_clip, sigma=self.action_sigma)
target_actions = clip_action(target_actions, self.max_action)
target_actions.size(), clip=self.action_clip, sigma=self.action_sigma).to(self.device)
target_actions = clip_action(target_actions.cpu().numpy(), self.max_action)
target_actions = torch.from_numpy(target_actions).to(self.device)
q_primes1 = self.target_critic(next_states, target_actions).squeeze()
q_primes2 = self.target_critic2(next_states, target_actions).squeeze()
q_primes = torch.min(q_primes1, q_primes2)
Expand All @@ -375,4 +377,117 @@ def update(self):
critic_loss = critic_loss1 + critic_loss2
critic_loss.backward()
self.critic.optimizer.step()
self.critic2.optimizer.step()
return self.actor_loss, critic_loss.item()


class SACAgent:
def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma,
tau, reward_scale, lr, batch_size, maxsize, checkpoint):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.gamma = gamma
self.tau = tau
self.reward_scale = reward_scale
self.batch_size = batch_size

self.memory = ReplayBuffer(state_dim, action_dim, maxsize)
self.critic1 = SACCritic(*state_dim, *action_dim, hidden_dims, lr,
checkpoint, 'Critic')
self.critic2 = SACCritic(*state_dim, *action_dim, hidden_dims,
lr, checkpoint, 'Critic2')
self.actor = SACActor(*state_dim, *action_dim, hidden_dims, max_action,
lr, checkpoint, 'Actor')
self.value = SACValue(*state_dim, hidden_dims,
lr, checkpoint, 'Valuator')
self.target_value = self.get_target_network(self.value)
self.target_value.name = 'Target_Valuator'

def get_target_network(self, online_network, freeze_weights=True):
target_network = deepcopy(online_network)
if freeze_weights:
for param in target_network.parameters():
param.requires_grad = False
return target_network

def choose_action(self, observation, test):
self.actor.eval()
observation = torch.from_numpy(observation).to(self.device)
with torch.no_grad():
action, _ = self.actor(observation)
self.actor.train()
action = action.cpu().detach().numpy()
return action

def update(self):
experiences = self.memory.sample_transition(self.batch_size)
states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences]

###### UPDATE VALUATOR ######
self.value.optimizer.zero_grad()
with torch.no_grad():
policy_actions, log_probs = self.actor(states, reparameterize=False)
action_values1 = self.critic1(states, policy_actions).squeeze()
action_values2 = self.critic2(states, policy_actions).squeeze()
action_values = torch.min(action_values1, action_values2)
target = action_values - log_probs.squeeze()
values = self.value(states).squeeze()
value_loss = 0.5 * F.mse_loss(target, values)
value_loss.backward()
self.value.optimizer.step()

###### UPDATE CRITIC ######
self.critic1.optimizer.zero_grad()
self.critic2.optimizer.zero_grad()
with torch.no_grad():
v_hat = self.target_value(next_states).squeeze() * (~dones)
targets = rewards * self.reward_scale + self.gamma * v_hat
qs1 = self.critic1(states, actions).squeeze()
qs2 = self.critic2(states, actions).squeeze()
critic_loss1 = 0.5 * F.mse_loss(targets, qs1)
critic_loss2 = 0.5 * F.mse_loss(targets, qs2)
critic_loss = critic_loss1 + critic_loss2
critic_loss.backward()
self.critic1.optimizer.step()
self.critic2.optimizer.step()

###### UPDATE ACTOR ######
self.actor.optimizer.zero_grad()
actions, log_probs = self.actor(states)
action_values1 = self.critic1(states, actions).squeeze()
action_values2 = self.critic2(states, actions).squeeze()
action_values = torch.min(action_values1, action_values2)
actor_loss = torch.mean(log_probs.squeeze() - action_values)
actor_loss.backward()
self.actor.optimizer.step()

###### UPDATE TARGET VALUE ######
self.update_target_network(self.value, self.target_value)

return value_loss.item(), critic_loss.item(), actor_loss.item()

def update_target_network(self, src, tgt):
for src_weight, tgt_weight in zip(src.parameters(), tgt.parameters()):
tgt_weight.data = tgt_weight.data * self.tau + src_weight.data * (1. - self.tau)

def store_transition(self, state, action, reward, next_state, done):
state = torch.tensor(state)
action = torch.tensor(action)
reward = torch.tensor(reward)
next_state = torch.tensor(next_state)
done = torch.tensor(done, dtype=torch.bool)
self.memory.store_transition(state, action, reward, next_state, done)

def save_models(self):
self.critic1.save_checkpoint()
self.critic2.save_checkpoint()
self.actor.save_checkpoint()
self.value.save_checkpoint()
self.target_value.save_checkpoint()

def load_models(self):
self.critic1.load_checkpoint()
self.critic2.load_checkpoint()
self.actor.load_checkpoint()
self.value.load_checkpoint()
self.target_value.load_checkpoint()

48 changes: 39 additions & 9 deletions policy/continuous/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

parser = argparse.ArgumentParser(description='Continuous Environment Agents')
# training hyperparams
parser.add_argument('--agent', type=str, default='TD3', help='Agent Algorithm')
parser.add_argument('--agent', type=str, default='SAC', help='Agent Algorithm')
parser.add_argument('--environment', type=str, default='LunarLanderContinuous-v2', help='Agent Algorithm')
parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for')
parser.add_argument('--batch_size', type=int, default=100, help='Minibatch size')
parser.add_argument('--batch_size', type=int, default=256, help='Minibatch size')
parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers')
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers')
parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic')
Expand All @@ -33,6 +33,7 @@
parser.add_argument('--actor_update_iter', type=int, default=2, help='Update actor and target network every')
parser.add_argument('--action_sigma', type=float, default=0.2, help='Std of noise for actions')
parser.add_argument('--action_clip', type=float, default=0.5, help='Max action bound')
parser.add_argument('--reward_scale', type=float, default=2., help='Reward scale for Soft Actor-Critic')

# eval params
parser.add_argument('--render', action="store_true", default=False, help='Render environment while training')
Expand Down Expand Up @@ -95,6 +96,20 @@ def main():
action_sigma=args.action_sigma,
action_clip=args.action_clip
)
elif args.agent == 'SAC':
max_action = float(env.action_space.high[0])
agent = agent_(state_dim=env.observation_space.shape,
action_dim=env.action_space.shape,
hidden_dims=args.hidden_dims,
max_action=max_action,
gamma=args.gamma,
tau=args.tau,
reward_scale=2,
lr=args.critic_lr,
batch_size=args.batch_size,
maxsize=int(args.maxsize),
checkpoint=args.checkpoint,
)
else:
agent = agent_(state_dim=env.observation_space.shape,
actionaction_dim_dim=env.action_space.n,
Expand All @@ -116,13 +131,15 @@ def main():
done, score, observation = False, 0, env.reset()

# reset DDPG UO Noise and also keep track of actor/critic losses
if args.agent in ['DDPG', 'TD3']:
if args.agent in ['DDPG', 'TD3', 'SAC']:
if args.agent == 'DDPG':
agent.noise.reset()
actor_losses, critic_losses = [], []
if args.agent == 'SAC':
value_losses = []
while not done:
if args.render:
env.render()
env.render(mode='human')

action = agent.choose_action(observation, args.test)
next_observation, reward, done, _ = env.step(action)
Expand All @@ -133,15 +150,23 @@ def main():
continue
elif args.agent == 'Actor Critic':
agent.update(reward, next_observation, done)
elif args.agent in ['DDPG', 'TD3']:
elif args.agent in ['DDPG', 'TD3', 'SAC']:
agent.store_transition(observation, action, reward, next_observation, done)
# if we have memory smaller than batch size, do not update
if agent.memory.idx < args.batch_size or (args.agent == 'TD3' and agent.ctr < args.warmup_steps):
continue
actor_loss, critic_loss = agent.update()
if args.agent == 'SAC':
value_loss, critic_loss, actor_loss = agent.update()
value_losses.append(value_loss)
else:
actor_loss, critic_loss = agent.update()
actor_losses.append(actor_loss)
critic_losses.append(critic_loss)
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss})
if args.agent == 'SAC':
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss,
'Critic Loss': critic_loss, 'Value Loss': value_loss})
else:
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss})
else:
agent.store_reward(reward)
observation = next_observation
Expand All @@ -155,15 +180,20 @@ def main():
agent.update()

# logging & saving
elif args.agent in ['DDPG', 'TD3']:
elif args.agent in ['DDPG', 'TD3', 'SAC']:
writer.add_scalars(
'Scores',
{'Episodic': score, 'Windowed Average': np.mean(score_history)},
global_step=e)

if actor_losses:
loss_dict = {'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)}
if args.agent == 'SAC':
loss_dict['Value'] = np.mean(value_losses)
value_losses = []
writer.add_scalars(
'Losses',
{'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)},
loss_dict,
global_step=e)
actor_losses, critic_losses = [], []

Expand Down
102 changes: 102 additions & 0 deletions policy/networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal


class ActorCritic(nn.Module):
Expand Down Expand Up @@ -112,3 +113,104 @@ def save_checkpoint(self):

def load_checkpoint(self):
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))


class SACCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dims, lr,
checkpoint_path, name):
super().__init__()
self.checkpoint_path = checkpoint_path
self.name = name
encoder = []
prev_dim = state_dim + action_dim
for i, dim in enumerate(hidden_dims):
encoder.extend([
nn.Linear(prev_dim, dim),
nn.LayerNorm(dim)
])
if i < len(hidden_dims) - 1:
encoder.append(nn.ReLU(True))
prev_dim = dim
self.encoder = nn.Sequential(*encoder)
self.value = nn.Linear(prev_dim, 1)
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(self.device)

def forward(self, states, actions):
scores = self.encoder(torch.cat([states, actions], dim=1))
return self.value(scores)

def save_checkpoint(self):
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')

def load_checkpoint(self):
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))


class SACValue(SACCritic):
def __init__(self, state_dim, hidden_dims, lr,
checkpoint_path, name):
super().__init__(state_dim, 0, hidden_dims, lr,
checkpoint_path, name)

def forward(self, states):
scores = self.encoder(states)
return self.value(scores)


class SACActor(nn.Module):
def __init__(self, state_dim, action_dim,
hidden_dims, lr, max_action,
checkpoint_path, name):
super().__init__()
self.log_std_min = -20
self.log_std_max = 2
self.epsilon = 1e-6
self.checkpoint_path = checkpoint_path
self.name = name
self.max_action = max_action
encoder = []
prev_dim = state_dim
for i, dim in enumerate(hidden_dims):
encoder.extend([
nn.Linear(prev_dim, dim),
nn.LayerNorm(dim)
])
if i < len(hidden_dims) - 1:
encoder.append(nn.ReLU(True))
prev_dim = dim
self.encoder = nn.Sequential(*encoder)

# mu & logvar for action
self.actor = nn.Linear(prev_dim, action_dim * 2)
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(self.device)

def sample(self, mu, log_std, reparameterize=True):
if mu.dim() == 1:
mu = mu.unsqueeze(0)
distribution = MultivariateNormal(mu, scale_tril=torch.diag_embed(log_std.exp()))
if reparameterize:
actions = distribution.rsample()
else:
actions = distribution.sample()
log_probs = distribution.log_prob(actions)
bounded_actions = torch.tanh(actions) * self.max_action
bounded_log_probs = log_probs - torch.log(
(1 - bounded_actions.pow(2)).clamp(0, 1) + self.epsilon).sum(dim=1)
return bounded_actions.squeeze(), bounded_log_probs

def forward(self, states, reparameterize=True):
scores = self.encoder(states)
mu, log_std = self.actor(scores).split(2, dim=-1)
log_std = log_std.clamp(self.log_std_min, self.log_std_max)
action, log_prob = self.sample(mu, log_std, reparameterize=reparameterize)
return action, log_prob

def save_checkpoint(self):
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')

def load_checkpoint(self):
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))

0 comments on commit 68809d8

Please sign in to comment.