-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e9a9f7c
commit d34a333
Showing
6 changed files
with
347 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import os | ||
from itertools import chain | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from td3.utils import ReplayBuffer | ||
from td3.networks import Actor, Critic | ||
|
||
|
||
class Agent: | ||
def __init__(self, env, alpha, beta, hidden_dims, tau, | ||
batch_size, gamma, d, warmup, max_size, c, | ||
sigma, one_device, log_dir, checkpoint_dir): | ||
state_space = env.observation_space.shape[0] | ||
n_actions = env.action_space.shape[0] | ||
|
||
# training params | ||
self.gamma = gamma | ||
self.tau = tau | ||
self.max_action = env.action_space.high[0] | ||
self.min_action = env.action_space.low[0] | ||
self.buffer = ReplayBuffer(max_size, state_space, n_actions) | ||
self.batch_size = batch_size | ||
self.learn_step_counter = 0 | ||
self.time_step = 0 | ||
self.warmup = warmup | ||
self.n_actions = n_actions | ||
self.d = d | ||
self.c = c | ||
self.sigma = sigma | ||
|
||
# training device | ||
if one_device: | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
||
# logging/checkpointing | ||
self.writer = SummaryWriter(log_dir) | ||
self.checkpoint_dir = checkpoint_dir | ||
|
||
# networks & optimizers | ||
self.actor = Actor(state_space, hidden_dims, n_actions, 'actor').to(self.device) | ||
self.critic_1 = Critic(state_space, hidden_dims, n_actions, 'critic_1').to(self.device) | ||
self.critic_2 = Critic(state_space, hidden_dims, n_actions, 'critic_2').to(self.device) | ||
|
||
self.critic_optimizer = torch.optim.Adam( | ||
chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta) | ||
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha) | ||
|
||
self.target_actor = Actor(state_space, hidden_dims, n_actions, 'target_actor').to(self.device) | ||
self.target_critic_1 = Critic(state_space, hidden_dims, n_actions, 'target_critic_1').to(self.device) | ||
self.target_critic_2 = Critic(state_space, hidden_dims, n_actions, 'target_critic_2').to(self.device) | ||
|
||
# copy weights | ||
self.update_network_parameters(tau=1) | ||
|
||
def _get_noise(self, clip=True): | ||
noise = torch.randn(self.n_actions, dtype=torch.float, device=self.device) * self.sigma | ||
if clip: | ||
noise = noise.clamp(-self.c, self.c) | ||
return noise | ||
|
||
def _clamp_action_bound(self, action): | ||
return action.clamp(self.min_action, self.max_action) | ||
|
||
def choose_action(self, observation): | ||
if self.time_step < self.warmup: | ||
mu = self._get_noise(clip=False) | ||
else: | ||
state = torch.tensor(observation, dtype=torch.float).to(self.device) | ||
mu = self.actor(state) + self._get_noise(clip=False) | ||
self.time_step += 1 | ||
return self._clamp_action_bound(mu).cpu().detach().numpy() | ||
|
||
def remember(self, state, action, reward, state_, done): | ||
self.buffer.store_transition(state, action, reward, state_, done) | ||
|
||
def critic_step(self, state, action, reward, state_, done): | ||
# get target actions w/ noise | ||
target_actions = self.target_actor(state_) + self._get_noise() | ||
target_actions = self._clamp_action_bound(target_actions) | ||
|
||
# target & online values | ||
q1_ = self.target_critic_1(state_, target_actions) | ||
q2_ = self.target_critic_2(state_, target_actions) | ||
|
||
# done mask | ||
q1_[done], q2_[done] = 0.0, 0.0 | ||
|
||
q1 = self.critic_1(state, action) | ||
q2 = self.critic_2(state, action) | ||
|
||
q1_ = q1_.view(-1) | ||
q2_ = q2_.view(-1) | ||
|
||
critic_value_ = torch.min(q1_, q2_) | ||
|
||
target = reward + self.gamma * critic_value_ | ||
target = target.unsqueeze(1) | ||
|
||
self.critic_optimizer.zero_grad() | ||
|
||
q1_loss = F.mse_loss(target, q1) | ||
q2_loss = F.mse_loss(target, q2) | ||
critic_loss = q1_loss + q2_loss | ||
critic_loss.backward() | ||
self.critic_optimizer.step() | ||
|
||
self.writer.add_scalar('Critic loss', critic_loss.item(), global_step=self.learn_step_counter) | ||
|
||
def actor_step(self, state): | ||
# calculate loss, update actor params | ||
self.actor_optimizer.zero_grad() | ||
actor_loss = -torch.mean(self.critic_1(state, self.actor(state))) | ||
actor_loss.backward() | ||
self.actor_optimizer.step() | ||
|
||
# update & log | ||
self.update_network_parameters() | ||
self.writer.add_scalar('Actor loss', actor_loss.item(), global_step=self.learn_step_counter) | ||
|
||
def learn(self): | ||
self.learn_step_counter += 1 | ||
|
||
# if the buffer is not yet filled w/ enough samples | ||
if self.buffer.counter < self.batch_size: | ||
return | ||
|
||
# transitions | ||
state, action, reward, state_, done = self.buffer.sample_buffer(self.batch_size) | ||
reward = torch.tensor(reward, dtype=torch.float).to(self.device) | ||
done = torch.tensor(done).to(self.device) | ||
state = torch.tensor(state, dtype=torch.float).to(self.device) | ||
state_ = torch.tensor(state_, dtype=torch.float).to(self.device) | ||
action = torch.tensor(action, dtype=torch.float).to(self.device) | ||
|
||
self.critic_step(state, action, reward, state_, done) | ||
if self.learn_step_counter % self.d == 0: | ||
self.actor_step(state) | ||
|
||
def momentum_update(self, online_network, target_network, tau): | ||
for param_o, param_t in zip(online_network.parameters(), target_network.parameters()): | ||
param_t.data = param_t.data * tau + param_o.data * (1. - tau) | ||
|
||
def update_network_parameters(self, tau=None): | ||
if tau is None: | ||
tau = self.tau | ||
self.momentum_update(self.critic_1, self.target_critic_1, tau) | ||
self.momentum_update(self.critic_2, self.target_critic_2, tau) | ||
self.momentum_update(self.actor, self.target_actor, tau) | ||
|
||
def add_scalar(self, tag, scalar_value, global_step=None): | ||
self.writer.add_scalar(tag, scalar_value, global_step=global_step) | ||
|
||
def save_networks(self): | ||
torch.save({ | ||
'actor': self.actor.state_dict(), | ||
'target_actor': self.target_actor.state_dict(), | ||
'critic_1': self.critic_1.state_dict(), | ||
'critic_2': self.critic_2.state_dict(), | ||
'target_critic_1': self.target_critic_1.state_dict(), | ||
'target_critic_2': self.target_critic_2.state_dict(), | ||
'critic_optimizer': self.critic_optimizer.state_dict(), | ||
'actor_optimizer': self.actor_optimizer.state_dict(), | ||
}, self.checkpoint_dir) | ||
|
||
def load_state_dicts(self): | ||
state_dict = torch.load(self.checkpoint_dir) | ||
self.actor.load_state_dict(state_dict['actor']) | ||
self.target_actor.load_state_dict(state_dict['target_actor']) | ||
self.critic_1.load_state_dict(state_dict['critic_1']) | ||
self.critic_2.load_state_dict(state_dict['critic_2']) | ||
self.target_critic_1.load_state_dict(state_dict['target_critic_1']) | ||
self.target_critic_2.load_state_dict(state_dict['target_critic_2']) | ||
self.critic_optimizer.load_state_dict(state_dict['critic_optimizer']) | ||
self.actor_optimizer.load_state_dict(state_dict['actor_optimizer']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
""" | ||
https://www.youtube.com/watch?v=ZhFO8EWADmY&ab_channel=MachineLearningwithPhil | ||
""" | ||
|
||
import argparse | ||
import gym | ||
from tqdm import tqdm | ||
from pathlib import Path | ||
from collections import deque | ||
|
||
from td3.agent import Agent | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
|
||
# agent hyperparameters | ||
parser.add_argument('--env_name', type=str, default='Pendulum-v0', help='Gyme env name') | ||
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='List of hidden dims for fc network') | ||
parser.add_argument('--tau', type=float, default=0.005, help='Soft update param') | ||
parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor') | ||
parser.add_argument('--sigma', type=float, default=0.2, help='Gaussian noise standard deviation') | ||
parser.add_argument('--c', type=float, default=0.5, help='Noise clip') | ||
|
||
# training hp params | ||
parser.add_argument('--n_episodes', type=int, default=1000, help='Number of episodes') | ||
parser.add_argument('--batch_size', type=int, default=100, help='Batch size') | ||
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate actor') | ||
parser.add_argument('--beta', type=float, default=0.001, help='Learning rate critic') | ||
parser.add_argument('--warmup', type=int, default=1000, help='Number of warmup steps') | ||
parser.add_argument('--d', type=int, default=2, help='Skip iteration') | ||
parser.add_argument('--max_size', type=int, default=1000000, help='Replay buffer size') | ||
parser.add_argument('--no_render', action="store_true", default=False, help='Whether to render') | ||
parser.add_argument('--window_size', type=int, default=100, help='Score tracking moving average window size') | ||
|
||
# misc | ||
parser.add_argument('--one_device', action="store_false", default=True, help='Whether to only train on device 0') | ||
parser.add_argument('--log_dir', type=str, default='td3/logs', help='Path to where log files will be saved') | ||
parser.add_argument('--checkpoint_dir', type=str, default='td3/network_weights', help='Path to where model weights will be saved') | ||
args = parser.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
# paths | ||
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) | ||
Path(args.log_dir).mkdir(parents=True, exist_ok=True) | ||
args.checkpoint_dir += f'/{args.env_name}_td3.pth' | ||
|
||
# env & agent | ||
env = gym.make(args.env_name) | ||
agent = Agent(env, args.alpha, args.beta, args.hidden_dims, args.tau, args.batch_size, | ||
args.gamma, args.d, args.warmup, args.max_size, args.c, args.sigma, | ||
args.one_device, args.log_dir, args.checkpoint_dir) | ||
|
||
best_score = env.reward_range[0] | ||
score_history = deque([], maxlen=args.window_size) | ||
episodes = tqdm(range(args.n_episodes)) | ||
|
||
for e in episodes: | ||
# resetting | ||
state = env.reset() | ||
done = False | ||
score = 0 | ||
|
||
while not done: | ||
action = agent.choose_action(state) | ||
state_, reward, done, _ = env.step(action) | ||
agent.remember(state, action, reward, state_, done) | ||
agent.learn() | ||
|
||
# reset, log & render | ||
score += reward | ||
state = state_ | ||
episodes.set_postfix({'Reward': reward}) | ||
if args.no_render: | ||
continue | ||
env.render() | ||
|
||
# logging | ||
score_history.append(score) | ||
moving_avg = sum(score_history) / len(score_history) | ||
agent.add_scalar('Average Score', moving_avg, global_step=e) | ||
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \ | ||
Episode Score: {score}, \ | ||
Average Score: {moving_avg}, \ | ||
Best Score: {best_score}') | ||
|
||
# save weights @ best score | ||
if moving_avg > best_score: | ||
best_score = moving_avg | ||
agent.save_networks() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch.nn as nn | ||
import torch | ||
|
||
|
||
class Critic(nn.Module): | ||
def __init__(self, input_dims, hidden_dims, n_actions, name): | ||
super().__init__() | ||
self.name = name | ||
|
||
fcs = [] | ||
prev_dim = input_dims + n_actions | ||
# input layers | ||
for hidden_dim in hidden_dims: | ||
fcs.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU()]) | ||
prev_dim = hidden_dim | ||
|
||
# output layer | ||
fcs.append(nn.Linear(prev_dim, 1)) | ||
self.q = nn.Sequential(*fcs) | ||
|
||
def forward(self, state, action): | ||
x = torch.cat([state, action], dim=1) | ||
return self.q(x) | ||
|
||
|
||
class Actor(nn.Module): | ||
def __init__(self, input_dims, hidden_dims, n_actions, name): | ||
super().__init__() | ||
self.name = name | ||
|
||
fcs = [] | ||
prev_dim = input_dims | ||
# input layers | ||
for hidden_size in hidden_dims: | ||
fcs.extend([nn.Linear(prev_dim, hidden_size), nn.ReLU()]) | ||
prev_dim = hidden_size | ||
|
||
# output layer | ||
fcs.extend([nn.Linear(prev_dim, n_actions), nn.Tanh()]) | ||
self.pi = nn.Sequential(*fcs) | ||
|
||
def forward(self, state): | ||
return self.pi(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
|
||
|
||
class ReplayBuffer: | ||
def __init__(self, max_size, input_shape, n_actions): | ||
self.mem_size = max_size | ||
self.counter = 0 | ||
self.state_memory = np.zeros((self.mem_size, input_shape)) | ||
self.new_state_memory = np.zeros((self.mem_size, input_shape)) | ||
self.action_memory = np.zeros((self.mem_size, n_actions)) | ||
self.reward_memory = np.zeros(self.mem_size) | ||
self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool) | ||
|
||
def store_transition(self, state, action, reward, state_, done): | ||
index = self.counter % self.mem_size | ||
self.state_memory[index] = state | ||
self.new_state_memory[index] = state_ | ||
self.terminal_memory[index] = done | ||
self.reward_memory[index] = reward | ||
self.action_memory[index] = action | ||
|
||
self.counter += 1 | ||
|
||
def sample_buffer(self, batch_size): | ||
max_mem = min(self.counter, self.mem_size) | ||
batch = np.random.choice(max_mem, batch_size) | ||
state = self.state_memory[batch] | ||
state_ = self.new_state_memory[batch] | ||
done = self.terminal_memory[batch] | ||
reward = self.reward_memory[batch] | ||
action = self.action_memory[batch] | ||
|
||
return state, action, reward, state_, done |