Skip to content

Commit

Permalink
first commit td3 implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 16, 2020
1 parent e9a9f7c commit d34a333
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 3 deletions.
4 changes: 2 additions & 2 deletions networks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, acti
elif normalization == 'gn':
norm = nn.GroupNorm(groups, out_channels)
else:
raise NotImplementedError('Please only choose normalization [bn, ln, in]')
raise NotImplementedError('Please only choose normalization [bn, ln, in, gn]')

# activations
if activation == 'relu':
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, in_channels, activation, normalization, groups=1):
elif normalization == 'gn':
norm = nn.GroupNorm(groups, in_channels)
else:
raise NotImplementedError('Please only choose normalization [bn, ln, in]')
raise NotImplementedError('Please only choose normalization [bn, ln, in, gn]')

# activations
if activation == 'relu':
Expand Down
2 changes: 1 addition & 1 deletion simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(args.epochs * 0.1))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

start_epoch = 0
if args.continue_train:
Expand Down
178 changes: 178 additions & 0 deletions td3/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
from itertools import chain

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from td3.utils import ReplayBuffer
from td3.networks import Actor, Critic


class Agent:
def __init__(self, env, alpha, beta, hidden_dims, tau,
batch_size, gamma, d, warmup, max_size, c,
sigma, one_device, log_dir, checkpoint_dir):
state_space = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

# training params
self.gamma = gamma
self.tau = tau
self.max_action = env.action_space.high[0]
self.min_action = env.action_space.low[0]
self.buffer = ReplayBuffer(max_size, state_space, n_actions)
self.batch_size = batch_size
self.learn_step_counter = 0
self.time_step = 0
self.warmup = warmup
self.n_actions = n_actions
self.d = d
self.c = c
self.sigma = sigma

# training device
if one_device:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# logging/checkpointing
self.writer = SummaryWriter(log_dir)
self.checkpoint_dir = checkpoint_dir

# networks & optimizers
self.actor = Actor(state_space, hidden_dims, n_actions, 'actor').to(self.device)
self.critic_1 = Critic(state_space, hidden_dims, n_actions, 'critic_1').to(self.device)
self.critic_2 = Critic(state_space, hidden_dims, n_actions, 'critic_2').to(self.device)

self.critic_optimizer = torch.optim.Adam(
chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha)

self.target_actor = Actor(state_space, hidden_dims, n_actions, 'target_actor').to(self.device)
self.target_critic_1 = Critic(state_space, hidden_dims, n_actions, 'target_critic_1').to(self.device)
self.target_critic_2 = Critic(state_space, hidden_dims, n_actions, 'target_critic_2').to(self.device)

# copy weights
self.update_network_parameters(tau=1)

def _get_noise(self, clip=True):
noise = torch.randn(self.n_actions, dtype=torch.float, device=self.device) * self.sigma
if clip:
noise = noise.clamp(-self.c, self.c)
return noise

def _clamp_action_bound(self, action):
return action.clamp(self.min_action, self.max_action)

def choose_action(self, observation):
if self.time_step < self.warmup:
mu = self._get_noise(clip=False)
else:
state = torch.tensor(observation, dtype=torch.float).to(self.device)
mu = self.actor(state) + self._get_noise(clip=False)
self.time_step += 1
return self._clamp_action_bound(mu).cpu().detach().numpy()

def remember(self, state, action, reward, state_, done):
self.buffer.store_transition(state, action, reward, state_, done)

def critic_step(self, state, action, reward, state_, done):
# get target actions w/ noise
target_actions = self.target_actor(state_) + self._get_noise()
target_actions = self._clamp_action_bound(target_actions)

# target & online values
q1_ = self.target_critic_1(state_, target_actions)
q2_ = self.target_critic_2(state_, target_actions)

# done mask
q1_[done], q2_[done] = 0.0, 0.0

q1 = self.critic_1(state, action)
q2 = self.critic_2(state, action)

q1_ = q1_.view(-1)
q2_ = q2_.view(-1)

critic_value_ = torch.min(q1_, q2_)

target = reward + self.gamma * critic_value_
target = target.unsqueeze(1)

self.critic_optimizer.zero_grad()

q1_loss = F.mse_loss(target, q1)
q2_loss = F.mse_loss(target, q2)
critic_loss = q1_loss + q2_loss
critic_loss.backward()
self.critic_optimizer.step()

self.writer.add_scalar('Critic loss', critic_loss.item(), global_step=self.learn_step_counter)

def actor_step(self, state):
# calculate loss, update actor params
self.actor_optimizer.zero_grad()
actor_loss = -torch.mean(self.critic_1(state, self.actor(state)))
actor_loss.backward()
self.actor_optimizer.step()

# update & log
self.update_network_parameters()
self.writer.add_scalar('Actor loss', actor_loss.item(), global_step=self.learn_step_counter)

def learn(self):
self.learn_step_counter += 1

# if the buffer is not yet filled w/ enough samples
if self.buffer.counter < self.batch_size:
return

# transitions
state, action, reward, state_, done = self.buffer.sample_buffer(self.batch_size)
reward = torch.tensor(reward, dtype=torch.float).to(self.device)
done = torch.tensor(done).to(self.device)
state = torch.tensor(state, dtype=torch.float).to(self.device)
state_ = torch.tensor(state_, dtype=torch.float).to(self.device)
action = torch.tensor(action, dtype=torch.float).to(self.device)

self.critic_step(state, action, reward, state_, done)
if self.learn_step_counter % self.d == 0:
self.actor_step(state)

def momentum_update(self, online_network, target_network, tau):
for param_o, param_t in zip(online_network.parameters(), target_network.parameters()):
param_t.data = param_t.data * tau + param_o.data * (1. - tau)

def update_network_parameters(self, tau=None):
if tau is None:
tau = self.tau
self.momentum_update(self.critic_1, self.target_critic_1, tau)
self.momentum_update(self.critic_2, self.target_critic_2, tau)
self.momentum_update(self.actor, self.target_actor, tau)

def add_scalar(self, tag, scalar_value, global_step=None):
self.writer.add_scalar(tag, scalar_value, global_step=global_step)

def save_networks(self):
torch.save({
'actor': self.actor.state_dict(),
'target_actor': self.target_actor.state_dict(),
'critic_1': self.critic_1.state_dict(),
'critic_2': self.critic_2.state_dict(),
'target_critic_1': self.target_critic_1.state_dict(),
'target_critic_2': self.target_critic_2.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict(),
'actor_optimizer': self.actor_optimizer.state_dict(),
}, self.checkpoint_dir)

def load_state_dicts(self):
state_dict = torch.load(self.checkpoint_dir)
self.actor.load_state_dict(state_dict['actor'])
self.target_actor.load_state_dict(state_dict['target_actor'])
self.critic_1.load_state_dict(state_dict['critic_1'])
self.critic_2.load_state_dict(state_dict['critic_2'])
self.target_critic_1.load_state_dict(state_dict['target_critic_1'])
self.target_critic_2.load_state_dict(state_dict['target_critic_2'])
self.critic_optimizer.load_state_dict(state_dict['critic_optimizer'])
self.actor_optimizer.load_state_dict(state_dict['actor_optimizer'])
90 changes: 90 additions & 0 deletions td3/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
https://www.youtube.com/watch?v=ZhFO8EWADmY&ab_channel=MachineLearningwithPhil
"""

import argparse
import gym
from tqdm import tqdm
from pathlib import Path
from collections import deque

from td3.agent import Agent


parser = argparse.ArgumentParser()

# agent hyperparameters
parser.add_argument('--env_name', type=str, default='Pendulum-v0', help='Gyme env name')
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='List of hidden dims for fc network')
parser.add_argument('--tau', type=float, default=0.005, help='Soft update param')
parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor')
parser.add_argument('--sigma', type=float, default=0.2, help='Gaussian noise standard deviation')
parser.add_argument('--c', type=float, default=0.5, help='Noise clip')

# training hp params
parser.add_argument('--n_episodes', type=int, default=1000, help='Number of episodes')
parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate actor')
parser.add_argument('--beta', type=float, default=0.001, help='Learning rate critic')
parser.add_argument('--warmup', type=int, default=1000, help='Number of warmup steps')
parser.add_argument('--d', type=int, default=2, help='Skip iteration')
parser.add_argument('--max_size', type=int, default=1000000, help='Replay buffer size')
parser.add_argument('--no_render', action="store_true", default=False, help='Whether to render')
parser.add_argument('--window_size', type=int, default=100, help='Score tracking moving average window size')

# misc
parser.add_argument('--one_device', action="store_false", default=True, help='Whether to only train on device 0')
parser.add_argument('--log_dir', type=str, default='td3/logs', help='Path to where log files will be saved')
parser.add_argument('--checkpoint_dir', type=str, default='td3/network_weights', help='Path to where model weights will be saved')
args = parser.parse_args()


if __name__ == '__main__':
# paths
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
args.checkpoint_dir += f'/{args.env_name}_td3.pth'

# env & agent
env = gym.make(args.env_name)
agent = Agent(env, args.alpha, args.beta, args.hidden_dims, args.tau, args.batch_size,
args.gamma, args.d, args.warmup, args.max_size, args.c, args.sigma,
args.one_device, args.log_dir, args.checkpoint_dir)

best_score = env.reward_range[0]
score_history = deque([], maxlen=args.window_size)
episodes = tqdm(range(args.n_episodes))

for e in episodes:
# resetting
state = env.reset()
done = False
score = 0

while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
agent.remember(state, action, reward, state_, done)
agent.learn()

# reset, log & render
score += reward
state = state_
episodes.set_postfix({'Reward': reward})
if args.no_render:
continue
env.render()

# logging
score_history.append(score)
moving_avg = sum(score_history) / len(score_history)
agent.add_scalar('Average Score', moving_avg, global_step=e)
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
Episode Score: {score}, \
Average Score: {moving_avg}, \
Best Score: {best_score}')

# save weights @ best score
if moving_avg > best_score:
best_score = moving_avg
agent.save_networks()
43 changes: 43 additions & 0 deletions td3/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch.nn as nn
import torch


class Critic(nn.Module):
def __init__(self, input_dims, hidden_dims, n_actions, name):
super().__init__()
self.name = name

fcs = []
prev_dim = input_dims + n_actions
# input layers
for hidden_dim in hidden_dims:
fcs.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU()])
prev_dim = hidden_dim

# output layer
fcs.append(nn.Linear(prev_dim, 1))
self.q = nn.Sequential(*fcs)

def forward(self, state, action):
x = torch.cat([state, action], dim=1)
return self.q(x)


class Actor(nn.Module):
def __init__(self, input_dims, hidden_dims, n_actions, name):
super().__init__()
self.name = name

fcs = []
prev_dim = input_dims
# input layers
for hidden_size in hidden_dims:
fcs.extend([nn.Linear(prev_dim, hidden_size), nn.ReLU()])
prev_dim = hidden_size

# output layer
fcs.extend([nn.Linear(prev_dim, n_actions), nn.Tanh()])
self.pi = nn.Sequential(*fcs)

def forward(self, state):
return self.pi(state)
33 changes: 33 additions & 0 deletions td3/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np


class ReplayBuffer:
def __init__(self, max_size, input_shape, n_actions):
self.mem_size = max_size
self.counter = 0
self.state_memory = np.zeros((self.mem_size, input_shape))
self.new_state_memory = np.zeros((self.mem_size, input_shape))
self.action_memory = np.zeros((self.mem_size, n_actions))
self.reward_memory = np.zeros(self.mem_size)
self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)

def store_transition(self, state, action, reward, state_, done):
index = self.counter % self.mem_size
self.state_memory[index] = state
self.new_state_memory[index] = state_
self.terminal_memory[index] = done
self.reward_memory[index] = reward
self.action_memory[index] = action

self.counter += 1

def sample_buffer(self, batch_size):
max_mem = min(self.counter, self.mem_size)
batch = np.random.choice(max_mem, batch_size)
state = self.state_memory[batch]
state_ = self.new_state_memory[batch]
done = self.terminal_memory[batch]
reward = self.reward_memory[batch]
action = self.action_memory[batch]

return state, action, reward, state_, done

0 comments on commit d34a333

Please sign in to comment.