diff --git a/policy/agent.py b/policy/agent.py index 55f45ad..19e1393 100644 --- a/policy/agent.py +++ b/policy/agent.py @@ -3,9 +3,9 @@ from torch.nn import functional as F from copy import deepcopy -from policy.networks import ActorCritic, Actor, Critic +from policy.networks import ActorCritic, Actor, Critic, SACActor, SACCritic, SACValue from policy.utils import ReplayBuffer, OUActionNoise, clip_action, GaussianActionNoise - +torch.autograd.set_detect_anomaly(True) class BlackJackAgent: def __init__(self, method, env, function='V', gamma=0.99, epsilon=0.1): @@ -332,7 +332,7 @@ def choose_action(self, observation, test): with torch.no_grad(): action = self.actor(observation) if not test: - action = action + self.noise(action.size()) + action = action + self.noise(action.size()).to(self.device) self.actor.train() action = action.cpu().detach().numpy() # clip noised action to ensure not out of bounds @@ -357,12 +357,14 @@ def update(self): # calculate targets & only update online critic network self.critic.optimizer.zero_grad() + self.critic2.optimizer.zero_grad() with torch.no_grad(): # y <- r + gamma * min_(i=1,2) Q_(theta'_i)(s', a_telda) target_actions = self.target_actor(next_states) target_actions += self.noise( - target_actions.size(), clip=self.action_clip, sigma=self.action_sigma) - target_actions = clip_action(target_actions, self.max_action) + target_actions.size(), clip=self.action_clip, sigma=self.action_sigma).to(self.device) + target_actions = clip_action(target_actions.cpu().numpy(), self.max_action) + target_actions = torch.from_numpy(target_actions).to(self.device) q_primes1 = self.target_critic(next_states, target_actions).squeeze() q_primes2 = self.target_critic2(next_states, target_actions).squeeze() q_primes = torch.min(q_primes1, q_primes2) @@ -375,4 +377,117 @@ def update(self): critic_loss = critic_loss1 + critic_loss2 critic_loss.backward() self.critic.optimizer.step() + self.critic2.optimizer.step() return self.actor_loss, critic_loss.item() + + +class SACAgent: + def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma, + tau, reward_scale, lr, batch_size, maxsize, checkpoint): + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.gamma = gamma + self.tau = tau + self.reward_scale = reward_scale + self.batch_size = batch_size + + self.memory = ReplayBuffer(state_dim, action_dim, maxsize) + self.critic1 = SACCritic(*state_dim, *action_dim, hidden_dims, lr, + checkpoint, 'Critic') + self.critic2 = SACCritic(*state_dim, *action_dim, hidden_dims, + lr, checkpoint, 'Critic2') + self.actor = SACActor(*state_dim, *action_dim, hidden_dims, max_action, + lr, checkpoint, 'Actor') + self.value = SACValue(*state_dim, hidden_dims, + lr, checkpoint, 'Valuator') + self.target_value = self.get_target_network(self.value) + self.target_value.name = 'Target_Valuator' + + def get_target_network(self, online_network, freeze_weights=True): + target_network = deepcopy(online_network) + if freeze_weights: + for param in target_network.parameters(): + param.requires_grad = False + return target_network + + def choose_action(self, observation, test): + self.actor.eval() + observation = torch.from_numpy(observation).to(self.device) + with torch.no_grad(): + action, _ = self.actor(observation) + self.actor.train() + action = action.cpu().detach().numpy() + return action + + def update(self): + experiences = self.memory.sample_transition(self.batch_size) + states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences] + + ###### UPDATE VALUATOR ###### + self.value.optimizer.zero_grad() + with torch.no_grad(): + policy_actions, log_probs = self.actor(states, reparameterize=False) + action_values1 = self.critic1(states, policy_actions).squeeze() + action_values2 = self.critic2(states, policy_actions).squeeze() + action_values = torch.min(action_values1, action_values2) + target = action_values - log_probs.squeeze() + values = self.value(states).squeeze() + value_loss = 0.5 * F.mse_loss(target, values) + value_loss.backward() + self.value.optimizer.step() + + ###### UPDATE CRITIC ###### + self.critic1.optimizer.zero_grad() + self.critic2.optimizer.zero_grad() + with torch.no_grad(): + v_hat = self.target_value(next_states).squeeze() * (~dones) + targets = rewards * self.reward_scale + self.gamma * v_hat + qs1 = self.critic1(states, actions).squeeze() + qs2 = self.critic2(states, actions).squeeze() + critic_loss1 = 0.5 * F.mse_loss(targets, qs1) + critic_loss2 = 0.5 * F.mse_loss(targets, qs2) + critic_loss = critic_loss1 + critic_loss2 + critic_loss.backward() + self.critic1.optimizer.step() + self.critic2.optimizer.step() + + ###### UPDATE ACTOR ###### + self.actor.optimizer.zero_grad() + actions, log_probs = self.actor(states) + action_values1 = self.critic1(states, actions).squeeze() + action_values2 = self.critic2(states, actions).squeeze() + action_values = torch.min(action_values1, action_values2) + actor_loss = torch.mean(log_probs.squeeze() - action_values) + actor_loss.backward() + self.actor.optimizer.step() + + ###### UPDATE TARGET VALUE ###### + self.update_target_network(self.value, self.target_value) + + return value_loss.item(), critic_loss.item(), actor_loss.item() + + def update_target_network(self, src, tgt): + for src_weight, tgt_weight in zip(src.parameters(), tgt.parameters()): + tgt_weight.data = tgt_weight.data * self.tau + src_weight.data * (1. - self.tau) + + def store_transition(self, state, action, reward, next_state, done): + state = torch.tensor(state) + action = torch.tensor(action) + reward = torch.tensor(reward) + next_state = torch.tensor(next_state) + done = torch.tensor(done, dtype=torch.bool) + self.memory.store_transition(state, action, reward, next_state, done) + + def save_models(self): + self.critic1.save_checkpoint() + self.critic2.save_checkpoint() + self.actor.save_checkpoint() + self.value.save_checkpoint() + self.target_value.save_checkpoint() + + def load_models(self): + self.critic1.load_checkpoint() + self.critic2.load_checkpoint() + self.actor.load_checkpoint() + self.value.load_checkpoint() + self.target_value.load_checkpoint() + diff --git a/policy/continuous/main.py b/policy/continuous/main.py index f3b7fb9..5ef8fcb 100644 --- a/policy/continuous/main.py +++ b/policy/continuous/main.py @@ -12,10 +12,10 @@ parser = argparse.ArgumentParser(description='Continuous Environment Agents') # training hyperparams -parser.add_argument('--agent', type=str, default='TD3', help='Agent Algorithm') +parser.add_argument('--agent', type=str, default='SAC', help='Agent Algorithm') parser.add_argument('--environment', type=str, default='LunarLanderContinuous-v2', help='Agent Algorithm') parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for') -parser.add_argument('--batch_size', type=int, default=100, help='Minibatch size') +parser.add_argument('--batch_size', type=int, default=256, help='Minibatch size') parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers') parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers') parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic') @@ -33,6 +33,7 @@ parser.add_argument('--actor_update_iter', type=int, default=2, help='Update actor and target network every') parser.add_argument('--action_sigma', type=float, default=0.2, help='Std of noise for actions') parser.add_argument('--action_clip', type=float, default=0.5, help='Max action bound') +parser.add_argument('--reward_scale', type=float, default=2., help='Reward scale for Soft Actor-Critic') # eval params parser.add_argument('--render', action="store_true", default=False, help='Render environment while training') @@ -95,6 +96,20 @@ def main(): action_sigma=args.action_sigma, action_clip=args.action_clip ) + elif args.agent == 'SAC': + max_action = float(env.action_space.high[0]) + agent = agent_(state_dim=env.observation_space.shape, + action_dim=env.action_space.shape, + hidden_dims=args.hidden_dims, + max_action=max_action, + gamma=args.gamma, + tau=args.tau, + reward_scale=2, + lr=args.critic_lr, + batch_size=args.batch_size, + maxsize=int(args.maxsize), + checkpoint=args.checkpoint, + ) else: agent = agent_(state_dim=env.observation_space.shape, actionaction_dim_dim=env.action_space.n, @@ -116,13 +131,15 @@ def main(): done, score, observation = False, 0, env.reset() # reset DDPG UO Noise and also keep track of actor/critic losses - if args.agent in ['DDPG', 'TD3']: + if args.agent in ['DDPG', 'TD3', 'SAC']: if args.agent == 'DDPG': agent.noise.reset() actor_losses, critic_losses = [], [] + if args.agent == 'SAC': + value_losses = [] while not done: if args.render: - env.render() + env.render(mode='human') action = agent.choose_action(observation, args.test) next_observation, reward, done, _ = env.step(action) @@ -133,15 +150,23 @@ def main(): continue elif args.agent == 'Actor Critic': agent.update(reward, next_observation, done) - elif args.agent in ['DDPG', 'TD3']: + elif args.agent in ['DDPG', 'TD3', 'SAC']: agent.store_transition(observation, action, reward, next_observation, done) # if we have memory smaller than batch size, do not update if agent.memory.idx < args.batch_size or (args.agent == 'TD3' and agent.ctr < args.warmup_steps): continue - actor_loss, critic_loss = agent.update() + if args.agent == 'SAC': + value_loss, critic_loss, actor_loss = agent.update() + value_losses.append(value_loss) + else: + actor_loss, critic_loss = agent.update() actor_losses.append(actor_loss) critic_losses.append(critic_loss) - pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss}) + if args.agent == 'SAC': + pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, + 'Critic Loss': critic_loss, 'Value Loss': value_loss}) + else: + pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss}) else: agent.store_reward(reward) observation = next_observation @@ -155,15 +180,20 @@ def main(): agent.update() # logging & saving - elif args.agent in ['DDPG', 'TD3']: + elif args.agent in ['DDPG', 'TD3', 'SAC']: writer.add_scalars( 'Scores', {'Episodic': score, 'Windowed Average': np.mean(score_history)}, global_step=e) + if actor_losses: + loss_dict = {'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)} + if args.agent == 'SAC': + loss_dict['Value'] = np.mean(value_losses) + value_losses = [] writer.add_scalars( 'Losses', - {'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)}, + loss_dict, global_step=e) actor_losses, critic_losses = [], [] diff --git a/policy/networks.py b/policy/networks.py index d9e709e..6454f16 100644 --- a/policy/networks.py +++ b/policy/networks.py @@ -1,6 +1,7 @@ import math import torch from torch import nn +from torch.distributions.multivariate_normal import MultivariateNormal class ActorCritic(nn.Module): @@ -112,3 +113,104 @@ def save_checkpoint(self): def load_checkpoint(self): self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth')) + + +class SACCritic(nn.Module): + def __init__(self, state_dim, action_dim, hidden_dims, lr, + checkpoint_path, name): + super().__init__() + self.checkpoint_path = checkpoint_path + self.name = name + encoder = [] + prev_dim = state_dim + action_dim + for i, dim in enumerate(hidden_dims): + encoder.extend([ + nn.Linear(prev_dim, dim), + nn.LayerNorm(dim) + ]) + if i < len(hidden_dims) - 1: + encoder.append(nn.ReLU(True)) + prev_dim = dim + self.encoder = nn.Sequential(*encoder) + self.value = nn.Linear(prev_dim, 1) + self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.to(self.device) + + def forward(self, states, actions): + scores = self.encoder(torch.cat([states, actions], dim=1)) + return self.value(scores) + + def save_checkpoint(self): + torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth') + + def load_checkpoint(self): + self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth')) + + +class SACValue(SACCritic): + def __init__(self, state_dim, hidden_dims, lr, + checkpoint_path, name): + super().__init__(state_dim, 0, hidden_dims, lr, + checkpoint_path, name) + + def forward(self, states): + scores = self.encoder(states) + return self.value(scores) + + +class SACActor(nn.Module): + def __init__(self, state_dim, action_dim, + hidden_dims, lr, max_action, + checkpoint_path, name): + super().__init__() + self.log_std_min = -20 + self.log_std_max = 2 + self.epsilon = 1e-6 + self.checkpoint_path = checkpoint_path + self.name = name + self.max_action = max_action + encoder = [] + prev_dim = state_dim + for i, dim in enumerate(hidden_dims): + encoder.extend([ + nn.Linear(prev_dim, dim), + nn.LayerNorm(dim) + ]) + if i < len(hidden_dims) - 1: + encoder.append(nn.ReLU(True)) + prev_dim = dim + self.encoder = nn.Sequential(*encoder) + + # mu & logvar for action + self.actor = nn.Linear(prev_dim, action_dim * 2) + self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.to(self.device) + + def sample(self, mu, log_std, reparameterize=True): + if mu.dim() == 1: + mu = mu.unsqueeze(0) + distribution = MultivariateNormal(mu, scale_tril=torch.diag_embed(log_std.exp())) + if reparameterize: + actions = distribution.rsample() + else: + actions = distribution.sample() + log_probs = distribution.log_prob(actions) + bounded_actions = torch.tanh(actions) * self.max_action + bounded_log_probs = log_probs - torch.log( + (1 - bounded_actions.pow(2)).clamp(0, 1) + self.epsilon).sum(dim=1) + return bounded_actions.squeeze(), bounded_log_probs + + def forward(self, states, reparameterize=True): + scores = self.encoder(states) + mu, log_std = self.actor(scores).split(2, dim=-1) + log_std = log_std.clamp(self.log_std_min, self.log_std_max) + action, log_prob = self.sample(mu, log_std, reparameterize=reparameterize) + return action, log_prob + + def save_checkpoint(self): + torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth') + + def load_checkpoint(self): + self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))