Skip to content

Commit 68809d8

Browse files
committed
added SAC implementation
1 parent 4483219 commit 68809d8

File tree

3 files changed

+261
-14
lines changed

3 files changed

+261
-14
lines changed

policy/agent.py

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from torch.nn import functional as F
44
from copy import deepcopy
55

6-
from policy.networks import ActorCritic, Actor, Critic
6+
from policy.networks import ActorCritic, Actor, Critic, SACActor, SACCritic, SACValue
77
from policy.utils import ReplayBuffer, OUActionNoise, clip_action, GaussianActionNoise
8-
8+
torch.autograd.set_detect_anomaly(True)
99

1010
class BlackJackAgent:
1111
def __init__(self, method, env, function='V', gamma=0.99, epsilon=0.1):
@@ -332,7 +332,7 @@ def choose_action(self, observation, test):
332332
with torch.no_grad():
333333
action = self.actor(observation)
334334
if not test:
335-
action = action + self.noise(action.size())
335+
action = action + self.noise(action.size()).to(self.device)
336336
self.actor.train()
337337
action = action.cpu().detach().numpy()
338338
# clip noised action to ensure not out of bounds
@@ -357,12 +357,14 @@ def update(self):
357357

358358
# calculate targets & only update online critic network
359359
self.critic.optimizer.zero_grad()
360+
self.critic2.optimizer.zero_grad()
360361
with torch.no_grad():
361362
# y <- r + gamma * min_(i=1,2) Q_(theta'_i)(s', a_telda)
362363
target_actions = self.target_actor(next_states)
363364
target_actions += self.noise(
364-
target_actions.size(), clip=self.action_clip, sigma=self.action_sigma)
365-
target_actions = clip_action(target_actions, self.max_action)
365+
target_actions.size(), clip=self.action_clip, sigma=self.action_sigma).to(self.device)
366+
target_actions = clip_action(target_actions.cpu().numpy(), self.max_action)
367+
target_actions = torch.from_numpy(target_actions).to(self.device)
366368
q_primes1 = self.target_critic(next_states, target_actions).squeeze()
367369
q_primes2 = self.target_critic2(next_states, target_actions).squeeze()
368370
q_primes = torch.min(q_primes1, q_primes2)
@@ -375,4 +377,117 @@ def update(self):
375377
critic_loss = critic_loss1 + critic_loss2
376378
critic_loss.backward()
377379
self.critic.optimizer.step()
380+
self.critic2.optimizer.step()
378381
return self.actor_loss, critic_loss.item()
382+
383+
384+
class SACAgent:
385+
def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma,
386+
tau, reward_scale, lr, batch_size, maxsize, checkpoint):
387+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
388+
self.gamma = gamma
389+
self.tau = tau
390+
self.reward_scale = reward_scale
391+
self.batch_size = batch_size
392+
393+
self.memory = ReplayBuffer(state_dim, action_dim, maxsize)
394+
self.critic1 = SACCritic(*state_dim, *action_dim, hidden_dims, lr,
395+
checkpoint, 'Critic')
396+
self.critic2 = SACCritic(*state_dim, *action_dim, hidden_dims,
397+
lr, checkpoint, 'Critic2')
398+
self.actor = SACActor(*state_dim, *action_dim, hidden_dims, max_action,
399+
lr, checkpoint, 'Actor')
400+
self.value = SACValue(*state_dim, hidden_dims,
401+
lr, checkpoint, 'Valuator')
402+
self.target_value = self.get_target_network(self.value)
403+
self.target_value.name = 'Target_Valuator'
404+
405+
def get_target_network(self, online_network, freeze_weights=True):
406+
target_network = deepcopy(online_network)
407+
if freeze_weights:
408+
for param in target_network.parameters():
409+
param.requires_grad = False
410+
return target_network
411+
412+
def choose_action(self, observation, test):
413+
self.actor.eval()
414+
observation = torch.from_numpy(observation).to(self.device)
415+
with torch.no_grad():
416+
action, _ = self.actor(observation)
417+
self.actor.train()
418+
action = action.cpu().detach().numpy()
419+
return action
420+
421+
def update(self):
422+
experiences = self.memory.sample_transition(self.batch_size)
423+
states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences]
424+
425+
###### UPDATE VALUATOR ######
426+
self.value.optimizer.zero_grad()
427+
with torch.no_grad():
428+
policy_actions, log_probs = self.actor(states, reparameterize=False)
429+
action_values1 = self.critic1(states, policy_actions).squeeze()
430+
action_values2 = self.critic2(states, policy_actions).squeeze()
431+
action_values = torch.min(action_values1, action_values2)
432+
target = action_values - log_probs.squeeze()
433+
values = self.value(states).squeeze()
434+
value_loss = 0.5 * F.mse_loss(target, values)
435+
value_loss.backward()
436+
self.value.optimizer.step()
437+
438+
###### UPDATE CRITIC ######
439+
self.critic1.optimizer.zero_grad()
440+
self.critic2.optimizer.zero_grad()
441+
with torch.no_grad():
442+
v_hat = self.target_value(next_states).squeeze() * (~dones)
443+
targets = rewards * self.reward_scale + self.gamma * v_hat
444+
qs1 = self.critic1(states, actions).squeeze()
445+
qs2 = self.critic2(states, actions).squeeze()
446+
critic_loss1 = 0.5 * F.mse_loss(targets, qs1)
447+
critic_loss2 = 0.5 * F.mse_loss(targets, qs2)
448+
critic_loss = critic_loss1 + critic_loss2
449+
critic_loss.backward()
450+
self.critic1.optimizer.step()
451+
self.critic2.optimizer.step()
452+
453+
###### UPDATE ACTOR ######
454+
self.actor.optimizer.zero_grad()
455+
actions, log_probs = self.actor(states)
456+
action_values1 = self.critic1(states, actions).squeeze()
457+
action_values2 = self.critic2(states, actions).squeeze()
458+
action_values = torch.min(action_values1, action_values2)
459+
actor_loss = torch.mean(log_probs.squeeze() - action_values)
460+
actor_loss.backward()
461+
self.actor.optimizer.step()
462+
463+
###### UPDATE TARGET VALUE ######
464+
self.update_target_network(self.value, self.target_value)
465+
466+
return value_loss.item(), critic_loss.item(), actor_loss.item()
467+
468+
def update_target_network(self, src, tgt):
469+
for src_weight, tgt_weight in zip(src.parameters(), tgt.parameters()):
470+
tgt_weight.data = tgt_weight.data * self.tau + src_weight.data * (1. - self.tau)
471+
472+
def store_transition(self, state, action, reward, next_state, done):
473+
state = torch.tensor(state)
474+
action = torch.tensor(action)
475+
reward = torch.tensor(reward)
476+
next_state = torch.tensor(next_state)
477+
done = torch.tensor(done, dtype=torch.bool)
478+
self.memory.store_transition(state, action, reward, next_state, done)
479+
480+
def save_models(self):
481+
self.critic1.save_checkpoint()
482+
self.critic2.save_checkpoint()
483+
self.actor.save_checkpoint()
484+
self.value.save_checkpoint()
485+
self.target_value.save_checkpoint()
486+
487+
def load_models(self):
488+
self.critic1.load_checkpoint()
489+
self.critic2.load_checkpoint()
490+
self.actor.load_checkpoint()
491+
self.value.load_checkpoint()
492+
self.target_value.load_checkpoint()
493+

policy/continuous/main.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
parser = argparse.ArgumentParser(description='Continuous Environment Agents')
1414
# training hyperparams
15-
parser.add_argument('--agent', type=str, default='TD3', help='Agent Algorithm')
15+
parser.add_argument('--agent', type=str, default='SAC', help='Agent Algorithm')
1616
parser.add_argument('--environment', type=str, default='LunarLanderContinuous-v2', help='Agent Algorithm')
1717
parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for')
18-
parser.add_argument('--batch_size', type=int, default=100, help='Minibatch size')
18+
parser.add_argument('--batch_size', type=int, default=256, help='Minibatch size')
1919
parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers')
2020
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers')
2121
parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic')
@@ -33,6 +33,7 @@
3333
parser.add_argument('--actor_update_iter', type=int, default=2, help='Update actor and target network every')
3434
parser.add_argument('--action_sigma', type=float, default=0.2, help='Std of noise for actions')
3535
parser.add_argument('--action_clip', type=float, default=0.5, help='Max action bound')
36+
parser.add_argument('--reward_scale', type=float, default=2., help='Reward scale for Soft Actor-Critic')
3637

3738
# eval params
3839
parser.add_argument('--render', action="store_true", default=False, help='Render environment while training')
@@ -95,6 +96,20 @@ def main():
9596
action_sigma=args.action_sigma,
9697
action_clip=args.action_clip
9798
)
99+
elif args.agent == 'SAC':
100+
max_action = float(env.action_space.high[0])
101+
agent = agent_(state_dim=env.observation_space.shape,
102+
action_dim=env.action_space.shape,
103+
hidden_dims=args.hidden_dims,
104+
max_action=max_action,
105+
gamma=args.gamma,
106+
tau=args.tau,
107+
reward_scale=2,
108+
lr=args.critic_lr,
109+
batch_size=args.batch_size,
110+
maxsize=int(args.maxsize),
111+
checkpoint=args.checkpoint,
112+
)
98113
else:
99114
agent = agent_(state_dim=env.observation_space.shape,
100115
actionaction_dim_dim=env.action_space.n,
@@ -116,13 +131,15 @@ def main():
116131
done, score, observation = False, 0, env.reset()
117132

118133
# reset DDPG UO Noise and also keep track of actor/critic losses
119-
if args.agent in ['DDPG', 'TD3']:
134+
if args.agent in ['DDPG', 'TD3', 'SAC']:
120135
if args.agent == 'DDPG':
121136
agent.noise.reset()
122137
actor_losses, critic_losses = [], []
138+
if args.agent == 'SAC':
139+
value_losses = []
123140
while not done:
124141
if args.render:
125-
env.render()
142+
env.render(mode='human')
126143

127144
action = agent.choose_action(observation, args.test)
128145
next_observation, reward, done, _ = env.step(action)
@@ -133,15 +150,23 @@ def main():
133150
continue
134151
elif args.agent == 'Actor Critic':
135152
agent.update(reward, next_observation, done)
136-
elif args.agent in ['DDPG', 'TD3']:
153+
elif args.agent in ['DDPG', 'TD3', 'SAC']:
137154
agent.store_transition(observation, action, reward, next_observation, done)
138155
# if we have memory smaller than batch size, do not update
139156
if agent.memory.idx < args.batch_size or (args.agent == 'TD3' and agent.ctr < args.warmup_steps):
140157
continue
141-
actor_loss, critic_loss = agent.update()
158+
if args.agent == 'SAC':
159+
value_loss, critic_loss, actor_loss = agent.update()
160+
value_losses.append(value_loss)
161+
else:
162+
actor_loss, critic_loss = agent.update()
142163
actor_losses.append(actor_loss)
143164
critic_losses.append(critic_loss)
144-
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss})
165+
if args.agent == 'SAC':
166+
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss,
167+
'Critic Loss': critic_loss, 'Value Loss': value_loss})
168+
else:
169+
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss})
145170
else:
146171
agent.store_reward(reward)
147172
observation = next_observation
@@ -155,15 +180,20 @@ def main():
155180
agent.update()
156181

157182
# logging & saving
158-
elif args.agent in ['DDPG', 'TD3']:
183+
elif args.agent in ['DDPG', 'TD3', 'SAC']:
159184
writer.add_scalars(
160185
'Scores',
161186
{'Episodic': score, 'Windowed Average': np.mean(score_history)},
162187
global_step=e)
188+
163189
if actor_losses:
190+
loss_dict = {'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)}
191+
if args.agent == 'SAC':
192+
loss_dict['Value'] = np.mean(value_losses)
193+
value_losses = []
164194
writer.add_scalars(
165195
'Losses',
166-
{'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)},
196+
loss_dict,
167197
global_step=e)
168198
actor_losses, critic_losses = [], []
169199

policy/networks.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import torch
33
from torch import nn
4+
from torch.distributions.multivariate_normal import MultivariateNormal
45

56

67
class ActorCritic(nn.Module):
@@ -112,3 +113,104 @@ def save_checkpoint(self):
112113

113114
def load_checkpoint(self):
114115
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))
116+
117+
118+
class SACCritic(nn.Module):
119+
def __init__(self, state_dim, action_dim, hidden_dims, lr,
120+
checkpoint_path, name):
121+
super().__init__()
122+
self.checkpoint_path = checkpoint_path
123+
self.name = name
124+
encoder = []
125+
prev_dim = state_dim + action_dim
126+
for i, dim in enumerate(hidden_dims):
127+
encoder.extend([
128+
nn.Linear(prev_dim, dim),
129+
nn.LayerNorm(dim)
130+
])
131+
if i < len(hidden_dims) - 1:
132+
encoder.append(nn.ReLU(True))
133+
prev_dim = dim
134+
self.encoder = nn.Sequential(*encoder)
135+
self.value = nn.Linear(prev_dim, 1)
136+
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
137+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
138+
self.to(self.device)
139+
140+
def forward(self, states, actions):
141+
scores = self.encoder(torch.cat([states, actions], dim=1))
142+
return self.value(scores)
143+
144+
def save_checkpoint(self):
145+
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')
146+
147+
def load_checkpoint(self):
148+
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))
149+
150+
151+
class SACValue(SACCritic):
152+
def __init__(self, state_dim, hidden_dims, lr,
153+
checkpoint_path, name):
154+
super().__init__(state_dim, 0, hidden_dims, lr,
155+
checkpoint_path, name)
156+
157+
def forward(self, states):
158+
scores = self.encoder(states)
159+
return self.value(scores)
160+
161+
162+
class SACActor(nn.Module):
163+
def __init__(self, state_dim, action_dim,
164+
hidden_dims, lr, max_action,
165+
checkpoint_path, name):
166+
super().__init__()
167+
self.log_std_min = -20
168+
self.log_std_max = 2
169+
self.epsilon = 1e-6
170+
self.checkpoint_path = checkpoint_path
171+
self.name = name
172+
self.max_action = max_action
173+
encoder = []
174+
prev_dim = state_dim
175+
for i, dim in enumerate(hidden_dims):
176+
encoder.extend([
177+
nn.Linear(prev_dim, dim),
178+
nn.LayerNorm(dim)
179+
])
180+
if i < len(hidden_dims) - 1:
181+
encoder.append(nn.ReLU(True))
182+
prev_dim = dim
183+
self.encoder = nn.Sequential(*encoder)
184+
185+
# mu & logvar for action
186+
self.actor = nn.Linear(prev_dim, action_dim * 2)
187+
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
188+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
189+
self.to(self.device)
190+
191+
def sample(self, mu, log_std, reparameterize=True):
192+
if mu.dim() == 1:
193+
mu = mu.unsqueeze(0)
194+
distribution = MultivariateNormal(mu, scale_tril=torch.diag_embed(log_std.exp()))
195+
if reparameterize:
196+
actions = distribution.rsample()
197+
else:
198+
actions = distribution.sample()
199+
log_probs = distribution.log_prob(actions)
200+
bounded_actions = torch.tanh(actions) * self.max_action
201+
bounded_log_probs = log_probs - torch.log(
202+
(1 - bounded_actions.pow(2)).clamp(0, 1) + self.epsilon).sum(dim=1)
203+
return bounded_actions.squeeze(), bounded_log_probs
204+
205+
def forward(self, states, reparameterize=True):
206+
scores = self.encoder(states)
207+
mu, log_std = self.actor(scores).split(2, dim=-1)
208+
log_std = log_std.clamp(self.log_std_min, self.log_std_max)
209+
action, log_prob = self.sample(mu, log_std, reparameterize=reparameterize)
210+
return action, log_prob
211+
212+
def save_checkpoint(self):
213+
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')
214+
215+
def load_checkpoint(self):
216+
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))

0 commit comments

Comments
 (0)