Skip to content

Commit d34a333

Browse files
committed
first commit td3 implementation
1 parent e9a9f7c commit d34a333

File tree

6 files changed

+347
-3
lines changed

6 files changed

+347
-3
lines changed

networks/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, acti
4646
elif normalization == 'gn':
4747
norm = nn.GroupNorm(groups, out_channels)
4848
else:
49-
raise NotImplementedError('Please only choose normalization [bn, ln, in]')
49+
raise NotImplementedError('Please only choose normalization [bn, ln, in, gn]')
5050

5151
# activations
5252
if activation == 'relu':
@@ -78,7 +78,7 @@ def __init__(self, in_channels, activation, normalization, groups=1):
7878
elif normalization == 'gn':
7979
norm = nn.GroupNorm(groups, in_channels)
8080
else:
81-
raise NotImplementedError('Please only choose normalization [bn, ln, in]')
81+
raise NotImplementedError('Please only choose normalization [bn, ln, in, gn]')
8282

8383
# activations
8484
if activation == 'relu':

simsiam/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181

8282
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
8383
momentum=args.momentum, weight_decay=args.wd)
84-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(args.epochs * 0.1))
84+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
8585

8686
start_epoch = 0
8787
if args.continue_train:

td3/agent.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import os
2+
from itertools import chain
3+
4+
import torch
5+
import torch.nn.functional as F
6+
from torch.utils.tensorboard import SummaryWriter
7+
8+
from td3.utils import ReplayBuffer
9+
from td3.networks import Actor, Critic
10+
11+
12+
class Agent:
13+
def __init__(self, env, alpha, beta, hidden_dims, tau,
14+
batch_size, gamma, d, warmup, max_size, c,
15+
sigma, one_device, log_dir, checkpoint_dir):
16+
state_space = env.observation_space.shape[0]
17+
n_actions = env.action_space.shape[0]
18+
19+
# training params
20+
self.gamma = gamma
21+
self.tau = tau
22+
self.max_action = env.action_space.high[0]
23+
self.min_action = env.action_space.low[0]
24+
self.buffer = ReplayBuffer(max_size, state_space, n_actions)
25+
self.batch_size = batch_size
26+
self.learn_step_counter = 0
27+
self.time_step = 0
28+
self.warmup = warmup
29+
self.n_actions = n_actions
30+
self.d = d
31+
self.c = c
32+
self.sigma = sigma
33+
34+
# training device
35+
if one_device:
36+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
37+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
38+
39+
# logging/checkpointing
40+
self.writer = SummaryWriter(log_dir)
41+
self.checkpoint_dir = checkpoint_dir
42+
43+
# networks & optimizers
44+
self.actor = Actor(state_space, hidden_dims, n_actions, 'actor').to(self.device)
45+
self.critic_1 = Critic(state_space, hidden_dims, n_actions, 'critic_1').to(self.device)
46+
self.critic_2 = Critic(state_space, hidden_dims, n_actions, 'critic_2').to(self.device)
47+
48+
self.critic_optimizer = torch.optim.Adam(
49+
chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta)
50+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha)
51+
52+
self.target_actor = Actor(state_space, hidden_dims, n_actions, 'target_actor').to(self.device)
53+
self.target_critic_1 = Critic(state_space, hidden_dims, n_actions, 'target_critic_1').to(self.device)
54+
self.target_critic_2 = Critic(state_space, hidden_dims, n_actions, 'target_critic_2').to(self.device)
55+
56+
# copy weights
57+
self.update_network_parameters(tau=1)
58+
59+
def _get_noise(self, clip=True):
60+
noise = torch.randn(self.n_actions, dtype=torch.float, device=self.device) * self.sigma
61+
if clip:
62+
noise = noise.clamp(-self.c, self.c)
63+
return noise
64+
65+
def _clamp_action_bound(self, action):
66+
return action.clamp(self.min_action, self.max_action)
67+
68+
def choose_action(self, observation):
69+
if self.time_step < self.warmup:
70+
mu = self._get_noise(clip=False)
71+
else:
72+
state = torch.tensor(observation, dtype=torch.float).to(self.device)
73+
mu = self.actor(state) + self._get_noise(clip=False)
74+
self.time_step += 1
75+
return self._clamp_action_bound(mu).cpu().detach().numpy()
76+
77+
def remember(self, state, action, reward, state_, done):
78+
self.buffer.store_transition(state, action, reward, state_, done)
79+
80+
def critic_step(self, state, action, reward, state_, done):
81+
# get target actions w/ noise
82+
target_actions = self.target_actor(state_) + self._get_noise()
83+
target_actions = self._clamp_action_bound(target_actions)
84+
85+
# target & online values
86+
q1_ = self.target_critic_1(state_, target_actions)
87+
q2_ = self.target_critic_2(state_, target_actions)
88+
89+
# done mask
90+
q1_[done], q2_[done] = 0.0, 0.0
91+
92+
q1 = self.critic_1(state, action)
93+
q2 = self.critic_2(state, action)
94+
95+
q1_ = q1_.view(-1)
96+
q2_ = q2_.view(-1)
97+
98+
critic_value_ = torch.min(q1_, q2_)
99+
100+
target = reward + self.gamma * critic_value_
101+
target = target.unsqueeze(1)
102+
103+
self.critic_optimizer.zero_grad()
104+
105+
q1_loss = F.mse_loss(target, q1)
106+
q2_loss = F.mse_loss(target, q2)
107+
critic_loss = q1_loss + q2_loss
108+
critic_loss.backward()
109+
self.critic_optimizer.step()
110+
111+
self.writer.add_scalar('Critic loss', critic_loss.item(), global_step=self.learn_step_counter)
112+
113+
def actor_step(self, state):
114+
# calculate loss, update actor params
115+
self.actor_optimizer.zero_grad()
116+
actor_loss = -torch.mean(self.critic_1(state, self.actor(state)))
117+
actor_loss.backward()
118+
self.actor_optimizer.step()
119+
120+
# update & log
121+
self.update_network_parameters()
122+
self.writer.add_scalar('Actor loss', actor_loss.item(), global_step=self.learn_step_counter)
123+
124+
def learn(self):
125+
self.learn_step_counter += 1
126+
127+
# if the buffer is not yet filled w/ enough samples
128+
if self.buffer.counter < self.batch_size:
129+
return
130+
131+
# transitions
132+
state, action, reward, state_, done = self.buffer.sample_buffer(self.batch_size)
133+
reward = torch.tensor(reward, dtype=torch.float).to(self.device)
134+
done = torch.tensor(done).to(self.device)
135+
state = torch.tensor(state, dtype=torch.float).to(self.device)
136+
state_ = torch.tensor(state_, dtype=torch.float).to(self.device)
137+
action = torch.tensor(action, dtype=torch.float).to(self.device)
138+
139+
self.critic_step(state, action, reward, state_, done)
140+
if self.learn_step_counter % self.d == 0:
141+
self.actor_step(state)
142+
143+
def momentum_update(self, online_network, target_network, tau):
144+
for param_o, param_t in zip(online_network.parameters(), target_network.parameters()):
145+
param_t.data = param_t.data * tau + param_o.data * (1. - tau)
146+
147+
def update_network_parameters(self, tau=None):
148+
if tau is None:
149+
tau = self.tau
150+
self.momentum_update(self.critic_1, self.target_critic_1, tau)
151+
self.momentum_update(self.critic_2, self.target_critic_2, tau)
152+
self.momentum_update(self.actor, self.target_actor, tau)
153+
154+
def add_scalar(self, tag, scalar_value, global_step=None):
155+
self.writer.add_scalar(tag, scalar_value, global_step=global_step)
156+
157+
def save_networks(self):
158+
torch.save({
159+
'actor': self.actor.state_dict(),
160+
'target_actor': self.target_actor.state_dict(),
161+
'critic_1': self.critic_1.state_dict(),
162+
'critic_2': self.critic_2.state_dict(),
163+
'target_critic_1': self.target_critic_1.state_dict(),
164+
'target_critic_2': self.target_critic_2.state_dict(),
165+
'critic_optimizer': self.critic_optimizer.state_dict(),
166+
'actor_optimizer': self.actor_optimizer.state_dict(),
167+
}, self.checkpoint_dir)
168+
169+
def load_state_dicts(self):
170+
state_dict = torch.load(self.checkpoint_dir)
171+
self.actor.load_state_dict(state_dict['actor'])
172+
self.target_actor.load_state_dict(state_dict['target_actor'])
173+
self.critic_1.load_state_dict(state_dict['critic_1'])
174+
self.critic_2.load_state_dict(state_dict['critic_2'])
175+
self.target_critic_1.load_state_dict(state_dict['target_critic_1'])
176+
self.target_critic_2.load_state_dict(state_dict['target_critic_2'])
177+
self.critic_optimizer.load_state_dict(state_dict['critic_optimizer'])
178+
self.actor_optimizer.load_state_dict(state_dict['actor_optimizer'])

td3/main.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
https://www.youtube.com/watch?v=ZhFO8EWADmY&ab_channel=MachineLearningwithPhil
3+
"""
4+
5+
import argparse
6+
import gym
7+
from tqdm import tqdm
8+
from pathlib import Path
9+
from collections import deque
10+
11+
from td3.agent import Agent
12+
13+
14+
parser = argparse.ArgumentParser()
15+
16+
# agent hyperparameters
17+
parser.add_argument('--env_name', type=str, default='Pendulum-v0', help='Gyme env name')
18+
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='List of hidden dims for fc network')
19+
parser.add_argument('--tau', type=float, default=0.005, help='Soft update param')
20+
parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor')
21+
parser.add_argument('--sigma', type=float, default=0.2, help='Gaussian noise standard deviation')
22+
parser.add_argument('--c', type=float, default=0.5, help='Noise clip')
23+
24+
# training hp params
25+
parser.add_argument('--n_episodes', type=int, default=1000, help='Number of episodes')
26+
parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
27+
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate actor')
28+
parser.add_argument('--beta', type=float, default=0.001, help='Learning rate critic')
29+
parser.add_argument('--warmup', type=int, default=1000, help='Number of warmup steps')
30+
parser.add_argument('--d', type=int, default=2, help='Skip iteration')
31+
parser.add_argument('--max_size', type=int, default=1000000, help='Replay buffer size')
32+
parser.add_argument('--no_render', action="store_true", default=False, help='Whether to render')
33+
parser.add_argument('--window_size', type=int, default=100, help='Score tracking moving average window size')
34+
35+
# misc
36+
parser.add_argument('--one_device', action="store_false", default=True, help='Whether to only train on device 0')
37+
parser.add_argument('--log_dir', type=str, default='td3/logs', help='Path to where log files will be saved')
38+
parser.add_argument('--checkpoint_dir', type=str, default='td3/network_weights', help='Path to where model weights will be saved')
39+
args = parser.parse_args()
40+
41+
42+
if __name__ == '__main__':
43+
# paths
44+
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
45+
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
46+
args.checkpoint_dir += f'/{args.env_name}_td3.pth'
47+
48+
# env & agent
49+
env = gym.make(args.env_name)
50+
agent = Agent(env, args.alpha, args.beta, args.hidden_dims, args.tau, args.batch_size,
51+
args.gamma, args.d, args.warmup, args.max_size, args.c, args.sigma,
52+
args.one_device, args.log_dir, args.checkpoint_dir)
53+
54+
best_score = env.reward_range[0]
55+
score_history = deque([], maxlen=args.window_size)
56+
episodes = tqdm(range(args.n_episodes))
57+
58+
for e in episodes:
59+
# resetting
60+
state = env.reset()
61+
done = False
62+
score = 0
63+
64+
while not done:
65+
action = agent.choose_action(state)
66+
state_, reward, done, _ = env.step(action)
67+
agent.remember(state, action, reward, state_, done)
68+
agent.learn()
69+
70+
# reset, log & render
71+
score += reward
72+
state = state_
73+
episodes.set_postfix({'Reward': reward})
74+
if args.no_render:
75+
continue
76+
env.render()
77+
78+
# logging
79+
score_history.append(score)
80+
moving_avg = sum(score_history) / len(score_history)
81+
agent.add_scalar('Average Score', moving_avg, global_step=e)
82+
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
83+
Episode Score: {score}, \
84+
Average Score: {moving_avg}, \
85+
Best Score: {best_score}')
86+
87+
# save weights @ best score
88+
if moving_avg > best_score:
89+
best_score = moving_avg
90+
agent.save_networks()

td3/networks.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class Critic(nn.Module):
6+
def __init__(self, input_dims, hidden_dims, n_actions, name):
7+
super().__init__()
8+
self.name = name
9+
10+
fcs = []
11+
prev_dim = input_dims + n_actions
12+
# input layers
13+
for hidden_dim in hidden_dims:
14+
fcs.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU()])
15+
prev_dim = hidden_dim
16+
17+
# output layer
18+
fcs.append(nn.Linear(prev_dim, 1))
19+
self.q = nn.Sequential(*fcs)
20+
21+
def forward(self, state, action):
22+
x = torch.cat([state, action], dim=1)
23+
return self.q(x)
24+
25+
26+
class Actor(nn.Module):
27+
def __init__(self, input_dims, hidden_dims, n_actions, name):
28+
super().__init__()
29+
self.name = name
30+
31+
fcs = []
32+
prev_dim = input_dims
33+
# input layers
34+
for hidden_size in hidden_dims:
35+
fcs.extend([nn.Linear(prev_dim, hidden_size), nn.ReLU()])
36+
prev_dim = hidden_size
37+
38+
# output layer
39+
fcs.extend([nn.Linear(prev_dim, n_actions), nn.Tanh()])
40+
self.pi = nn.Sequential(*fcs)
41+
42+
def forward(self, state):
43+
return self.pi(state)

td3/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
3+
4+
class ReplayBuffer:
5+
def __init__(self, max_size, input_shape, n_actions):
6+
self.mem_size = max_size
7+
self.counter = 0
8+
self.state_memory = np.zeros((self.mem_size, input_shape))
9+
self.new_state_memory = np.zeros((self.mem_size, input_shape))
10+
self.action_memory = np.zeros((self.mem_size, n_actions))
11+
self.reward_memory = np.zeros(self.mem_size)
12+
self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)
13+
14+
def store_transition(self, state, action, reward, state_, done):
15+
index = self.counter % self.mem_size
16+
self.state_memory[index] = state
17+
self.new_state_memory[index] = state_
18+
self.terminal_memory[index] = done
19+
self.reward_memory[index] = reward
20+
self.action_memory[index] = action
21+
22+
self.counter += 1
23+
24+
def sample_buffer(self, batch_size):
25+
max_mem = min(self.counter, self.mem_size)
26+
batch = np.random.choice(max_mem, batch_size)
27+
state = self.state_memory[batch]
28+
state_ = self.new_state_memory[batch]
29+
done = self.terminal_memory[batch]
30+
reward = self.reward_memory[batch]
31+
action = self.action_memory[batch]
32+
33+
return state, action, reward, state_, done

0 commit comments

Comments
 (0)