Skip to content

Commit

Permalink
added policy gradient methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Feb 13, 2021
1 parent c03f3ab commit ffde0ba
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 75 deletions.
124 changes: 124 additions & 0 deletions policy/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np


class BlackJackAgent:
def __init__(self, method, env, function='V', gamma=0.99, epsilon=0.1):
self.method = method
self.values = {(i, j, b): 0 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]}
self.vreturns = {(i, j, b): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]}
self.qs = {(i, j, b, a): 10 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)}
self.qreturns = {(i, j, b, a): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)}
self.value_function = lambda i, j, k: self.values[(i, j, k)]
self.q_function = lambda i, j, k, l: self.qs[(i, j, k, l)]
self.get_state_name = lambda state: (state[0], state[1], state[2])
self.get_state_action_name = lambda state, action: (state[0], state[1], state[2], action)
self.gamma = gamma
self.actions = list(range(env.action_space.n))
self.policy = {state: 0 for state in self.values.keys()}
self.epsilon = epsilon
self.function = function

def choose_action(self, state):
sum_, show, ace = state
if self.method == 'lucky':
return self.feeling_lucky(sum_)
if self.method == 'egreedy':
return self.epsilon_greedy(state)

def epsilon_greedy(self, state):
if np.random.random() < self.epsilon:
return np.random.choice(self.actions)
else:
state_name = self.get_state_name(state)
return self.policy[state_name]

def feeling_lucky(self, sum_):
if sum_ < 20:
return 1
return 0

def update(self, rewards, states, actions, function='V'):
visited = set()
if self.function == 'V':
for i, state in enumerate(states):
state_name = self.get_state_name(state)
if state_name in visited:
continue
G = 0
for j, reward in enumerate(rewards[i:], 1):
G += self.gamma ** j * reward
self.vreturns[state_name].append(G)
self.values[state_name] = np.mean(self.vreturns[state_name])
visited.add(state_name)
elif self.function == 'Q':
for i, (state, action) in enumerate(zip(states, actions)):
state_action_name = self.get_state_action_name(state, action)
if state_action_name in visited:
continue
G = 0
for j, reward in enumerate(rewards[i:], 1):
G += self.gamma ** j * reward
self.qreturns[state_action_name].append(G)
self.qs[state_action_name] = np.mean(self.qreturns[state_action_name])
visited.add(state_action_name)
for state in states:
Q_prime, A_prime = -np.inf, None
for action in actions:
state_action_name = self.get_state_action_name(state, action)
curr_Q = self.qs[state_action_name]
if curr_Q > Q_prime:
Q_prime = curr_Q
A_prime = action
state_name = self.get_state_name(state)
self.policy[state_name] = A_prime
else:
raise NotImplementedError


class CartPoleNoob:
def __init__(self, method, env, function='V', alpha=0.1, gamma=0.99, epsilon=0.1, n_bins=10):
self.method = method
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.function = function
self.actions = list(range(env.action_space.n))
self.rad = np.linspace(-0.2094, 0.2094, n_bins)
self.values = {r: 0 for r in range(len(self.rad) + 1)}
self.qs = {(r, a): 10 for r in range(len(self.rad) + 1) for a in self.actions}

def choose_action(self, state):
if self.method == 'naive':
return self.naive_action(state)
if self.method == 'egreedy':
return self.epsilon_greedy(state)

def naive_action(self, state):
if state[2] < 0:
return 0
return 1

def epsilon_greedy(self, state):
if np.random.random() < self.epsilon:
return np.random.choice(self.actions)
else:
s = self.get_bucket_index([state[2]])[0]
action = np.array([self.qs[(s, a)] for a in self.actions]).argmax()
return action

def get_bucket_index(self, states):
inds = np.digitize(states, self.rad)
return inds

def update(self, state, action, reward, state_):
r, r_ = self.get_bucket_index([state[2], state_[2]])
if self.function == 'V':
# TD update w/ bootstrap
self.values[r] += self.alpha * (reward + self.gamma * self.values[r_] - self.values[r])
elif self.function == 'Q':
Q_ = np.array([self.qs[(r_, a)] for a in self.actions]).max()
self.qs[(r, action)] += self.alpha * (reward + self.gamma * Q_ - self.qs[(r, action)])
self.decrease_eps()

def decrease_eps(self):
self.epsilon = max(0.01, self.epsilon - 1e-5)
34 changes: 34 additions & 0 deletions policy/blackjack/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import gym
import argparse
from tqdm import trange
from policy.agent import BlackJackAgent


parser = argparse.ArgumentParser(description='Black Jack Agents')
parser.add_argument('--method', type=str, default='lucky', help='The name of the policy you wish to evaluate')
parser.add_argument('--function', type=str, default='Q', help='The function to evaluate')
parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for')
args = parser.parse_args()


def first_visit_monte_carlo():
env = gym.make('Blackjack-v0')
agent = BlackJackAgent(args.method, env, args.function)
for _ in trange(args.n_episodes):
state, done = env.reset(), False
states, actions, rewards = [state], [], []
while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
states.append(state)
rewards.append(reward)
actions.append(action)
state = state_
agent.update(rewards, states, actions)

print(agent.value_function(21, 2, True))
print(agent.q_function(16, 2, False, 0))


if __name__ == '__main__':
first_visit_monte_carlo()
27 changes: 27 additions & 0 deletions policy/cartpole/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import gym
import argparse
from tqdm import trange
from policy.agent import CartPoleNoob


parser = argparse.ArgumentParser(description='Cartpole Agents')
parser.add_argument('--method', type=str, default='egreedy', help='The name of the policy you wish to evaluate')
parser.add_argument('--function', type=str, default='Q', help='The function to evaluate')
parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for')
args = parser.parse_args()


def td():
env = gym.make('CartPole-v0')
agent = CartPoleNoob(args.method, env, args.function)
for _ in trange(args.n_episodes):
state, done = env.reset(), False
while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
agent.update(state, action, reward, state_)
state = state_
print(agent.values)

if __name__ == '__main__':
td()
77 changes: 37 additions & 40 deletions qlearning/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from torch import optim
from copy import deepcopy
from qlearning.networks import QNaive, QBasic, QDueling
from qlearning import networks
from qlearning.networks import QNaive
from qlearning.experience_replay import ReplayBuffer


Expand Down Expand Up @@ -95,48 +96,40 @@ def decrease_epsilon(self):
class DQNAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args)
self.algorithm = kwargs['algorithm']
self.batch_size = kwargs['batch_size']
self.grad_clip = kwargs['grad_clip']
self.prioritize = kwargs['prioritize']
self.alpha = kwargs['alpha']
self.beta = kwargs['beta']
self.eps = kwargs['eps']
self.memory = ReplayBuffer(kwargs['max_size'], self.state_dim)
self.target_update_interval = kwargs['target_update_interval']
for k, v in kwargs.items():
setattr(self, k, v)
self.memory = ReplayBuffer(self.max_size, self.state_dim)
self.n_updates = 0

self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if self.algorithm.startswith('Dueling'):
self.Q_function = QDueling(
kwargs['input_channels'],
self.n_actions,
kwargs['cpt_dir'],
kwargs['algorithm'] + '_' + kwargs['env_name'],
kwargs['img_size'],
kwargs['hidden_dim'],
noised=kwargs['noised']).to(self.device)
else:
self.Q_function = QBasic(
kwargs['input_channels'],
self.n_actions,
kwargs['cpt_dir'],
kwargs['algorithm'] + '_' + kwargs['env_name'],
kwargs['img_size'],
kwargs['hidden_dim'],
noised=kwargs['noised']).to(self.device)

network = self.algorithm
if 'DD' in network:
import re
network = re.sub('DDQN', 'DQN', network)
network = getattr(networks, network)
self.Q_function = network(
input_channels=self.input_channels,
out_features=self.n_actions,
cpt_dir=self.cpt_dir,
name=self.algorithm + '_' + self.env_name,
img_size=self.img_size,
hidden_dim=self.hidden_dim,
n_repeats=self.n_repeats,
noised=self.noised,
num_atoms=self.num_atoms).to(self.device)

# instanciate target network
self.target_Q = deepcopy(self.Q_function)
self.freeze_network(self.target_Q)
self.target_Q.name = kwargs['algorithm'] + '_' + kwargs['env_name'] + '_target'
self.target_Q.name = self.algorithm + '_' + self.env_name + '_target'

self.optimizer = torch.optim.RMSprop(self.Q_function.parameters(), lr=self.lr, alpha=0.95)
self.criterion = torch.nn.MSELoss(reduction='none')

def greedy_action(self, observation):
observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device)
next_action = self.Q_function(observation).argmax()
with torch.no_grad():
observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device)
next_action = self.Q_function(observation).argmax()
return next_action.item()

def update_target_network(self):
Expand All @@ -161,18 +154,18 @@ def update(self):
# double DQN uses online network to select action for Q'
if self.algorithm.endswith('DDQN'):
next_actions = self.Q_function(next_observations).argmax(-1)
q_prime = self.target_Q(next_observations)[list(range(self.batch_size)), next_actions]
q_prime = self.target_Q(next_observations).gather(1, next_actions.unsqueeze(1))
elif self.algorithm.endswith('DQN'):
q_prime = self.target_Q(next_observations).max(-1)[0]

# calculate target + estimate
q_target = rewards + self.gamma * q_prime * (~dones)
q_pred = self.Q_function(observations)[list(range(self.batch_size)), actions]
loss = self.criterion(q_target.detach(), q_pred)
q_target = rewards + self.gamma * q_prime.squeeze() * (~dones)
q_pred = self.Q_function(observations).gather(1, actions.unsqueeze(1))
loss = self.criterion(q_target.detach(), q_pred.squeeze())

# for updating priorities if using priority replay
if self.prioritize:
priorities = (idx, loss.clone().detach() + self.eps)
priorities = (idx, loss.detach().cpu() + self.eps)
else:
priorities = None

Expand All @@ -182,13 +175,16 @@ def update(self):
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(self.Q_function.parameters(), self.grad_clip)
self.optimizer.step()
self.decrease_epsilon()
self.adjust_epsilon_and_beta()
self.n_updates += 1
if self.n_updates % self.target_update_interval == 0:
self.update_target_network()
return priorities

def decrease_epsilon(self):
def adjust_epsilon_and_beta(self):
self.beta = min(
self.beta_min,
self.beta + self.beta_dec)
self.epsilon = max(
self.epsilon_min,
self.epsilon - self.epsilon_desc)
Expand All @@ -198,7 +194,8 @@ def store_transition(self, state, reward, action, next_state, done, priority=Non
self.memory.store(state, reward, action, next_state, done, priority=priority)

def sample_transitions(self):
return self.memory.sample(self.batch_size, self.device)
transition = self.memory.sample(self.batch_size, self.device, self.beta)
return transition

def save_models(self):
self.target_Q.check_point()
Expand Down
Loading

0 comments on commit ffde0ba

Please sign in to comment.