From c00e3bc910da70684e6952b656065eb4e88320cb Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Sat, 20 Feb 2021 17:53:46 +0800 Subject: [PATCH] added DDPG implementation --- policy/agent.py | 92 +++++++++++++++++++++++++++++-- policy/lunarlander/main.py | 109 +++++++++++++++++++++++++++++++++---- policy/networks.py | 95 ++++++++++++++++++++++++++++++++ policy/utils.py | 57 +++++++++++++++++++ 4 files changed, 339 insertions(+), 14 deletions(-) create mode 100644 policy/utils.py diff --git a/policy/agent.py b/policy/agent.py index 457a79b..e6e4190 100644 --- a/policy/agent.py +++ b/policy/agent.py @@ -1,6 +1,8 @@ import numpy as np import torch -from policy.networks import ActorCritic +from copy import deepcopy +from policy.networks import ActorCritic, Actor, Critic +from policy.utils import ReplayBuffer, OUActionNoise class BlackJackAgent: @@ -172,7 +174,6 @@ def __init__(self, input_dim, action_dim, hidden_dim, gamma, lr): self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.log_proba, self.value = None, None - def choose_action(self, state): state = torch.from_numpy(state).to(self.device) self.value, action_logits = self.actor_critic(state) @@ -186,13 +187,96 @@ def update(self, reward, state_, done): # calculate TD loss state_ = torch.from_numpy(state_).unsqueeze(0).to(self.device) value_, _ = self.actor_critic(state_) - critic_loss = (reward + self.gamma * value_ * ~done - self.value).pow(2) + TD_error = reward + self.gamma * value_ * ~done - self.value + critic_loss = TD_error.pow(2) # actor loss - actor_loss = - self.value.detach() * self.log_proba + actor_loss = - self.value * self.log_proba # sgd + reset history loss = critic_loss + actor_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() + + +class DDPGAgent: + def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma, + tau, critic_lr, critic_wd, actor_lr, actor_wd, batch_size, + final_init, maxsize, sigma, theta, dt, checkpoint): + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.gamma = gamma + self.tau = tau + self.batch_size = batch_size + self.memory = ReplayBuffer(state_dim, action_dim, maxsize) + self.noise = OUActionNoise(torch.zeros(action_dim, device=self.device), + sigma=sigma, + theta=theta, + dt=dt) + self.critic = Critic(*state_dim, *action_dim, hidden_dims, critic_lr, critic_wd, + final_init, checkpoint, 'Critic') + self.actor = Actor(*state_dim, *action_dim, hidden_dims, max_action, + actor_lr, actor_wd, final_init, checkpoint, 'Actor') + self.target_critic = deepcopy(self.critic) + self.target_critic.name = 'Target_Critic' + self.target_actor = deepcopy(self.actor) + self.target_actor.name = 'Target_Actor' + + def update(self): + experiences = self.memory.sample_transition(self.batch_size) + states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences] + # calculate targets & only update online critic network + with torch.no_grad(): + next_actions = self.target_actor(next_states) + q_primes = self.target_critic(next_states, next_actions) + targets = rewards + self.gamma * q_primes * (~dones) + qs = self.critic(states, actions) + td_error = targets - qs + critic_loss = td_error.pow(2).mean() + self.critic.optimizer.zero_grad() + critic_loss.backward() + self.critic.optimizer.step() + + # actor loss is by maximizing Q values + qs = self.critic(states, self.actor(states)) + actor_loss = - qs.mean() + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + self.update_target_network(self.critic, self.target_critic) + self.update_target_network(self.actor, self.target_actor) + return actor_loss.item(), critic_loss.item() + + def update_target_network(self, src, tgt): + for src_weight, tgt_weight in zip(src.parameters(), tgt.parameters()): + tgt_weight.data = tgt_weight.data * self.tau + src_weight.data * (1. - self.tau) + + def save_models(self): + self.critic.save_checkpoint() + self.actor.save_checkpoint() + self.target_critic.save_checkpoint() + self.target_actor.save_checkpoint() + + def load_models(self): + self.critic.load_checkpoint() + self.actor.load_checkpoint() + self.target_critic.save_checkpoint() + self.target_actor.save_checkpoint() + + def choose_action(self, observation): + self.actor.eval() + observation = torch.from_numpy(observation).to(self.device) + with torch.no_grad(): + mu = self.actor(observation) + action = mu + self.noise() + self.actor.train() + return action.cpu().detach().numpy() + + def store_transition(self, state, action, reward, next_state, done): + state = torch.tensor(state) + action = torch.tensor(action) + reward = torch.tensor(reward) + next_state = torch.tensor(next_state) + done = torch.tensor(done, dtype=torch.bool) + self.memory.store_transition(state, action, reward, next_state, done) diff --git a/policy/lunarlander/main.py b/policy/lunarlander/main.py index f11e4de..b36c6a6 100644 --- a/policy/lunarlander/main.py +++ b/policy/lunarlander/main.py @@ -4,47 +4,136 @@ import numpy as np from tqdm import tqdm from collections import deque +from pathlib import Path +from torch.utils.tensorboard import SummaryWriter + +from policy.utils import clip_action from policy import agent as Agent parser = argparse.ArgumentParser(description='Lunar Lander Agents') -parser.add_argument('--agent', type=str, default='Actor Critic', help='Agent style') +# training hyperparams +parser.add_argument('--agent', type=str, default='DDPG', help='Agent style') parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for') +parser.add_argument('--batch_size', type=int, default=64, help='Minibatch size') parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers') -parser.add_argument('--lr', '--learning_rate', type=float, default=1e-4, help='Learning rate for Adam optimizer') +parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers') +parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic') +parser.add_argument('--critic_wd', type=float, default=1e-2, help='Weight decay for Critic') +parser.add_argument('--actor_lr', type=float, default=1e-4, help='Learning rate for Actor') +parser.add_argument('--actor_wd', type=float, default=0., help='Weight decay for Actor') parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor') +parser.add_argument('--final_init', type=float, default=3e-3, help='The range for output layer initialization') +parser.add_argument('--tau', type=float, default=0.001, help='Weight of target network update') +parser.add_argument('--maxsize', type=int, default=1e6, help='Size of Replay Buffer') +parser.add_argument('--sigma', type=float, default=0.2, help='Sigma for UOnoise') +parser.add_argument('--theta', type=float, default=0.15, help='Theta for UOnoise') +parser.add_argument('--dt', type=float, default=1e-2, help='dt for UOnoise') +# eval params parser.add_argument('--render', action="store_true", default=False, help='Render environment while training') parser.add_argument('--window_legnth', type=int, default=100, help='Length of window to keep track scores') + +# checkpoint + logs +parser.add_argument('--checkpoint', type=str, default='policy/lunarlander/checkpoint', help='Checkpoint for model weights') +parser.add_argument('--logdir', type=str, default='policy/lunarlander/logs', help='Directory to save logs') args = parser.parse_args() def main(): - env = gym.make('LunarLander-v2') + env_type = 'Continuous' if args.agent in ['DDPG'] else '' + env = gym.make(f'LunarLander{env_type}-v2') agent_ = getattr(Agent, args.agent.replace(' ', '') + 'Agent') - agent = agent_(input_dim=env.observation_space.shape, - action_dim=env.action_space.n, - hidden_dim=args.hidden_dim, - gamma=args.gamma, - lr=args.lr) + if args.agent in ['DDPG']: + max_action = float(env.action_space.high[0]) + agent = agent_(state_dim=env.observation_space.shape, + action_dim=env.action_space.shape, + hidden_dims=args.hidden_dims, + max_action=max_action, + gamma=args.gamma, + tau=args.tau, + critic_lr=args.critic_lr, + critic_wd=args.critic_wd, + actor_lr=args.actor_lr, + actor_wd=args.actor_wd, + batch_size=args.batch_size, + final_init=args.final_init, + maxsize=int(args.maxsize), + sigma=args.sigma, + theta=args.theta, + dt=args.dt, + checkpoint=args.checkpoint) + else: + agent = agent_(state_dim=env.observation_space.shape, + actionaction_dim_dim=env.action_space.n, + hidden_dims=args.hidden_dims, + gamma=args.gamma, + lr=args.lr) + + Path(args.logdir).mkdir(parents=True, exist_ok=True) + Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + + writer = SummaryWriter(args.logdir) + pbar = tqdm(range(args.n_episodes)) score_history = deque(maxlen=args.window_legnth) + best_score = env.reward_range[0] for e in pbar: done, score, observation = False, 0, env.reset() + + # reset DDPG UO Noise and also keep track of actor/critic losses + if args.agent in ['DDPG']: + agent.noise.reset() + actor_losses, critic_losses = [], [] while not done: if args.render: env.render() + action = agent.choose_action(observation) + # clip noised action to ensure not out of bounds + if args.agent in ['DDPG']: + action = clip_action(action, max_action) next_observation, reward, done, _ = env.step(action) + score += reward + + # update for td methods, recording for mc methods if args.agent == 'Actor Critic': agent.update(reward, next_observation, done) + elif args.agent in ['DDPG']: + agent.store_transition(observation, action, reward, next_observation, done) + # if we have memory smaller than batch size, do not update + if agent.memory.idx < args.batch_size: + continue + actor_loss, critic_loss = agent.update() + actor_losses.append(actor_loss) + critic_losses.append(critic_loss) + pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss}) else: agent.store_reward(reward) observation = next_observation - score += reward + + score_history.append(score) + + # update for mc methods w/ full trajectory if args.agent == 'Policy Gradient': agent.update() - score_history.append(score) + + # logging & saving + elif args.agent in ['DDPG']: + writer.add_scalars( + 'Scores', + {'Episodic': score, 'Windowed Average': np.mean(score_history)}, + global_step=e) + if actor_losses: + writer.add_scalars( + 'Losses', + {'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)}, + global_step=e) + actor_losses, critic_losses = [], [] + + if score > best_score: + best_score = score + agent.save_models() tqdm.write( f'Episode: {e + 1}/{args.n_episodes}, Score: {score}, Average Score: {np.mean(score_history)}') diff --git a/policy/networks.py b/policy/networks.py index 43d1ae8..d9e709e 100644 --- a/policy/networks.py +++ b/policy/networks.py @@ -1,3 +1,4 @@ +import math import torch from torch import nn @@ -17,3 +18,97 @@ def __init__(self, input_dim, n_actions, hidden_dim): def forward(self, state): features = self.encoder(state) return self.v(features), self.pi(features) + + +class Critic(nn.Module): + def __init__(self, input_dim, action_dim, hidden_dims, lr, weight_decay, + final_init, checkpoint_path, name): + super().__init__() + self.checkpoint_path = checkpoint_path + self.name = name + encoder = [] + prev_dim = input_dim + for i, dim in enumerate(hidden_dims): + encoder.extend([ + nn.Linear(prev_dim, dim), + nn.LayerNorm(dim) + ]) + if i < len(hidden_dims) - 1: + encoder.append(nn.ReLU(True)) + prev_dim = dim + self.state_encoder = nn.Sequential(*encoder) + self.action_encoder = nn.Sequential(nn.Linear(action_dim, prev_dim), + nn.LayerNorm(prev_dim)) + self.q = nn.Linear(prev_dim, 1) + self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) + self._init_weights(self.q, final_init) + self._init_weights(self.action_encoder, 1 / math.sqrt(action_dim)) + self._init_weights(self.state_encoder, 1 / math.sqrt(hidden_dims[-2])) + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.to(self.device) + + def _init_weights(self, layers, b): + for m in layers.modules(): + if isinstance(m, (nn.Linear, nn.LayerNorm)): + nn.init.uniform_( + m.weight, + a=-b, + b=b + ) + + def forward(self, states, actions): + state_values = self.state_encoder(states) + action_values = self.action_encoder(actions) + state_action_values = nn.functional.relu(torch.add(state_values, action_values)) + return self.q(state_action_values) + + def save_checkpoint(self): + torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth') + + def load_checkpoint(self): + self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth')) + + +class Actor(nn.Module): + def __init__(self, input_dim, action_dim, hidden_dims, + max_action, lr, weight_decay, + final_init, checkpoint_path, name): + super().__init__() + self.max_action = max_action + self.name = name + self.checkpoint_path = checkpoint_path + encoder = [] + prev_dim = input_dim + for dim in hidden_dims: + encoder.extend([ + nn.Linear(prev_dim, dim), + nn.LayerNorm(dim), + nn.ReLU(True)]) + prev_dim = dim + self.state_encoder = nn.Sequential(*encoder) + self.mu = nn.Linear(prev_dim, action_dim) + self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) + self._init_weights(self.mu, final_init) + self._init_weights(self.state_encoder, 1 / math.sqrt(hidden_dims[-2])) + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.to(self.device) + + def _init_weights(self, layers, b): + for m in layers.modules(): + if isinstance(m, (nn.Linear, nn.LayerNorm)): + nn.init.uniform_( + m.weight, + a=-b, + b=b + ) + + def forward(self, states): + state_features = self.state_encoder(states) + # bound the output action to [-max_action, max_action] + return torch.tanh(self.mu(state_features)) * self.max_action + + def save_checkpoint(self): + torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth') + + def load_checkpoint(self): + self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth')) diff --git a/policy/utils.py b/policy/utils.py new file mode 100644 index 0000000..5cc9dff --- /dev/null +++ b/policy/utils.py @@ -0,0 +1,57 @@ +import math +import numpy as np +import torch + + +class OUActionNoise: + def __init__(self, mu, sigma=0.2, theta=0.15, dt=1e-2, x0=None): + self.mu = mu + self.sigma = sigma + self.theta = theta + self.dt = dt + self.x0 = x0 + self.device = self.mu.device + self.reset() + + def __call__(self): + x = self.x_prev + self.theta * (self.mu - self.x_prev) * \ + self.dt + self.sigma + math.sqrt(self.dt) * torch.randn(*self.mu.shape, device=self.device) + self.x_prev = x + return x + + def reset(self): + self.x_prev = self.x0 if self.x0 is not None else torch.zeros_like(self.mu) + + +class ReplayBuffer: + def __init__(self, state_dim, action_dim, maxsize): + self.states = torch.empty(maxsize, *state_dim) + self.actions = torch.empty(maxsize, *action_dim) + self.next_states = torch.empty(maxsize, *state_dim) + self.rewards = torch.empty(maxsize) + self.dones = torch.zeros(maxsize, dtype=torch.bool) + self.maxsize = maxsize + self.idx = 0 + + def store_transition(self, state, action, reward, next_state, done): + idx = self.idx % self.maxsize + self.states[idx] = state + self.actions[idx] = action + self.rewards[idx] = reward + self.next_states[idx] = next_state + self.dones[idx] = done + self.idx += 1 + + def sample_transition(self, batch_size): + curr_size = min(self.maxsize, self.idx) + idx = torch.multinomial(torch.ones(curr_size), batch_size) + states = self.states[idx] + actions = self.actions[idx] + rewards = self.rewards[idx] + next_states = self.next_states[idx] + dones = self.dones[idx] + return states, actions, rewards, next_states, dones + + +def clip_action(action, max_action): + return np.clip(action, - max_action, max_action)