Skip to content

Commit

Permalink
added dqn agent & networks, DDQN, Dueling Networks. TODO: Prioritized…
Browse files Browse the repository at this point in the history
… ER, Rainbow
  • Loading branch information
Andrewzh112 committed Feb 5, 2021
1 parent c12979c commit f404b30
Show file tree
Hide file tree
Showing 6 changed files with 588 additions and 14 deletions.
201 changes: 201 additions & 0 deletions qlearning/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import pandas as pd
import numpy as np
import torch
from torch import optim
from copy import deepcopy
from qlearning.networks import QNaive, QBasic, QDueling
from qlearning.experience_replay import ReplayBuffer


class BaseAgent:
def __init__(self, state_dim, n_actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes):
self.actions = list(range(n_actions))
self.n_actions = n_actions
if isinstance(state_dim, int):
self.states = list(range(state_dim))
self.state_dim = state_dim
self.epsilon = epsilon_init
self.epsilon_min = epsilon_min
self.epsilon_desc = epsilon_desc
self.gamma = gamma
self.alpha = alpha
self.n_episodes = n_episodes

def epsilon_greedy(self, state):
if np.random.random() > self.epsilon:
return self.greedy_action(state)
return np.random.choice(self.actions)

def greedy_action(self, state):
raise NotImplementedError


class TabularAgent(BaseAgent):
def __init__(self, states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes):
super().__init__(states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes)
# initialize table with 0 Q-values
self.q_table = pd.DataFrame(np.zeros((self.state_dim, self.n_actions)),
index=states, columns=actions)

def greedy_action(self, state):
Qs = self.q_table.loc[state]
return Qs.argmax()

def update(self, state, action, reward, next_state):
# update Q-table
max_Q_ = self.q_table.loc[next_state].max()
Q_sa = self.q_table.loc[state, action]
self.q_table.loc[state, action] += self.alpha * (
reward + self.gamma * max_Q_ - Q_sa)
# update epsilon
self.decrease_epsilon()

def decrease_epsilon(self):
self.epsilon = max(
self.epsilon_min,
self.epsilon * self.epsilon_desc)


class NaiveNeuralAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args)
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if kwargs['policy'] == 'naive neural':
self.Q_function = QNaive(
kwargs['state_dim'],
kwargs['action_dim'],
kwargs['hidden_dim'],
self.n_actions).to(self.device)
self.optimizer = optim.Adam(self.Q_function.parameters(), self.alpha)
self.criterion = torch.nn.MSELoss()

def number2tensor(self, number):
return torch.tensor([number]).to(self.device)

def greedy_action(self, state):
state = self.number2tensor(state)
next_action = self.Q_function(state).argmax()
return next_action.item()

def update(self, state, action, reward, next_state):
q_prime = self.Q_function(next_state).max()
q_target = torch.tensor([reward + self.gamma * q_prime]).to(self.device)
q_pred = self.Q_function(state)[action]
loss = self.criterion(q_target, q_pred)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

def decrease_epsilon(self):
self.epsilon = max(
self.epsilon_min,
self.epsilon - self.epsilon_desc)


class DQNAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args)
self.algorithm = kwargs['algorithm']
self.batch_size = kwargs['batch_size']
self.grad_clip = kwargs['grad_clip']
self.prioritize = kwargs['prioritize']
self.memory = ReplayBuffer(kwargs['max_size'], self.state_dim)
self.target_update_interval = kwargs['target_update_interval']
self.n_updates = 0

self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if self.algorithm.startswith('Dueling'):
self.Q_function = QDueling(
kwargs['input_channels'],
self.n_actions,
kwargs['cpt_dir'],
kwargs['algorithm'] + '_' + kwargs['env_name'],
kwargs['img_size'],
kwargs['hidden_dim']).to(self.device)
else:
self.Q_function = QBasic(
kwargs['input_channels'],
self.n_actions,
kwargs['cpt_dir'],
kwargs['algorithm'] + '_' + kwargs['env_name'],
kwargs['img_size'],
kwargs['hidden_dim']).to(self.device)

# instanciate target network
self.target_Q = deepcopy(self.Q_function)
self.freeze_network(self.target_Q)
self.target_Q.name = kwargs['algorithm'] + '_' + kwargs['env_name'] + '_target'

self.optimizer = torch.optim.RMSprop(self.Q_function.parameters(), lr=self.alpha, alpha=0.95)
self.criterion = torch.nn.MSELoss()

def greedy_action(self, observation):
observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device)
next_action = self.Q_function(observation).argmax()
return next_action.item()

def update_target_network(self):
self.copy_network_weights(self.Q_function, self.target_Q)

def copy_network_weights(self, src_network, tgt_network):
tgt_network.load_state_dict(src_network.state_dict())

def freeze_network(self, network):
for p in network.parameters():
p.requires_grad = False

def update(self):
# Q_t = Q_t + alpha * (reward + gamma * Q'_t - Q^target_t) ** 2
# keep sampling until we have full batch
if self.memory.ctr < self.batch_size:
return
self.optimizer.zero_grad()
observations, rewards, actions, next_observations, dones = self.sample_transitions()

# double DQN use online network to select action for target
if self.algorithm.endswith('DDQN'):
next_actions = self.Q_function(next_observations).argmax(-1)
q_prime = self.target_Q(next_observations)[list(range(self.batch_size)), next_actions]
elif self.algorithm.endswith('DQN'):
q_prime = self.target_Q(next_observations).max(-1)[0]
q_target = rewards.to(self.device) + self.gamma * q_prime * (~dones)
q_pred = self.Q_function(observations)[list(range(self.batch_size)), actions]
loss = self.criterion(q_target, q_pred)
loss.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(self.Q_function.parameters(), self.grad_clip)
self.optimizer.step()
self.decrease_epsilon()
self.n_updates += 1
if self.n_updates % self.target_update_interval == 0:
self.update_target_network()

priorities = None
# if self.prioritize:
# # TODO
# pass
# else:
# priorities = None
return priorities

def decrease_epsilon(self):
self.epsilon = max(
self.epsilon_min,
self.epsilon - self.epsilon_desc)

def store_transition(self, state, reward, action, next_state, done, priority=None):
state, next_state = torch.from_numpy(state), torch.from_numpy(next_state)
self.memory.store(state, reward, action, next_state, done, priority=priority)

def sample_transitions(self):
return self.memory.sample(self.batch_size, self.device)

def save_models(self):
self.target_Q.check_point()
self.Q_function.check_point()

def load_models(self):
self.target_Q.load_checkpoint()
self.Q_function.load_checkpoint()
self.target_Q.to(self.device)
self.Q_function.to(self.device)
124 changes: 124 additions & 0 deletions qlearning/atari/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import argparse
from collections import deque
from pathlib import Path
from gym import wrappers
# from ple.games.flappybird import FlappyBird

from qlearning.agent import DQNAgent
from qlearning.atari.utils import processed_atari


parser = argparse.ArgumentParser(description='Q Learning Atari Agents')
parser.add_argument('--algorithm', type=str, default='DQN', help='The type of algorithm you wish to use.\nDQN\n \
DDQN\n \
DuelingDQN\n \
DuelingDDQN')

# environment & data
parser.add_argument('--env_name', type=str, default='PongNoFrameskip-v4', help='Atari environment.\nPongNoFrameskip-v4\n \
BreakoutNoFrameskip-v4\n \
SpaceInvadersNoFrameskip-v4\n \
EnduroNoFrameskip-v4\n \
AtlantisNoFrameskip-v4\n \
BankHeistNoFrameskip-v4')
parser.add_argument('--n_repeats', type=int, default=4, help='The number of repeated actions')
parser.add_argument('--img_size', type=int, default=84, help='The height and width of images after resizing')
parser.add_argument('--input_channels', type=int, default=1, help='The input channels after preprocessing')
parser.add_argument('--hidden_dim', type=int, default=512, help='The hidden size for second fc layer')
parser.add_argument('--max_size', type=int, default=100000, help='Buffer size')
parser.add_argument('--target_update_interval', type=int, default=1000, help='Interval for updating target network')

# training
parser.add_argument('-e', '--n_episodes', '--epochs', type=int, default=1000, help='Number of episodes agent interacts with env')
parser.add_argument('--alpha', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor')
parser.add_argument('--prioritize', action="store_false", default=True, help='Use Prioritized Experience Replay')
parser.add_argument('--epsilon_init', type=float, default=1.0, help='Initial epsilon value')
parser.add_argument('--epsilon_min', type=float, default=0.1, help='Minimum epsilon value to decay to')
parser.add_argument('--epsilon_desc', type=float, default=1e-5, help='Epsilon decrease')
parser.add_argument('--grad_clip', type=float, default=10, help='Norm of the grad clip, None for no clip')
parser.add_argument('-b', '--batch_size', type=int, default=32, help='Batch size')

# logging
parser.add_argument('--progress_window', type=int, default=100, help='Window of episodes for progress')
parser.add_argument('--print_every', type=int, default=1, help='Print progress interval')
parser.add_argument('--cpt_dir', type=str, default='qlearning/atari/model_weights', help='Directory to save model weights')
parser.add_argument('--log_dir', type=str, default='qlearning/atari/logs', help='Directory to submit logs')

# testing
parser.add_argument('--test', action="store_true", default=False, help='Testing + rendering')
parser.add_argument('--video', action="store_false", default=True, help='Output video files if testing')
parser.add_argument('--video_dir', type=str, default='qlearning/atari/videos', help='Directory for agent playing videos')
args = parser.parse_args()

writer = SummaryWriter(args.log_dir)
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
Path(args.cpt_dir).mkdir(parents=True, exist_ok=True)


if __name__ == '__main__':
env = processed_atari(args.env_name, args.img_size, args.input_channels, args.n_repeats)

# if testing agent and want to output videos, make dir & wrap env to auto output video files
if args.test and args.video:
Path(args.video_dir).mkdir(parents=True, exist_ok=True)
env = wrappers.Monitor(env, args.video_dir,
video_callable=lambda episode_id: True,
force=True)
if 'DQN' in args.algorithm:
agent = DQNAgent(env.observation_space.shape,
env.action_space.n,
args.epsilon_init, args.epsilon_min, args.epsilon_desc,
args.gamma, args.alpha, args.n_episodes,
input_channels=args.input_channels,
algorithm=args.algorithm,
img_size=args.img_size,
hidden_dim=args.hidden_dim,
max_size=args.max_size,
target_update_interval=args.target_update_interval,
batch_size=args.batch_size,
cpt_dir=args.cpt_dir,
grad_clip=args.grad_clip,
prioritize=args.prioritize,
env_name=args.env_name)
else:
raise NotImplementedError
scores, best_score = deque(maxlen=args.progress_window), -np.inf

# load weights & make sure model in eval mode during test
if args.test:
agent.load_models()
agent.Q_function.eval()
pbar = tqdm(range(args.n_episodes))
for e in pbar:

# reset every episode and make sure functions are in training mode
done, score, observation = False, 0, env.reset()
agent.Q_function.train()
while not done:
# if test, only take greedy action, if not epsilon greedy
if args.test:
action = agent.greedy_action(observation)
env.render()
else:
action = agent.epsilon_greedy(observation)
next_observation, reward, done, info = env.step(action)

# only update parameters during training
if not args.test:
priorities = agent.update()
agent.store_transition(observation, reward, action, next_observation, done, priorities)
score += reward
observation = next_observation

# logging
writer.add_scalars('Performance and training', {'Score': score, 'Epsilon': agent.epsilon})
scores.append(score)
if score > best_score and not args.test:
agent.save_models()
best_score = score
if (e + 1) % args.print_every == 0:
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, Average Score: {np.mean(scores)}, Best Score {best_score}, Epsilon: {agent.epsilon}')
Loading

0 comments on commit f404b30

Please sign in to comment.