-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added dqn agent & networks, DDQN, Dueling Networks. TODO: Prioritized…
… ER, Rainbow
- Loading branch information
1 parent
c12979c
commit f404b30
Showing
6 changed files
with
588 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import torch | ||
from torch import optim | ||
from copy import deepcopy | ||
from qlearning.networks import QNaive, QBasic, QDueling | ||
from qlearning.experience_replay import ReplayBuffer | ||
|
||
|
||
class BaseAgent: | ||
def __init__(self, state_dim, n_actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes): | ||
self.actions = list(range(n_actions)) | ||
self.n_actions = n_actions | ||
if isinstance(state_dim, int): | ||
self.states = list(range(state_dim)) | ||
self.state_dim = state_dim | ||
self.epsilon = epsilon_init | ||
self.epsilon_min = epsilon_min | ||
self.epsilon_desc = epsilon_desc | ||
self.gamma = gamma | ||
self.alpha = alpha | ||
self.n_episodes = n_episodes | ||
|
||
def epsilon_greedy(self, state): | ||
if np.random.random() > self.epsilon: | ||
return self.greedy_action(state) | ||
return np.random.choice(self.actions) | ||
|
||
def greedy_action(self, state): | ||
raise NotImplementedError | ||
|
||
|
||
class TabularAgent(BaseAgent): | ||
def __init__(self, states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes): | ||
super().__init__(states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes) | ||
# initialize table with 0 Q-values | ||
self.q_table = pd.DataFrame(np.zeros((self.state_dim, self.n_actions)), | ||
index=states, columns=actions) | ||
|
||
def greedy_action(self, state): | ||
Qs = self.q_table.loc[state] | ||
return Qs.argmax() | ||
|
||
def update(self, state, action, reward, next_state): | ||
# update Q-table | ||
max_Q_ = self.q_table.loc[next_state].max() | ||
Q_sa = self.q_table.loc[state, action] | ||
self.q_table.loc[state, action] += self.alpha * ( | ||
reward + self.gamma * max_Q_ - Q_sa) | ||
# update epsilon | ||
self.decrease_epsilon() | ||
|
||
def decrease_epsilon(self): | ||
self.epsilon = max( | ||
self.epsilon_min, | ||
self.epsilon * self.epsilon_desc) | ||
|
||
|
||
class NaiveNeuralAgent(BaseAgent): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args) | ||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
if kwargs['policy'] == 'naive neural': | ||
self.Q_function = QNaive( | ||
kwargs['state_dim'], | ||
kwargs['action_dim'], | ||
kwargs['hidden_dim'], | ||
self.n_actions).to(self.device) | ||
self.optimizer = optim.Adam(self.Q_function.parameters(), self.alpha) | ||
self.criterion = torch.nn.MSELoss() | ||
|
||
def number2tensor(self, number): | ||
return torch.tensor([number]).to(self.device) | ||
|
||
def greedy_action(self, state): | ||
state = self.number2tensor(state) | ||
next_action = self.Q_function(state).argmax() | ||
return next_action.item() | ||
|
||
def update(self, state, action, reward, next_state): | ||
q_prime = self.Q_function(next_state).max() | ||
q_target = torch.tensor([reward + self.gamma * q_prime]).to(self.device) | ||
q_pred = self.Q_function(state)[action] | ||
loss = self.criterion(q_target, q_pred) | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
def decrease_epsilon(self): | ||
self.epsilon = max( | ||
self.epsilon_min, | ||
self.epsilon - self.epsilon_desc) | ||
|
||
|
||
class DQNAgent(BaseAgent): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args) | ||
self.algorithm = kwargs['algorithm'] | ||
self.batch_size = kwargs['batch_size'] | ||
self.grad_clip = kwargs['grad_clip'] | ||
self.prioritize = kwargs['prioritize'] | ||
self.memory = ReplayBuffer(kwargs['max_size'], self.state_dim) | ||
self.target_update_interval = kwargs['target_update_interval'] | ||
self.n_updates = 0 | ||
|
||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
if self.algorithm.startswith('Dueling'): | ||
self.Q_function = QDueling( | ||
kwargs['input_channels'], | ||
self.n_actions, | ||
kwargs['cpt_dir'], | ||
kwargs['algorithm'] + '_' + kwargs['env_name'], | ||
kwargs['img_size'], | ||
kwargs['hidden_dim']).to(self.device) | ||
else: | ||
self.Q_function = QBasic( | ||
kwargs['input_channels'], | ||
self.n_actions, | ||
kwargs['cpt_dir'], | ||
kwargs['algorithm'] + '_' + kwargs['env_name'], | ||
kwargs['img_size'], | ||
kwargs['hidden_dim']).to(self.device) | ||
|
||
# instanciate target network | ||
self.target_Q = deepcopy(self.Q_function) | ||
self.freeze_network(self.target_Q) | ||
self.target_Q.name = kwargs['algorithm'] + '_' + kwargs['env_name'] + '_target' | ||
|
||
self.optimizer = torch.optim.RMSprop(self.Q_function.parameters(), lr=self.alpha, alpha=0.95) | ||
self.criterion = torch.nn.MSELoss() | ||
|
||
def greedy_action(self, observation): | ||
observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device) | ||
next_action = self.Q_function(observation).argmax() | ||
return next_action.item() | ||
|
||
def update_target_network(self): | ||
self.copy_network_weights(self.Q_function, self.target_Q) | ||
|
||
def copy_network_weights(self, src_network, tgt_network): | ||
tgt_network.load_state_dict(src_network.state_dict()) | ||
|
||
def freeze_network(self, network): | ||
for p in network.parameters(): | ||
p.requires_grad = False | ||
|
||
def update(self): | ||
# Q_t = Q_t + alpha * (reward + gamma * Q'_t - Q^target_t) ** 2 | ||
# keep sampling until we have full batch | ||
if self.memory.ctr < self.batch_size: | ||
return | ||
self.optimizer.zero_grad() | ||
observations, rewards, actions, next_observations, dones = self.sample_transitions() | ||
|
||
# double DQN use online network to select action for target | ||
if self.algorithm.endswith('DDQN'): | ||
next_actions = self.Q_function(next_observations).argmax(-1) | ||
q_prime = self.target_Q(next_observations)[list(range(self.batch_size)), next_actions] | ||
elif self.algorithm.endswith('DQN'): | ||
q_prime = self.target_Q(next_observations).max(-1)[0] | ||
q_target = rewards.to(self.device) + self.gamma * q_prime * (~dones) | ||
q_pred = self.Q_function(observations)[list(range(self.batch_size)), actions] | ||
loss = self.criterion(q_target, q_pred) | ||
loss.backward() | ||
if self.grad_clip is not None: | ||
torch.nn.utils.clip_grad_norm_(self.Q_function.parameters(), self.grad_clip) | ||
self.optimizer.step() | ||
self.decrease_epsilon() | ||
self.n_updates += 1 | ||
if self.n_updates % self.target_update_interval == 0: | ||
self.update_target_network() | ||
|
||
priorities = None | ||
# if self.prioritize: | ||
# # TODO | ||
# pass | ||
# else: | ||
# priorities = None | ||
return priorities | ||
|
||
def decrease_epsilon(self): | ||
self.epsilon = max( | ||
self.epsilon_min, | ||
self.epsilon - self.epsilon_desc) | ||
|
||
def store_transition(self, state, reward, action, next_state, done, priority=None): | ||
state, next_state = torch.from_numpy(state), torch.from_numpy(next_state) | ||
self.memory.store(state, reward, action, next_state, done, priority=priority) | ||
|
||
def sample_transitions(self): | ||
return self.memory.sample(self.batch_size, self.device) | ||
|
||
def save_models(self): | ||
self.target_Q.check_point() | ||
self.Q_function.check_point() | ||
|
||
def load_models(self): | ||
self.target_Q.load_checkpoint() | ||
self.Q_function.load_checkpoint() | ||
self.target_Q.to(self.device) | ||
self.Q_function.to(self.device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
from tqdm import tqdm | ||
import argparse | ||
from collections import deque | ||
from pathlib import Path | ||
from gym import wrappers | ||
# from ple.games.flappybird import FlappyBird | ||
|
||
from qlearning.agent import DQNAgent | ||
from qlearning.atari.utils import processed_atari | ||
|
||
|
||
parser = argparse.ArgumentParser(description='Q Learning Atari Agents') | ||
parser.add_argument('--algorithm', type=str, default='DQN', help='The type of algorithm you wish to use.\nDQN\n \ | ||
DDQN\n \ | ||
DuelingDQN\n \ | ||
DuelingDDQN') | ||
|
||
# environment & data | ||
parser.add_argument('--env_name', type=str, default='PongNoFrameskip-v4', help='Atari environment.\nPongNoFrameskip-v4\n \ | ||
BreakoutNoFrameskip-v4\n \ | ||
SpaceInvadersNoFrameskip-v4\n \ | ||
EnduroNoFrameskip-v4\n \ | ||
AtlantisNoFrameskip-v4\n \ | ||
BankHeistNoFrameskip-v4') | ||
parser.add_argument('--n_repeats', type=int, default=4, help='The number of repeated actions') | ||
parser.add_argument('--img_size', type=int, default=84, help='The height and width of images after resizing') | ||
parser.add_argument('--input_channels', type=int, default=1, help='The input channels after preprocessing') | ||
parser.add_argument('--hidden_dim', type=int, default=512, help='The hidden size for second fc layer') | ||
parser.add_argument('--max_size', type=int, default=100000, help='Buffer size') | ||
parser.add_argument('--target_update_interval', type=int, default=1000, help='Interval for updating target network') | ||
|
||
# training | ||
parser.add_argument('-e', '--n_episodes', '--epochs', type=int, default=1000, help='Number of episodes agent interacts with env') | ||
parser.add_argument('--alpha', type=float, default=0.0001, help='Learning rate') | ||
parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor') | ||
parser.add_argument('--prioritize', action="store_false", default=True, help='Use Prioritized Experience Replay') | ||
parser.add_argument('--epsilon_init', type=float, default=1.0, help='Initial epsilon value') | ||
parser.add_argument('--epsilon_min', type=float, default=0.1, help='Minimum epsilon value to decay to') | ||
parser.add_argument('--epsilon_desc', type=float, default=1e-5, help='Epsilon decrease') | ||
parser.add_argument('--grad_clip', type=float, default=10, help='Norm of the grad clip, None for no clip') | ||
parser.add_argument('-b', '--batch_size', type=int, default=32, help='Batch size') | ||
|
||
# logging | ||
parser.add_argument('--progress_window', type=int, default=100, help='Window of episodes for progress') | ||
parser.add_argument('--print_every', type=int, default=1, help='Print progress interval') | ||
parser.add_argument('--cpt_dir', type=str, default='qlearning/atari/model_weights', help='Directory to save model weights') | ||
parser.add_argument('--log_dir', type=str, default='qlearning/atari/logs', help='Directory to submit logs') | ||
|
||
# testing | ||
parser.add_argument('--test', action="store_true", default=False, help='Testing + rendering') | ||
parser.add_argument('--video', action="store_false", default=True, help='Output video files if testing') | ||
parser.add_argument('--video_dir', type=str, default='qlearning/atari/videos', help='Directory for agent playing videos') | ||
args = parser.parse_args() | ||
|
||
writer = SummaryWriter(args.log_dir) | ||
Path(args.log_dir).mkdir(parents=True, exist_ok=True) | ||
Path(args.cpt_dir).mkdir(parents=True, exist_ok=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
env = processed_atari(args.env_name, args.img_size, args.input_channels, args.n_repeats) | ||
|
||
# if testing agent and want to output videos, make dir & wrap env to auto output video files | ||
if args.test and args.video: | ||
Path(args.video_dir).mkdir(parents=True, exist_ok=True) | ||
env = wrappers.Monitor(env, args.video_dir, | ||
video_callable=lambda episode_id: True, | ||
force=True) | ||
if 'DQN' in args.algorithm: | ||
agent = DQNAgent(env.observation_space.shape, | ||
env.action_space.n, | ||
args.epsilon_init, args.epsilon_min, args.epsilon_desc, | ||
args.gamma, args.alpha, args.n_episodes, | ||
input_channels=args.input_channels, | ||
algorithm=args.algorithm, | ||
img_size=args.img_size, | ||
hidden_dim=args.hidden_dim, | ||
max_size=args.max_size, | ||
target_update_interval=args.target_update_interval, | ||
batch_size=args.batch_size, | ||
cpt_dir=args.cpt_dir, | ||
grad_clip=args.grad_clip, | ||
prioritize=args.prioritize, | ||
env_name=args.env_name) | ||
else: | ||
raise NotImplementedError | ||
scores, best_score = deque(maxlen=args.progress_window), -np.inf | ||
|
||
# load weights & make sure model in eval mode during test | ||
if args.test: | ||
agent.load_models() | ||
agent.Q_function.eval() | ||
pbar = tqdm(range(args.n_episodes)) | ||
for e in pbar: | ||
|
||
# reset every episode and make sure functions are in training mode | ||
done, score, observation = False, 0, env.reset() | ||
agent.Q_function.train() | ||
while not done: | ||
# if test, only take greedy action, if not epsilon greedy | ||
if args.test: | ||
action = agent.greedy_action(observation) | ||
env.render() | ||
else: | ||
action = agent.epsilon_greedy(observation) | ||
next_observation, reward, done, info = env.step(action) | ||
|
||
# only update parameters during training | ||
if not args.test: | ||
priorities = agent.update() | ||
agent.store_transition(observation, reward, action, next_observation, done, priorities) | ||
score += reward | ||
observation = next_observation | ||
|
||
# logging | ||
writer.add_scalars('Performance and training', {'Score': score, 'Epsilon': agent.epsilon}) | ||
scores.append(score) | ||
if score > best_score and not args.test: | ||
agent.save_models() | ||
best_score = score | ||
if (e + 1) % args.print_every == 0: | ||
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, Average Score: {np.mean(scores)}, Best Score {best_score}, Epsilon: {agent.epsilon}') |
Oops, something went wrong.