Skip to content

Commit

Permalink
added prioritized replay
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Feb 5, 2021
1 parent f404b30 commit 13db269
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 33 deletions.
47 changes: 27 additions & 20 deletions qlearning/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class BaseAgent:
def __init__(self, state_dim, n_actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes):
def __init__(self, state_dim, n_actions, epsilon_init, epsilon_min, epsilon_desc, gamma, lr, n_episodes):
self.actions = list(range(n_actions))
self.n_actions = n_actions
if isinstance(state_dim, int):
Expand All @@ -18,7 +18,7 @@ def __init__(self, state_dim, n_actions, epsilon_init, epsilon_min, epsilon_desc
self.epsilon_min = epsilon_min
self.epsilon_desc = epsilon_desc
self.gamma = gamma
self.alpha = alpha
self.lr = lr
self.n_episodes = n_episodes

def epsilon_greedy(self, state):
Expand All @@ -31,8 +31,8 @@ def greedy_action(self, state):


class TabularAgent(BaseAgent):
def __init__(self, states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes):
super().__init__(states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes)
def __init__(self, states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, lr, n_episodes):
super().__init__(states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, lr, n_episodes)
# initialize table with 0 Q-values
self.q_table = pd.DataFrame(np.zeros((self.state_dim, self.n_actions)),
index=states, columns=actions)
Expand All @@ -45,7 +45,7 @@ def update(self, state, action, reward, next_state):
# update Q-table
max_Q_ = self.q_table.loc[next_state].max()
Q_sa = self.q_table.loc[state, action]
self.q_table.loc[state, action] += self.alpha * (
self.q_table.loc[state, action] += self.lr * (
reward + self.gamma * max_Q_ - Q_sa)
# update epsilon
self.decrease_epsilon()
Expand All @@ -66,7 +66,7 @@ def __init__(self, *args, **kwargs):
kwargs['action_dim'],
kwargs['hidden_dim'],
self.n_actions).to(self.device)
self.optimizer = optim.Adam(self.Q_function.parameters(), self.alpha)
self.optimizer = optim.Adam(self.Q_function.parameters(), self.lr)
self.criterion = torch.nn.MSELoss()

def number2tensor(self, number):
Expand Down Expand Up @@ -99,6 +99,9 @@ def __init__(self, *args, **kwargs):
self.batch_size = kwargs['batch_size']
self.grad_clip = kwargs['grad_clip']
self.prioritize = kwargs['prioritize']
self.alpha = kwargs['alpha']
self.beta = kwargs['beta']
self.eps = kwargs['eps']
self.memory = ReplayBuffer(kwargs['max_size'], self.state_dim)
self.target_update_interval = kwargs['target_update_interval']
self.n_updates = 0
Expand Down Expand Up @@ -126,8 +129,8 @@ def __init__(self, *args, **kwargs):
self.freeze_network(self.target_Q)
self.target_Q.name = kwargs['algorithm'] + '_' + kwargs['env_name'] + '_target'

self.optimizer = torch.optim.RMSprop(self.Q_function.parameters(), lr=self.alpha, alpha=0.95)
self.criterion = torch.nn.MSELoss()
self.optimizer = torch.optim.RMSprop(self.Q_function.parameters(), lr=self.lr, alpha=0.95)
self.criterion = torch.nn.MSELoss(reduction='none')

def greedy_action(self, observation):
observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device)
Expand All @@ -145,22 +148,33 @@ def freeze_network(self, network):
p.requires_grad = False

def update(self):
# Q_t = Q_t + alpha * (reward + gamma * Q'_t - Q^target_t) ** 2
# Q_t = Q_t + lr * (reward + gamma * Q'_t - Q^target_t) ** 2
# keep sampling until we have full batch
if self.memory.ctr < self.batch_size:
return
self.optimizer.zero_grad()
observations, rewards, actions, next_observations, dones = self.sample_transitions()
observations, rewards, actions, next_observations, dones, idx, weights = self.sample_transitions()

# double DQN use online network to select action for target
# double DQN uses online network to select action for Q'
if self.algorithm.endswith('DDQN'):
next_actions = self.Q_function(next_observations).argmax(-1)
q_prime = self.target_Q(next_observations)[list(range(self.batch_size)), next_actions]
elif self.algorithm.endswith('DQN'):
q_prime = self.target_Q(next_observations).max(-1)[0]
q_target = rewards.to(self.device) + self.gamma * q_prime * (~dones)

# calculate target + estimate
q_target = rewards + self.gamma * q_prime * (~dones)
q_pred = self.Q_function(observations)[list(range(self.batch_size)), actions]
loss = self.criterion(q_target, q_pred)
loss = self.criterion(q_target.detach(), q_pred)

# for updating priorities if using priority replay
if self.prioritize:
priorities = (idx, loss.clone().detach() + self.eps)
else:
priorities = None

# update
loss = (loss * weights).mean()
loss.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(self.Q_function.parameters(), self.grad_clip)
Expand All @@ -169,13 +183,6 @@ def update(self):
self.n_updates += 1
if self.n_updates % self.target_update_interval == 0:
self.update_target_network()

priorities = None
# if self.prioritize:
# # TODO
# pass
# else:
# priorities = None
return priorities

def decrease_epsilon(self):
Expand Down
21 changes: 16 additions & 5 deletions qlearning/atari/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


parser = argparse.ArgumentParser(description='Q Learning Atari Agents')
parser.add_argument('--algorithm', type=str, default='DQN', help='The type of algorithm you wish to use.\nDQN\n \
parser.add_argument('--algorithm', type=str, default='DuelingDDQN', help='The type of algorithm you wish to use.\nDQN\n \
DDQN\n \
DuelingDQN\n \
DuelingDDQN')
Expand All @@ -33,14 +33,17 @@

# training
parser.add_argument('-e', '--n_episodes', '--epochs', type=int, default=1000, help='Number of episodes agent interacts with env')
parser.add_argument('--alpha', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--lr', type=float, default=0.00025, help='Learning rate')
parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor')
parser.add_argument('--prioritize', action="store_false", default=True, help='Use Prioritized Experience Replay')
parser.add_argument('--epsilon_init', type=float, default=1.0, help='Initial epsilon value')
parser.add_argument('--epsilon_min', type=float, default=0.1, help='Minimum epsilon value to decay to')
parser.add_argument('--epsilon_desc', type=float, default=1e-5, help='Epsilon decrease')
parser.add_argument('--grad_clip', type=float, default=10, help='Norm of the grad clip, None for no clip')
parser.add_argument('-b', '--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--no_prioritize', action="store_true", default=False, help='Use Prioritized Experience Replay')
parser.add_argument('--alpha', type=float, default=0.6, help='Prioritized Experience Replay alpha')
parser.add_argument('--beta', type=float, default=0.4, help='Prioritized Experience Replay beta')
parser.add_argument('--eps', type=float, default=1e-5, help='Prioritized Experience Replay epsilon')

# logging
parser.add_argument('--progress_window', type=int, default=100, help='Window of episodes for progress')
Expand Down Expand Up @@ -72,7 +75,7 @@
agent = DQNAgent(env.observation_space.shape,
env.action_space.n,
args.epsilon_init, args.epsilon_min, args.epsilon_desc,
args.gamma, args.alpha, args.n_episodes,
args.gamma, args.lr, args.n_episodes,
input_channels=args.input_channels,
algorithm=args.algorithm,
img_size=args.img_size,
Expand All @@ -82,10 +85,18 @@
batch_size=args.batch_size,
cpt_dir=args.cpt_dir,
grad_clip=args.grad_clip,
prioritize=args.prioritize,
prioritize=not args.no_prioritize,
alpha=args.alpha,
beta=args.beta,
eps=args.eps,
env_name=args.env_name)
else:
raise NotImplementedError
# force some parameters depending on if using priority replay
if args.no_prioritize:
args.alpha, args.beta, args.epsilon = 1, 0, 0
else:
args.lr /= 4
scores, best_score = deque(maxlen=args.progress_window), -np.inf

# load weights & make sure model in eval mode during test
Expand Down
16 changes: 11 additions & 5 deletions qlearning/experience_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@


class ReplayBuffer:
def __init__(self, max_size, state_dim):
def __init__(self, max_size, state_dim, alpha=1):
self.states = torch.empty(max_size, *state_dim)
self.rewards = torch.empty(max_size)
self.actions = torch.zeros(max_size, dtype=torch.long)
self.next_states = torch.empty(max_size, *state_dim)
self.dones = torch.ones(max_size, dtype=torch.bool)
self.priorities = torch.zeros(max_size)
self.max_size = max_size
self.alpha = alpha
self.ctr = 0

def store(self, state, reward, action, next_state, done, priority=None):
Expand All @@ -20,18 +21,23 @@ def store(self, state, reward, action, next_state, done, priority=None):
self.next_states[i] = next_state.cpu()
self.dones[i] = done
if priority is not None:
self.priorities[i] = priority
idx, priority = priority
self.priorities[idx] = priority.cpu()
else:
self.priorities[i] = 1
self.ctr += 1

def sample(self, batch_size, device):
def sample(self, batch_size, device, beta=0):
max_mem = min(self.ctr, self.max_size)
assert max_mem > 0
idx = torch.multinomial(self.priorities, batch_size)
sample_distribution = self.priorities ** self.alpha
sample_distribution /= sample_distribution.sum()
idx = torch.multinomial(sample_distribution, batch_size)
states = self.states[idx].to(device)
rewards = self.rewards[idx].to(device)
actions = self.actions[idx]
next_states = self.next_states[idx].to(device)
dones = self.dones[idx].to(device)
return states, rewards, actions, next_states, dones
weights = ((max_mem * sample_distribution[idx]) ** (- beta)).to(device)
weights /= weights.max()
return states, rewards, actions, next_states, dones, idx, weights
6 changes: 3 additions & 3 deletions qlearning/frozen_lake/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
parser.add_argument('--policy', type=str, default='naive neural', help='The type of policy you wish to use')
parser.add_argument('--trailing_n', type=int, default=10, help='Window size of plotting win %')
parser.add_argument('--n_episodes', type=int, default=10000, help='Number of episodes agent interacts with env')
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor')
parser.add_argument('--epsilon_init', type=float, default=1.0, help='Initial epsilon value')
parser.add_argument('--epsilon_min', type=float, default=0.01, help='Minimum epsilon value to decay to')
Expand All @@ -35,12 +35,12 @@ def __init__(self, policy):
self.agent = TabularAgent(np.arange(env.observation_space.n),
np.arange(env.action_space.n),
args.epsilon_init, args.epsilon_min, args.epsilon_desc,
args.gamma, args.alpha, args.n_episodes)
args.gamma, args.lr, args.n_episodes)
elif policy == 'naive neural':
self.agent = NaiveNeuralAgent(np.arange(env.observation_space.n),
np.arange(env.action_space.n),
args.epsilon_init, args.epsilon_min, args.epsilon_desc,
args.gamma, args.alpha, args.n_episodes,
args.gamma, args.lr, args.n_episodes,
policy=policy, state_dim=64, action_dim=64, hidden_dim=128)

def __call__(self, state):
Expand Down

0 comments on commit 13db269

Please sign in to comment.