From 05b489f624c9102e97f481f1cabf2f28a6ed5785 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Mon, 22 Feb 2021 19:10:38 +0800 Subject: [PATCH] added td3 --- policy/agent.py | 93 ++++++++++++++++++++-- policy/{lunarlander => continuous}/main.py | 69 +++++++++++----- policy/utils.py | 15 ++++ 3 files changed, 154 insertions(+), 23 deletions(-) rename policy/{lunarlander => continuous}/main.py (66%) diff --git a/policy/agent.py b/policy/agent.py index e26a312..55f45ad 100644 --- a/policy/agent.py +++ b/policy/agent.py @@ -4,7 +4,7 @@ from copy import deepcopy from policy.networks import ActorCritic, Actor, Critic -from policy.utils import ReplayBuffer, OUActionNoise, clip_action +from policy.utils import ReplayBuffer, OUActionNoise, clip_action, GaussianActionNoise class BlackJackAgent: @@ -211,6 +211,15 @@ def __init__(self, state_dim, action_dim, hidden_dims, max_action, gamma, self.tau = tau self.batch_size = batch_size self.max_action = max_action + self.state_dim = state_dim + self.action_dim = action_dim + self.hidden_dims = hidden_dims + self.critic_lr = critic_lr + self.critic_wd = critic_wd + self.final_init = final_init + self.checkpoint = checkpoint + self.sigma = sigma + self.memory = ReplayBuffer(state_dim, action_dim, maxsize) self.noise = OUActionNoise(torch.zeros(action_dim, device=self.device), sigma=sigma, @@ -272,15 +281,18 @@ def save_models(self): def load_models(self): self.critic.load_checkpoint() self.actor.load_checkpoint() - self.target_critic.save_checkpoint() - self.target_actor.save_checkpoint() + self.target_critic.load_checkpoint() + self.target_actor.load_checkpoint() - def choose_action(self, observation): + def choose_action(self, observation, test): self.actor.eval() observation = torch.from_numpy(observation).to(self.device) with torch.no_grad(): mu = self.actor(observation) - action = mu + self.noise() + if test: + action = mu + else: + action = mu + self.noise() self.actor.train() action = action.cpu().detach().numpy() # clip noised action to ensure not out of bounds @@ -293,3 +305,74 @@ def store_transition(self, state, action, reward, next_state, done): next_state = torch.tensor(next_state) done = torch.tensor(done, dtype=torch.bool) self.memory.store_transition(state, action, reward, next_state, done) + + +class TD3Agent(DDPGAgent): + def __init__(self, *args, **kwargs): + exluded_kwargs = ['actor_update_iter', 'action_sigma', 'action_clip'] + super().__init__(*args, **{k: v for k, v in kwargs.items() if k not in exluded_kwargs}) + self.ctr = 0 + self.actor_update_iter = kwargs['actor_update_iter'] + self.action_sigma = kwargs['action_sigma'] + self.action_clip = kwargs['action_clip'] + self.noise = GaussianActionNoise(mu=0, sigma=self.sigma) + self.actor_loss = 0 + + # second pair of critic + self.critic2 = Critic(*self.state_dim, *self.action_dim, self.hidden_dims, + self.critic_lr, self.critic_wd, + self.final_init, self.checkpoint, 'Critic2') + self.target_critic2 = self.get_target_network(self.critic2) + self.target_critic2.name = 'Target_Critic2' + + def choose_action(self, observation, test): + self.actor.eval() + self.ctr += 1 + observation = torch.from_numpy(observation).to(self.device) + with torch.no_grad(): + action = self.actor(observation) + if not test: + action = action + self.noise(action.size()) + self.actor.train() + action = action.cpu().detach().numpy() + # clip noised action to ensure not out of bounds + return clip_action(action, self.max_action) + + def update(self): + experiences = self.memory.sample_transition(self.batch_size) + states, actions, rewards, next_states, dones = [data.to(self.device) for data in experiences] + + # actor loss is by maximizing Q values + if self.ctr % self.actor_update_iter == 0: + self.actor.optimizer.zero_grad() + qs = self.critic(states, self.actor(states)) + actor_loss = - qs.mean() + actor_loss.backward() + self.actor.optimizer.step() + self.actor_loss = actor_loss.item() + + self.update_target_network(self.critic, self.target_critic) + self.update_target_network(self.critic2, self.target_critic2) + self.update_target_network(self.actor, self.target_actor) + + # calculate targets & only update online critic network + self.critic.optimizer.zero_grad() + with torch.no_grad(): + # y <- r + gamma * min_(i=1,2) Q_(theta'_i)(s', a_telda) + target_actions = self.target_actor(next_states) + target_actions += self.noise( + target_actions.size(), clip=self.action_clip, sigma=self.action_sigma) + target_actions = clip_action(target_actions, self.max_action) + q_primes1 = self.target_critic(next_states, target_actions).squeeze() + q_primes2 = self.target_critic2(next_states, target_actions).squeeze() + q_primes = torch.min(q_primes1, q_primes2) + targets = rewards + self.gamma * q_primes * (~dones) + # theta_i <- argmin_(theta_i) N^(-1) sum(y - Q_(theta_i)(s, a))^2 + qs1 = self.critic(states, actions) + qs2 = self.critic2(states, actions) + critic_loss1 = F.mse_loss(targets.unsqueeze(-1), qs1) + critic_loss2 = F.mse_loss(targets.unsqueeze(-1), qs2) + critic_loss = critic_loss1 + critic_loss2 + critic_loss.backward() + self.critic.optimizer.step() + return self.actor_loss, critic_loss.item() diff --git a/policy/lunarlander/main.py b/policy/continuous/main.py similarity index 66% rename from policy/lunarlander/main.py rename to policy/continuous/main.py index fcb3ca1..9ecbd03 100644 --- a/policy/lunarlander/main.py +++ b/policy/continuous/main.py @@ -12,22 +12,25 @@ parser = argparse.ArgumentParser(description='Lunar Lander Agents') # training hyperparams -parser.add_argument('--agent', type=str, default='DDPG', help='Agent style') +parser.add_argument('--agent', type=str, default='TD3', help='Agent Algorithm') parser.add_argument('--n_episodes', type=int, default=3000, help='Number of episodes you wish to run for') -parser.add_argument('--batch_size', type=int, default=64, help='Minibatch size') +parser.add_argument('--batch_size', type=int, default=100, help='Minibatch size') parser.add_argument('--hidden_dim', type=int, default=2048, help='Hidden dimension of FC layers') parser.add_argument('--hidden_dims', type=list, default=[400, 300], help='Hidden dimensions of FC layers') parser.add_argument('--critic_lr', type=float, default=1e-3, help='Learning rate for Critic') -parser.add_argument('--critic_wd', type=float, default=1e-2, help='Weight decay for Critic') +parser.add_argument('--critic_wd', type=float, default=0., help='Weight decay for Critic') parser.add_argument('--actor_lr', type=float, default=1e-4, help='Learning rate for Actor') parser.add_argument('--actor_wd', type=float, default=0., help='Weight decay for Actor') parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor') parser.add_argument('--final_init', type=float, default=3e-3, help='The range for output layer initialization') -parser.add_argument('--tau', type=float, default=0.001, help='Weight of target network update') +parser.add_argument('--tau', type=float, default=0.005, help='Weight of target network update') parser.add_argument('--maxsize', type=int, default=1e6, help='Size of Replay Buffer') -parser.add_argument('--sigma', type=float, default=0.2, help='Sigma for UOnoise') +parser.add_argument('--sigma', type=float, default=0.1, help='Sigma for Noise') parser.add_argument('--theta', type=float, default=0.15, help='Theta for UOnoise') parser.add_argument('--dt', type=float, default=1e-2, help='dt for UOnoise') +parser.add_argument('--actor_update_iter', type=int, default=2, help='Update actor and target network every') +parser.add_argument('--action_sigma', type=float, default=0.2, help='Std of noise for actions') +parser.add_argument('--action_clip', type=float, default=0.5, help='Max action bound') # eval params parser.add_argument('--render', action="store_true", default=False, help='Render environment while training') @@ -36,19 +39,20 @@ parser.add_argument('--load_models', action="store_true", default=False, help='Load pretrained models') # checkpoint + logs -parser.add_argument('--checkpoint', type=str, default='policy/lunarlander/checkpoint', help='Checkpoint for model weights') -parser.add_argument('--logdir', type=str, default='policy/lunarlander/logs', help='Directory to save logs') +parser.add_argument('--checkpoint', type=str, default='policy/continuous/checkpoint', help='Checkpoint for model weights') +parser.add_argument('--logdir', type=str, default='policy/continuous/logs', help='Directory to save logs') args = parser.parse_args() def main(): - env_type = 'Continuous' if args.agent in ['DDPG'] else '' + env_type = 'Continuous' if args.agent in ['DDPG', 'TD3'] else '' env = gym.make(f'LunarLander{env_type}-v2') agent_ = getattr(Agent, args.agent.replace(' ', '') + 'Agent') if args.test: args.load_models = True args.render = True - if args.agent in ['DDPG']: + print(args) + if args.agent == 'DDPG': max_action = float(env.action_space.high[0]) agent = agent_(state_dim=env.observation_space.shape, action_dim=env.action_space.shape, @@ -67,6 +71,29 @@ def main(): theta=args.theta, dt=args.dt, checkpoint=args.checkpoint) + elif args.agent == 'TD3': + max_action = float(env.action_space.high[0]) + agent = agent_(state_dim=env.observation_space.shape, + action_dim=env.action_space.shape, + hidden_dims=args.hidden_dims, + max_action=max_action, + gamma=args.gamma, + tau=args.tau, + critic_lr=args.critic_lr, + critic_wd=args.critic_wd, + actor_lr=args.actor_lr, + actor_wd=args.actor_wd, + batch_size=args.batch_size, + final_init=args.final_init, + maxsize=int(args.maxsize), + sigma=args.sigma, + theta=args.theta, + dt=args.dt, + checkpoint=args.checkpoint, + actor_update_iter=args.actor_update_iter, + action_sigma=args.action_sigma, + action_clip=args.action_clip + ) else: agent = agent_(state_dim=env.observation_space.shape, actionaction_dim_dim=env.action_space.n, @@ -83,29 +110,32 @@ def main(): agent.load_models() pbar = tqdm(range(args.n_episodes)) score_history = deque(maxlen=args.window_legnth) - best_score = env.reward_range[0] + best_score = - np.inf for e in pbar: done, score, observation = False, 0, env.reset() # reset DDPG UO Noise and also keep track of actor/critic losses - if args.agent in ['DDPG']: - agent.noise.reset() + if args.agent in ['DDPG', 'TD3']: + if args.agent == 'DDPG': + agent.noise.reset() actor_losses, critic_losses = [], [] while not done: if args.render: env.render() - action = agent.choose_action(observation) + action = agent.choose_action(observation, args.test) next_observation, reward, done, _ = env.step(action) score += reward # update for td methods, recording for mc methods - if args.agent == 'Actor Critic': + if args.test: + continue + elif args.agent == 'Actor Critic': agent.update(reward, next_observation, done) - elif args.agent in ['DDPG']: + elif args.agent in ['DDPG', 'TD3']: agent.store_transition(observation, action, reward, next_observation, done) # if we have memory smaller than batch size, do not update - if agent.memory.idx < args.batch_size: + if agent.memory.idx < args.batch_size or (args.agent == 'TD3' and agent.ctr < 10000): continue actor_loss, critic_loss = agent.update() actor_losses.append(actor_loss) @@ -117,12 +147,14 @@ def main(): score_history.append(score) + if args.test: + continue # update for mc methods w/ full trajectory - if args.agent == 'Policy Gradient': + elif args.agent == 'Policy Gradient': agent.update() # logging & saving - elif args.agent in ['DDPG']: + elif args.agent in ['DDPG', 'TD3']: writer.add_scalars( 'Scores', {'Episodic': score, 'Windowed Average': np.mean(score_history)}, @@ -137,6 +169,7 @@ def main(): if np.mean(score_history) > best_score: best_score = np.mean(score_history) agent.save_models() + tqdm.write( f'Episode: {e + 1}/{args.n_episodes}, Score: {score}, Average Score: {np.mean(score_history)}') diff --git a/policy/utils.py b/policy/utils.py index 9f97ac7..d61fc2b 100644 --- a/policy/utils.py +++ b/policy/utils.py @@ -21,6 +21,21 @@ def reset(self): self.x_prev = self.x0 if self.x0 is not None else torch.zeros_like(self.mu) +class GaussianActionNoise: + def __init__(self, mu, sigma=0.2): + self.mu = mu + self.sigma = sigma + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + def __call__(self, output_dim, clip=None, sigma=None): + if sigma is None: + sigma = self.sigma + noise = torch.randn(*output_dim) * sigma + self.mu + if clip is not None: + noise.clip(-clip, clip) + return noise + + class ReplayBuffer: def __init__(self, state_dim, action_dim, maxsize): self.states = torch.empty(maxsize, *state_dim)