Skip to content

Commit c00e3bc

Browse files
committed
added DDPG implementation
1 parent b627774 commit c00e3bc

File tree

4 files changed

+339
-14
lines changed

4 files changed

+339
-14
lines changed

policy/agent.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import torch
3-
from policy.networks import ActorCritic
3+
from copy import deepcopy
4+
from policy.networks import ActorCritic, Actor, Critic
5+
from policy.utils import ReplayBuffer, OUActionNoise
46

57

68
class BlackJackAgent:
@@ -172,7 +174,6 @@ def __init__(self, input_dim, action_dim, hidden_dim, gamma, lr):
172174
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
173175
self.log_proba, self.value = None, None
174176

175-
176177
def choose_action(self, state):
177178
state = torch.from_numpy(state).to(self.device)
178179
self.value, action_logits = self.actor_critic(state)
@@ -186,13 +187,96 @@ def update(self, reward, state_, done):
186187
# calculate TD loss
187188
state_ = torch.from_numpy(state_).unsqueeze(0).to(self.device)
188189
value_, _ = self.actor_critic(state_)
189-
critic_loss = (reward + self.gamma * value_ * ~done - self.value).pow(2)
190+
TD_error = reward + self.gamma * value_ * ~done - self.value
191+
critic_loss = TD_error.pow(2)
190192

191193
# actor loss
192-
actor_loss = - self.value.detach() * self.log_proba
194+
actor_loss = - self.value * self.log_proba
193195

194196
# sgd + reset history
195197
loss = critic_loss + actor_loss
196198
self.optimizer.zero_grad()
197199
loss.backward()
198200
self.optimizer.step()
201+
202+
203+
class DDPGAgent:
204+
def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma,
205+
tau, critic_lr, critic_wd, actor_lr, actor_wd, batch_size,
206+
final_init, maxsize, sigma, theta, dt, checkpoint):
207+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
208+
self.gamma = gamma
209+
self.tau = tau
210+
self.batch_size = batch_size
211+
self.memory = ReplayBuffer(state_dim, action_dim, maxsize)
212+
self.noise = OUActionNoise(torch.zeros(action_dim, device=self.device),
213+
sigma=sigma,
214+
theta=theta,
215+
dt=dt)
216+
self.critic = Critic(*state_dim, *action_dim, hidden_dims, critic_lr, critic_wd,
217+
final_init, checkpoint, 'Critic')
218+
self.actor = Actor(*state_dim, *action_dim, hidden_dims, max_action,
219+
actor_lr, actor_wd, final_init, checkpoint, 'Actor')
220+
self.target_critic = deepcopy(self.critic)
221+
self.target_critic.name = 'Target_Critic'
222+
self.target_actor = deepcopy(self.actor)
223+
self.target_actor.name = 'Target_Actor'
224+
225+
def update(self):
226+
experiences = self.memory.sample_transition(self.batch_size)
227+
states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences]
228+
# calculate targets & only update online critic network
229+
with torch.no_grad():
230+
next_actions = self.target_actor(next_states)
231+
q_primes = self.target_critic(next_states, next_actions)
232+
targets = rewards + self.gamma * q_primes * (~dones)
233+
qs = self.critic(states, actions)
234+
td_error = targets - qs
235+
critic_loss = td_error.pow(2).mean()
236+
self.critic.optimizer.zero_grad()
237+
critic_loss.backward()
238+
self.critic.optimizer.step()
239+
240+
# actor loss is by maximizing Q values
241+
qs = self.critic(states, self.actor(states))
242+
actor_loss = - qs.mean()
243+
self.actor.optimizer.zero_grad()
244+
actor_loss.backward()
245+
self.actor.optimizer.step()
246+
247+
self.update_target_network(self.critic, self.target_critic)
248+
self.update_target_network(self.actor, self.target_actor)
249+
return actor_loss.item(), critic_loss.item()
250+
251+
def update_target_network(self, src, tgt):
252+
for src_weight, tgt_weight in zip(src.parameters(), tgt.parameters()):
253+
tgt_weight.data = tgt_weight.data * self.tau + src_weight.data * (1. - self.tau)
254+
255+
def save_models(self):
256+
self.critic.save_checkpoint()
257+
self.actor.save_checkpoint()
258+
self.target_critic.save_checkpoint()
259+
self.target_actor.save_checkpoint()
260+
261+
def load_models(self):
262+
self.critic.load_checkpoint()
263+
self.actor.load_checkpoint()
264+
self.target_critic.save_checkpoint()
265+
self.target_actor.save_checkpoint()
266+
267+
def choose_action(self, observation):
268+
self.actor.eval()
269+
observation = torch.from_numpy(observation).to(self.device)
270+
with torch.no_grad():
271+
mu = self.actor(observation)
272+
action = mu + self.noise()
273+
self.actor.train()
274+
return action.cpu().detach().numpy()
275+
276+
def store_transition(self, state, action, reward, next_state, done):
277+
state = torch.tensor(state)
278+
action = torch.tensor(action)
279+
reward = torch.tensor(reward)
280+
next_state = torch.tensor(next_state)
281+
done = torch.tensor(done, dtype=torch.bool)
282+
self.memory.store_transition(state, action, reward, next_state, done)

policy/lunarlander/main.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,136 @@
44
import numpy as np
55
from tqdm import tqdm
66
from collections import deque
7+
from pathlib import Path
8+
from torch.utils.tensorboard import SummaryWriter
9+
10+
from policy.utils import clip_action
711
from policy import agent as Agent
812

913

1014
parser = argparse.ArgumentParser(description='Lunar Lander Agents')
11-
parser.add_argument('--agent', type=str, default='Actor Critic', help='Agent style')
15+
# training hyperparams
16+
parser.add_argument('--agent', type=str, default='DDPG', help='Agent style')
1217
parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for')
18+
parser.add_argument('--batch_size', type=int, default=64, help='Minibatch size')
1319
parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers')
14-
parser.add_argument('--lr', '--learning_rate', type=float, default=1e-4, help='Learning rate for Adam optimizer')
20+
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers')
21+
parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic')
22+
parser.add_argument('--critic_wd', type=float, default=1e-2, help='Weight decay for Critic')
23+
parser.add_argument('--actor_lr', type=float, default=1e-4, help='Learning rate for Actor')
24+
parser.add_argument('--actor_wd', type=float, default=0., help='Weight decay for Actor')
1525
parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor')
26+
parser.add_argument('--final_init', type=float, default=3e-3, help='The range for output layer initialization')
27+
parser.add_argument('--tau', type=float, default=0.001, help='Weight of target network update')
28+
parser.add_argument('--maxsize', type=int, default=1e6, help='Size of Replay Buffer')
29+
parser.add_argument('--sigma', type=float, default=0.2, help='Sigma for UOnoise')
30+
parser.add_argument('--theta', type=float, default=0.15, help='Theta for UOnoise')
31+
parser.add_argument('--dt', type=float, default=1e-2, help='dt for UOnoise')
1632

33+
# eval params
1734
parser.add_argument('--render', action="store_true", default=False, help='Render environment while training')
1835
parser.add_argument('--window_legnth', type=int, default=100, help='Length of window to keep track scores')
36+
37+
# checkpoint + logs
38+
parser.add_argument('--checkpoint', type=str, default='policy/lunarlander/checkpoint', help='Checkpoint for model weights')
39+
parser.add_argument('--logdir', type=str, default='policy/lunarlander/logs', help='Directory to save logs')
1940
args = parser.parse_args()
2041

2142

2243
def main():
23-
env = gym.make('LunarLander-v2')
44+
env_type = 'Continuous' if args.agent in ['DDPG'] else ''
45+
env = gym.make(f'LunarLander{env_type}-v2')
2446
agent_ = getattr(Agent, args.agent.replace(' ', '') + 'Agent')
25-
agent = agent_(input_dim=env.observation_space.shape,
26-
action_dim=env.action_space.n,
27-
hidden_dim=args.hidden_dim,
28-
gamma=args.gamma,
29-
lr=args.lr)
47+
if args.agent in ['DDPG']:
48+
max_action = float(env.action_space.high[0])
49+
agent = agent_(state_dim=env.observation_space.shape,
50+
action_dim=env.action_space.shape,
51+
hidden_dims=args.hidden_dims,
52+
max_action=max_action,
53+
gamma=args.gamma,
54+
tau=args.tau,
55+
critic_lr=args.critic_lr,
56+
critic_wd=args.critic_wd,
57+
actor_lr=args.actor_lr,
58+
actor_wd=args.actor_wd,
59+
batch_size=args.batch_size,
60+
final_init=args.final_init,
61+
maxsize=int(args.maxsize),
62+
sigma=args.sigma,
63+
theta=args.theta,
64+
dt=args.dt,
65+
checkpoint=args.checkpoint)
66+
else:
67+
agent = agent_(state_dim=env.observation_space.shape,
68+
actionaction_dim_dim=env.action_space.n,
69+
hidden_dims=args.hidden_dims,
70+
gamma=args.gamma,
71+
lr=args.lr)
72+
73+
Path(args.logdir).mkdir(parents=True, exist_ok=True)
74+
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)
75+
76+
writer = SummaryWriter(args.logdir)
77+
3078
pbar = tqdm(range(args.n_episodes))
3179
score_history = deque(maxlen=args.window_legnth)
80+
best_score = env.reward_range[0]
3281
for e in pbar:
3382
done, score, observation = False, 0, env.reset()
83+
84+
# reset DDPG UO Noise and also keep track of actor/critic losses
85+
if args.agent in ['DDPG']:
86+
agent.noise.reset()
87+
actor_losses, critic_losses = [], []
3488
while not done:
3589
if args.render:
3690
env.render()
91+
3792
action = agent.choose_action(observation)
93+
# clip noised action to ensure not out of bounds
94+
if args.agent in ['DDPG']:
95+
action = clip_action(action, max_action)
3896
next_observation, reward, done, _ = env.step(action)
97+
score += reward
98+
99+
# update for td methods, recording for mc methods
39100
if args.agent == 'Actor Critic':
40101
agent.update(reward, next_observation, done)
102+
elif args.agent in ['DDPG']:
103+
agent.store_transition(observation, action, reward, next_observation, done)
104+
# if we have memory smaller than batch size, do not update
105+
if agent.memory.idx < args.batch_size:
106+
continue
107+
actor_loss, critic_loss = agent.update()
108+
actor_losses.append(actor_loss)
109+
critic_losses.append(critic_loss)
110+
pbar.set_postfix({'Reward': reward, 'Actor Loss': actor_loss, 'Critic Loss': critic_loss})
41111
else:
42112
agent.store_reward(reward)
43113
observation = next_observation
44-
score += reward
114+
115+
score_history.append(score)
116+
117+
# update for mc methods w/ full trajectory
45118
if args.agent == 'Policy Gradient':
46119
agent.update()
47-
score_history.append(score)
120+
121+
# logging & saving
122+
elif args.agent in ['DDPG']:
123+
writer.add_scalars(
124+
'Scores',
125+
{'Episodic': score, 'Windowed Average': np.mean(score_history)},
126+
global_step=e)
127+
if actor_losses:
128+
writer.add_scalars(
129+
'Losses',
130+
{'Actor': np.mean(actor_losses), 'Critic': np.mean(critic_losses)},
131+
global_step=e)
132+
actor_losses, critic_losses = [], []
133+
134+
if score > best_score:
135+
best_score = score
136+
agent.save_models()
48137
tqdm.write(
49138
f'Episode: {e + 1}/{args.n_episodes}, Score: {score}, Average Score: {np.mean(score_history)}')
50139

policy/networks.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import torch
23
from torch import nn
34

@@ -17,3 +18,97 @@ def __init__(self, input_dim, n_actions, hidden_dim):
1718
def forward(self, state):
1819
features = self.encoder(state)
1920
return self.v(features), self.pi(features)
21+
22+
23+
class Critic(nn.Module):
24+
def __init__(self, input_dim, action_dim, hidden_dims, lr, weight_decay,
25+
final_init, checkpoint_path, name):
26+
super().__init__()
27+
self.checkpoint_path = checkpoint_path
28+
self.name = name
29+
encoder = []
30+
prev_dim = input_dim
31+
for i, dim in enumerate(hidden_dims):
32+
encoder.extend([
33+
nn.Linear(prev_dim, dim),
34+
nn.LayerNorm(dim)
35+
])
36+
if i < len(hidden_dims) - 1:
37+
encoder.append(nn.ReLU(True))
38+
prev_dim = dim
39+
self.state_encoder = nn.Sequential(*encoder)
40+
self.action_encoder = nn.Sequential(nn.Linear(action_dim, prev_dim),
41+
nn.LayerNorm(prev_dim))
42+
self.q = nn.Linear(prev_dim, 1)
43+
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
44+
self._init_weights(self.q, final_init)
45+
self._init_weights(self.action_encoder, 1 / math.sqrt(action_dim))
46+
self._init_weights(self.state_encoder, 1 / math.sqrt(hidden_dims[-2]))
47+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
48+
self.to(self.device)
49+
50+
def _init_weights(self, layers, b):
51+
for m in layers.modules():
52+
if isinstance(m, (nn.Linear, nn.LayerNorm)):
53+
nn.init.uniform_(
54+
m.weight,
55+
a=-b,
56+
b=b
57+
)
58+
59+
def forward(self, states, actions):
60+
state_values = self.state_encoder(states)
61+
action_values = self.action_encoder(actions)
62+
state_action_values = nn.functional.relu(torch.add(state_values, action_values))
63+
return self.q(state_action_values)
64+
65+
def save_checkpoint(self):
66+
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')
67+
68+
def load_checkpoint(self):
69+
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))
70+
71+
72+
class Actor(nn.Module):
73+
def __init__(self, input_dim, action_dim, hidden_dims,
74+
max_action, lr, weight_decay,
75+
final_init, checkpoint_path, name):
76+
super().__init__()
77+
self.max_action = max_action
78+
self.name = name
79+
self.checkpoint_path = checkpoint_path
80+
encoder = []
81+
prev_dim = input_dim
82+
for dim in hidden_dims:
83+
encoder.extend([
84+
nn.Linear(prev_dim, dim),
85+
nn.LayerNorm(dim),
86+
nn.ReLU(True)])
87+
prev_dim = dim
88+
self.state_encoder = nn.Sequential(*encoder)
89+
self.mu = nn.Linear(prev_dim, action_dim)
90+
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
91+
self._init_weights(self.mu, final_init)
92+
self._init_weights(self.state_encoder, 1 / math.sqrt(hidden_dims[-2]))
93+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
94+
self.to(self.device)
95+
96+
def _init_weights(self, layers, b):
97+
for m in layers.modules():
98+
if isinstance(m, (nn.Linear, nn.LayerNorm)):
99+
nn.init.uniform_(
100+
m.weight,
101+
a=-b,
102+
b=b
103+
)
104+
105+
def forward(self, states):
106+
state_features = self.state_encoder(states)
107+
# bound the output action to [-max_action, max_action]
108+
return torch.tanh(self.mu(state_features)) * self.max_action
109+
110+
def save_checkpoint(self):
111+
torch.save(self.state_dict(), self.checkpoint_path + '/' + self.name + '.pth')
112+
113+
def load_checkpoint(self):
114+
self.load_state_dict(torch.load(self.checkpoint_path + '/' + self.name + '.pth'))

0 commit comments

Comments
 (0)