Skip to content

Commit

Permalink
added td3
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Feb 22, 2021
1 parent 2bb03c9 commit 05b489f
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 23 deletions.
93 changes: 88 additions & 5 deletions policy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy

from policy.networks import ActorCritic, Actor, Critic
from policy.utils import ReplayBuffer, OUActionNoise, clip_action
from policy.utils import ReplayBuffer, OUActionNoise, clip_action, GaussianActionNoise


class BlackJackAgent:
Expand Down Expand Up @@ -211,6 +211,15 @@ def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma,
self.tau = tau
self.batch_size = batch_size
self.max_action = max_action
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_dims = hidden_dims
self.critic_lr = critic_lr
self.critic_wd = critic_wd
self.final_init = final_init
self.checkpoint = checkpoint
self.sigma = sigma

self.memory = ReplayBuffer(state_dim, action_dim, maxsize)
self.noise = OUActionNoise(torch.zeros(action_dim, device=self.device),
sigma=sigma,
Expand Down Expand Up @@ -272,15 +281,18 @@ def save_models(self):
def load_models(self):
self.critic.load_checkpoint()
self.actor.load_checkpoint()
self.target_critic.save_checkpoint()
self.target_actor.save_checkpoint()
self.target_critic.load_checkpoint()
self.target_actor.load_checkpoint()

def choose_action(self, observation):
def choose_action(self, observation, test):
self.actor.eval()
observation = torch.from_numpy(observation).to(self.device)
with torch.no_grad():
mu = self.actor(observation)
action = mu + self.noise()
if test:
action = mu
else:
action = mu + self.noise()
self.actor.train()
action = action.cpu().detach().numpy()
# clip noised action to ensure not out of bounds
Expand All @@ -293,3 +305,74 @@ def store_transition(self, state, action, reward, next_state, done):
next_state = torch.tensor(next_state)
done = torch.tensor(done, dtype=torch.bool)
self.memory.store_transition(state, action, reward, next_state, done)


class TD3Agent(DDPGAgent):
def __init__(self, *args, **kwargs):
exluded_kwargs = ['actor_update_iter', 'action_sigma', 'action_clip']
super().__init__(*args, **{k: v for k, v in kwargs.items() if k not in exluded_kwargs})
self.ctr = 0
self.actor_update_iter = kwargs['actor_update_iter']
self.action_sigma = kwargs['action_sigma']
self.action_clip = kwargs['action_clip']
self.noise = GaussianActionNoise(mu=0, sigma=self.sigma)
self.actor_loss = 0

# second pair of critic
self.critic2 = Critic(*self.state_dim, *self.action_dim, self.hidden_dims,
self.critic_lr, self.critic_wd,
self.final_init, self.checkpoint, 'Critic2')
self.target_critic2 = self.get_target_network(self.critic2)
self.target_critic2.name = 'Target_Critic2'

def choose_action(self, observation, test):
self.actor.eval()
self.ctr += 1
observation = torch.from_numpy(observation).to(self.device)
with torch.no_grad():
action = self.actor(observation)
if not test:
action = action + self.noise(action.size())
self.actor.train()
action = action.cpu().detach().numpy()
# clip noised action to ensure not out of bounds
return clip_action(action, self.max_action)

def update(self):
experiences = self.memory.sample_transition(self.batch_size)
states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences]

# actor loss is by maximizing Q values
if self.ctr % self.actor_update_iter == 0:
self.actor.optimizer.zero_grad()
qs = self.critic(states, self.actor(states))
actor_loss = - qs.mean()
actor_loss.backward()
self.actor.optimizer.step()
self.actor_loss = actor_loss.item()

self.update_target_network(self.critic, self.target_critic)
self.update_target_network(self.critic2, self.target_critic2)
self.update_target_network(self.actor, self.target_actor)

# calculate targets & only update online critic network
self.critic.optimizer.zero_grad()
with torch.no_grad():
# y <- r + gamma * min_(i=1,2) Q_(theta'_i)(s', a_telda)
target_actions = self.target_actor(next_states)
target_actions += self.noise(
target_actions.size(), clip=self.action_clip, sigma=self.action_sigma)
target_actions = clip_action(target_actions, self.max_action)
q_primes1 = self.target_critic(next_states, target_actions).squeeze()
q_primes2 = self.target_critic2(next_states, target_actions).squeeze()
q_primes = torch.min(q_primes1, q_primes2)
targets = rewards + self.gamma * q_primes * (~dones)
# theta_i <- argmin_(theta_i) N^(-1) sum(y - Q_(theta_i)(s, a))^2
qs1 = self.critic(states, actions)
qs2 = self.critic2(states, actions)
critic_loss1 = F.mse_loss(targets.unsqueeze(-1), qs1)
critic_loss2 = F.mse_loss(targets.unsqueeze(-1), qs2)
critic_loss = critic_loss1 + critic_loss2
critic_loss.backward()
self.critic.optimizer.step()
return self.actor_loss, critic_loss.item()
69 changes: 51 additions & 18 deletions policy/lunarlander/main.py → policy/continuous/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@

parser = argparse.ArgumentParser(description='Lunar Lander Agents')
# training hyperparams
parser.add_argument('--agent', type=str, default='DDPG', help='Agent style')
parser.add_argument('--agent', type=str, default='TD3', help='Agent Algorithm')
parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for')
parser.add_argument('--batch_size', type=int, default=64, help='Minibatch size')
parser.add_argument('--batch_size', type=int, default=100, help='Minibatch size')
parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers')
parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers')
parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic')
parser.add_argument('--critic_wd', type=float, default=1e-2, help='Weight decay for Critic')
parser.add_argument('--critic_wd', type=float, default=0., help='Weight decay for Critic')
parser.add_argument('--actor_lr', type=float, default=1e-4, help='Learning rate for Actor')
parser.add_argument('--actor_wd', type=float, default=0., help='Weight decay for Actor')
parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor')
parser.add_argument('--final_init', type=float, default=3e-3, help='The range for output layer initialization')
parser.add_argument('--tau', type=float, default=0.001, help='Weight of target network update')
parser.add_argument('--tau', type=float, default=0.005, help='Weight of target network update')
parser.add_argument('--maxsize', type=int, default=1e6, help='Size of Replay Buffer')
parser.add_argument('--sigma', type=float, default=0.2, help='Sigma for UOnoise')
parser.add_argument('--sigma', type=float, default=0.1, help='Sigma for Noise')
parser.add_argument('--theta', type=float, default=0.15, help='Theta for UOnoise')
parser.add_argument('--dt', type=float, default=1e-2, help='dt for UOnoise')
parser.add_argument('--actor_update_iter', type=int, default=2, help='Update actor and target network every')
parser.add_argument('--action_sigma', type=float, default=0.2, help='Std of noise for actions')
parser.add_argument('--action_clip', type=float, default=0.5, help='Max action bound')

# eval params
parser.add_argument('--render', action="store_true", default=False, help='Render environment while training')
Expand All @@ -36,19 +39,20 @@
parser.add_argument('--load_models', action="store_true", default=False, help='Load pretrained models')

# checkpoint + logs
parser.add_argument('--checkpoint', type=str, default='policy/lunarlander/checkpoint', help='Checkpoint for model weights')
parser.add_argument('--logdir', type=str, default='policy/lunarlander/logs', help='Directory to save logs')
parser.add_argument('--checkpoint', type=str, default='policy/continuous/checkpoint', help='Checkpoint for model weights')
parser.add_argument('--logdir', type=str, default='policy/continuous/logs', help='Directory to save logs')
args = parser.parse_args()


def main():
env_type = 'Continuous' if args.agent in ['DDPG'] else ''
env_type = 'Continuous' if args.agent in ['DDPG', 'TD3'] else ''
env = gym.make(f'LunarLander{env_type}-v2')
agent_ = getattr(Agent, args.agent.replace(' ', '') + 'Agent')
if args.test:
args.load_models = True
args.render = True
if args.agent in ['DDPG']:
print(args)
if args.agent == 'DDPG':
max_action = float(env.action_space.high[0])
agent = agent_(state_dim=env.observation_space.shape,
action_dim=env.action_space.shape,
Expand All @@ -67,6 +71,29 @@ def main():
theta=args.theta,
dt=args.dt,
checkpoint=args.checkpoint)
elif args.agent == 'TD3':
max_action = float(env.action_space.high[0])
agent = agent_(state_dim=env.observation_space.shape,
action_dim=env.action_space.shape,
hidden_dims=args.hidden_dims,
max_action=max_action,
gamma=args.gamma,
tau=args.tau,
critic_lr=args.critic_lr,
critic_wd=args.critic_wd,
actor_lr=args.actor_lr,
actor_wd=args.actor_wd,
batch_size=args.batch_size,
final_init=args.final_init,
maxsize=int(args.maxsize),
sigma=args.sigma,
theta=args.theta,
dt=args.dt,
checkpoint=args.checkpoint,
actor_update_iter=args.actor_update_iter,
action_sigma=args.action_sigma,
action_clip=args.action_clip
)
else:
agent = agent_(state_dim=env.observation_space.shape,
actionaction_dim_dim=env.action_space.n,
Expand All @@ -83,29 +110,32 @@ def main():
agent.load_models()
pbar = tqdm(range(args.n_episodes))
score_history = deque(maxlen=args.window_legnth)
best_score = env.reward_range[0]
best_score = - np.inf
for e in pbar:
done, score, observation = False, 0, env.reset()

# reset DDPG UO Noise and also keep track of actor/critic losses
if args.agent in ['DDPG']:
agent.noise.reset()
if args.agent in ['DDPG', 'TD3']:
if args.agent == 'DDPG':
agent.noise.reset()
actor_losses, critic_losses = [], []
while not done:
if args.render:
env.render()

action = agent.choose_action(observation)
action = agent.choose_action(observation, args.test)
next_observation, reward, done, _ = env.step(action)
score += reward

# update for td methods, recording for mc methods
if args.agent == 'Actor Critic':
if args.test:
continue
elif args.agent == 'Actor Critic':
agent.update(reward, next_observation, done)
elif args.agent in ['DDPG']:
elif args.agent in ['DDPG', 'TD3']:
agent.store_transition(observation, action, reward, next_observation, done)
# if we have memory smaller than batch size, do not update
if agent.memory.idx < args.batch_size:
if agent.memory.idx < args.batch_size or (args.agent == 'TD3' and agent.ctr < 10000):
continue
actor_loss, critic_loss = agent.update()
actor_losses.append(actor_loss)
Expand All @@ -117,12 +147,14 @@ def main():

score_history.append(score)

if args.test:
continue
# update for mc methods w/ full trajectory
if args.agent == 'Policy Gradient':
elif args.agent == 'Policy Gradient':
agent.update()

# logging & saving
elif args.agent in ['DDPG']:
elif args.agent in ['DDPG', 'TD3']:
writer.add_scalars(
'Scores',
{'Episodic': score, 'Windowed Average': np.mean(score_history)},
Expand All @@ -137,6 +169,7 @@ def main():
if np.mean(score_history) > best_score:
best_score = np.mean(score_history)
agent.save_models()

tqdm.write(
f'Episode: {e + 1}/{args.n_episodes}, Score: {score}, Average Score: {np.mean(score_history)}')

Expand Down
15 changes: 15 additions & 0 deletions policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@ def reset(self):
self.x_prev = self.x0 if self.x0 is not None else torch.zeros_like(self.mu)


class GaussianActionNoise:
def __init__(self, mu, sigma=0.2):
self.mu = mu
self.sigma = sigma
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def __call__(self, output_dim, clip=None, sigma=None):
if sigma is None:
sigma = self.sigma
noise = torch.randn(*output_dim) * sigma + self.mu
if clip is not None:
noise.clip(-clip, clip)
return noise


class ReplayBuffer:
def __init__(self, state_dim, action_dim, maxsize):
self.states = torch.empty(maxsize, *state_dim)
Expand Down

0 comments on commit 05b489f

Please sign in to comment.