Skip to content

Commit

Permalink
added DDPG implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Feb 20, 2021
1 parent b627774 commit c00e3bc
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 14 deletions.
92 changes: 88 additions & 4 deletions policy/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import torch
from policy.networks import ActorCritic
from copy import deepcopy
from policy.networks import ActorCritic, Actor, Critic
from policy.utils import ReplayBuffer, OUActionNoise


class BlackJackAgent:
Expand Down Expand Up @@ -172,7 +174,6 @@ def __init__(self, input_dim, action_dim, hidden_dim, gamma, lr):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.log_proba, self.value = None, None


def choose_action(self, state):
state = torch.from_numpy(state).to(self.device)
self.value, action_logits = self.actor_critic(state)
Expand All @@ -186,13 +187,96 @@ def update(self, reward, state_, done):
# calculate TD loss
state_ = torch.from_numpy(state_).unsqueeze(0).to(self.device)
value_, _ = self.actor_critic(state_)
critic_loss = (reward + self.gamma * value_ * ~done - self.value).pow(2)
TD_error = reward + self.gamma * value_ * ~done - self.value
critic_loss = TD_error.pow(2)

# actor loss
actor_loss = - self.value.detach() * self.log_proba
actor_loss = - self.value * self.log_proba

# sgd + reset history
loss = critic_loss + actor_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()


class DDPGAgent:
def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma,
tau, critic_lr, critic_wd, actor_lr, actor_wd, batch_size,
final_init, maxsize, sigma, theta, dt, checkpoint):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.gamma = gamma
self.tau = tau
self.batch_size = batch_size
self.memory = ReplayBuffer(state_dim, action_dim, maxsize)
self.noise = OUActionNoise(torch.zeros(action_dim, device=self.device),
sigma=sigma,
theta=theta,
dt=dt)
self.critic = Critic(*state_dim, *action_dim, hidden_dims, critic_lr, critic_wd,
final_init, checkpoint, 'Critic')
self.actor = Actor(*state_dim, *action_dim, hidden_dims, max_action,
actor_lr, actor_wd, final_init, checkpoint, 'Actor')
self.target_critic = deepcopy(self.critic)
self.target_critic.name = 'Target_Critic'
self.target_actor = deepcopy(self.actor)
self.target_actor.name = 'Target_Actor'

def update(self):
experiences = self.memory.sample_transition(self.batch_size)
states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences]
# calculate targets & only update online critic network
with torch.no_grad():
next_actions = self.target_actor(next_states)
q_primes = self.target_critic(next_states, next_actions)
targets = rewards + self.gamma * q_primes * (~dones)
qs = self.critic(states, actions)
td_error = targets - qs
critic_loss = td_error.pow(2).mean()
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()

# actor loss is by maximizing Q values
qs = self.critic(states, self.actor(states))
actor_loss = - qs.mean()
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()

self.update_target_network(self.critic, self.target_critic)
self.update_target_network(self.actor, self.target_actor)
return actor_loss.item(), critic_loss.item()

def update_target_network(self, src, tgt):
for src_weight, tgt_weight in zip(src.parameters(), tgt.parameters()):
tgt_weight.data = tgt_weight.data * self.tau + src_weight.data * (1. - self.tau)

def save_models(self):
self.critic.save_checkpoint()
self.actor.save_checkpoint()
self.target_critic.save_checkpoint()
self.target_actor.save_checkpoint()

def load_models(self):
self.critic.load_checkpoint()
self.actor.load_checkpoint()
self.target_critic.save_checkpoint()
self.target_actor.save_checkpoint()

def choose_action(self, observation):
self.actor.eval()
observation = torch.from_numpy(observation).to(self.device)
with torch.no_grad():
mu = self.actor(observation)
action = mu + self.noise()
self.actor.train()
return action.cpu().detach().numpy()

def store_transition(self, state, action, reward, next_state, done):
state = torch.tensor(state)
action = torch.tensor(action)
reward = torch.tensor(reward)
next_state = torch.tensor(next_state)
done = torch.tensor(done, dtype=torch.bool)
self.memory.store_transition(state, action, reward, next_state, done)
109 changes: 99 additions & 10 deletions policy/lunarlander/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,136 @@
import numpy as np
from tqdm import tqdm
from collections import deque
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

from policy.utils import clip_action
from policy import agent as Agent


parser = argparse.ArgumentParser(description='Lunar Lander Agents')
parser.add_argument('--agent', type=str, default='Actor Critic', help='Agent style')
# training hyperparams
parser.add_argument('--agent', type=str, default='DDPG', help='Agent style')
parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for')
parser.add_argument('--batch_size', type=int, default=64, help='Minibatch size')
parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers')
parser.add_argument('--lr', '--learning_rate', type=float, default=1e-4, help='Learning rate for Adam optimizer')
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers')
parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic')
parser.add_argument('--critic_wd', type=float, default=1e-2, help='Weight decay for Critic')
parser.add_argument('--actor_lr', type=float, default=1e-4, help='Learning rate for Actor')
parser.add_argument('--actor_wd', type=float, default=0., help='Weight decay for Actor')
parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor')
parser.add_argument('--final_init', type=float, default=3e-3, help='The range for output layer initialization')
parser.add_argument('--tau', type=float, default=0.001, help='Weight of target network update')
parser.add_argument('--maxsize', type=int, default=1e6, help='Size of Replay Buffer')
parser.add_argument('--sigma', type=float, default=0.2, help='Sigma for UOnoise')
parser.add_argument('--theta', type=float, default=0.15, help='Theta for UOnoise')
parser.add_argument('--dt', type=float, default=1e-2, help='dt for UOnoise')

# eval params
parser.add_argument('--render', action="store_true", default=False, help='Render environment while training')
parser.add_argument('--window_legnth', type=int, default=100, help='Length of window to keep track scores')

# checkpoint + logs
parser.add_argument('--checkpoint', type=str, default='policy/lunarlander/checkpoint', help='Checkpoint for model weights')
parser.add_argument('--logdir', type=str, default='policy/lunarlander/logs', help='Directory to save logs')
args = parser.parse_args()


def main():
env = gym.make('LunarLander-v2')
env_type = 'Continuous' if args.agent in ['DDPG'] else ''
env = gym.make(f'LunarLander{env_type}-v2')
agent_ = getattr(Agent, args.agent.replace(' ', '') + 'Agent')
agent = agent_(input_dim=env.observation_space.shape,
action_dim=env.action_space.n,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
lr=args.lr)
if args.agent in ['DDPG']:
max_action = float(env.action_space.high[0])
agent = agent_(state_dim=env.observation_space.shape,
action_dim=env.action_space.shape,
hidden_dims=args.hidden_dims,
max_action=max_action,
gamma=args.gamma,
tau=args.tau,
critic_lr=args.critic_lr,
critic_wd=args.critic_wd,
actor_lr=args.actor_lr,
actor_wd=args.actor_wd,
batch_size=args.batch_size,
final_init=args.final_init,
maxsize=int(args.maxsize),
sigma=args.sigma,
theta=args.theta,
dt=args.dt,
checkpoint=args.checkpoint)
else:
agent = agent_(state_dim=env.observation_space.shape,
actionaction_dim_dim=env.action_space.n,
hidden_dims=args.hidden_dims,
gamma=args.gamma,
lr=args.lr)

Path(args.logdir).mkdir(parents=True, exist_ok=True)
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(args.logdir)

pbar = tqdm(range(args.n_episodes))
score_history = deque(maxlen=args.window_legnth)
best_score = env.reward_range[0]
for e in pbar:
done, score, observation = False, 0, env.reset()

# reset DDPG UO Noise and also keep track of actor/critic losses
if args.agent in ['DDPG']:
agent.noise.reset()
actor_losses, critic_losses = [], []
while not done:
if args.render:
env.render()

action = agent.choose_action(observation)
# clip noised action to ensure not out of bounds
if args.agent in ['DDPG']:
action = clip_action(action, max_action)
next_observation, reward, done, _ = env.step(action)
score += reward

# update for td methods, recording for mc methods
if args.agent == 'Actor Critic':
agent.update(reward, next_observation, done)
elif args.agent in ['DDPG']:
agent.store_transition(observation, action, reward, next_observation, done)
# if we have memory smaller than batch size, do not update
if agent.memory.idx < args.batch_size:
continue
actor_loss, critic_loss = agent.update()
actor_losses.append(actor_loss)
critic_losses.append(critic_loss)
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss})
else:
agent.store_reward(reward)
observation = next_observation
score += reward

score_history.append(score)

# update for mc methods w/ full trajectory
if args.agent == 'Policy Gradient':
agent.update()
score_history.append(score)

# logging & saving
elif args.agent in ['DDPG']:
writer.add_scalars(
'Scores',
{'Episodic': score, 'Windowed Average': np.mean(score_history)},
global_step=e)
if actor_losses:
writer.add_scalars(
'Losses',
{'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)},
global_step=e)
actor_losses, critic_losses = [], []

if score > best_score:
best_score = score
agent.save_models()
tqdm.write(
f'Episode: {e + 1}/{args.n_episodes}, Score: {score}, Average Score: {np.mean(score_history)}')

Expand Down
95 changes: 95 additions & 0 deletions policy/networks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import torch
from torch import nn

Expand All @@ -17,3 +18,97 @@ def __init__(self, input_dim, n_actions, hidden_dim):
def forward(self, state):
features = self.encoder(state)
return self.v(features), self.pi(features)


class Critic(nn.Module):
def __init__(self, input_dim, action_dim, hidden_dims, lr, weight_decay,
final_init, checkpoint_path, name):
super().__init__()
self.checkpoint_path = checkpoint_path
self.name = name
encoder = []
prev_dim = input_dim
for i, dim in enumerate(hidden_dims):
encoder.extend([
nn.Linear(prev_dim, dim),
nn.LayerNorm(dim)
])
if i < len(hidden_dims) - 1:
encoder.append(nn.ReLU(True))
prev_dim = dim
self.state_encoder = nn.Sequential(*encoder)
self.action_encoder = nn.Sequential(nn.Linear(action_dim, prev_dim),
nn.LayerNorm(prev_dim))
self.q = nn.Linear(prev_dim, 1)
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
self._init_weights(self.q, final_init)
self._init_weights(self.action_encoder, 1 / math.sqrt(action_dim))
self._init_weights(self.state_encoder, 1 / math.sqrt(hidden_dims[-2]))
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(self.device)

def _init_weights(self, layers, b):
for m in layers.modules():
if isinstance(m, (nn.Linear, nn.LayerNorm)):
nn.init.uniform_(
m.weight,
a=-b,
b=b
)

def forward(self, states, actions):
state_values = self.state_encoder(states)
action_values = self.action_encoder(actions)
state_action_values = nn.functional.relu(torch.add(state_values, action_values))
return self.q(state_action_values)

def save_checkpoint(self):
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')

def load_checkpoint(self):
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))


class Actor(nn.Module):
def __init__(self, input_dim, action_dim, hidden_dims,
max_action, lr, weight_decay,
final_init, checkpoint_path, name):
super().__init__()
self.max_action = max_action
self.name = name
self.checkpoint_path = checkpoint_path
encoder = []
prev_dim = input_dim
for dim in hidden_dims:
encoder.extend([
nn.Linear(prev_dim, dim),
nn.LayerNorm(dim),
nn.ReLU(True)])
prev_dim = dim
self.state_encoder = nn.Sequential(*encoder)
self.mu = nn.Linear(prev_dim, action_dim)
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
self._init_weights(self.mu, final_init)
self._init_weights(self.state_encoder, 1 / math.sqrt(hidden_dims[-2]))
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.to(self.device)

def _init_weights(self, layers, b):
for m in layers.modules():
if isinstance(m, (nn.Linear, nn.LayerNorm)):
nn.init.uniform_(
m.weight,
a=-b,
b=b
)

def forward(self, states):
state_features = self.state_encoder(states)
# bound the output action to [-max_action, max_action]
return torch.tanh(self.mu(state_features)) * self.max_action

def save_checkpoint(self):
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')

def load_checkpoint(self):
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))
Loading

0 comments on commit c00e3bc

Please sign in to comment.