Skip to content

Commit ffde0ba

Browse files
committed
added policy gradient methods
1 parent c03f3ab commit ffde0ba

File tree

8 files changed

+301
-75
lines changed

8 files changed

+301
-75
lines changed

policy/agent.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import numpy as np
2+
3+
4+
class BlackJackAgent:
5+
def __init__(self, method, env, function='V', gamma=0.99, epsilon=0.1):
6+
self.method = method
7+
self.values = {(i, j, b): 0 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]}
8+
self.vreturns = {(i, j, b): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]}
9+
self.qs = {(i, j, b, a): 10 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)}
10+
self.qreturns = {(i, j, b, a): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)}
11+
self.value_function = lambda i, j, k: self.values[(i, j, k)]
12+
self.q_function = lambda i, j, k, l: self.qs[(i, j, k, l)]
13+
self.get_state_name = lambda state: (state[0], state[1], state[2])
14+
self.get_state_action_name = lambda state, action: (state[0], state[1], state[2], action)
15+
self.gamma = gamma
16+
self.actions = list(range(env.action_space.n))
17+
self.policy = {state: 0 for state in self.values.keys()}
18+
self.epsilon = epsilon
19+
self.function = function
20+
21+
def choose_action(self, state):
22+
sum_, show, ace = state
23+
if self.method == 'lucky':
24+
return self.feeling_lucky(sum_)
25+
if self.method == 'egreedy':
26+
return self.epsilon_greedy(state)
27+
28+
def epsilon_greedy(self, state):
29+
if np.random.random() < self.epsilon:
30+
return np.random.choice(self.actions)
31+
else:
32+
state_name = self.get_state_name(state)
33+
return self.policy[state_name]
34+
35+
def feeling_lucky(self, sum_):
36+
if sum_ < 20:
37+
return 1
38+
return 0
39+
40+
def update(self, rewards, states, actions, function='V'):
41+
visited = set()
42+
if self.function == 'V':
43+
for i, state in enumerate(states):
44+
state_name = self.get_state_name(state)
45+
if state_name in visited:
46+
continue
47+
G = 0
48+
for j, reward in enumerate(rewards[i:], 1):
49+
G += self.gamma ** j * reward
50+
self.vreturns[state_name].append(G)
51+
self.values[state_name] = np.mean(self.vreturns[state_name])
52+
visited.add(state_name)
53+
elif self.function == 'Q':
54+
for i, (state, action) in enumerate(zip(states, actions)):
55+
state_action_name = self.get_state_action_name(state, action)
56+
if state_action_name in visited:
57+
continue
58+
G = 0
59+
for j, reward in enumerate(rewards[i:], 1):
60+
G += self.gamma ** j * reward
61+
self.qreturns[state_action_name].append(G)
62+
self.qs[state_action_name] = np.mean(self.qreturns[state_action_name])
63+
visited.add(state_action_name)
64+
for state in states:
65+
Q_prime, A_prime = -np.inf, None
66+
for action in actions:
67+
state_action_name = self.get_state_action_name(state, action)
68+
curr_Q = self.qs[state_action_name]
69+
if curr_Q > Q_prime:
70+
Q_prime = curr_Q
71+
A_prime = action
72+
state_name = self.get_state_name(state)
73+
self.policy[state_name] = A_prime
74+
else:
75+
raise NotImplementedError
76+
77+
78+
class CartPoleNoob:
79+
def __init__(self, method, env, function='V', alpha=0.1, gamma=0.99, epsilon=0.1, n_bins=10):
80+
self.method = method
81+
self.alpha = alpha
82+
self.gamma = gamma
83+
self.epsilon = epsilon
84+
self.function = function
85+
self.actions = list(range(env.action_space.n))
86+
self.rad = np.linspace(-0.2094, 0.2094, n_bins)
87+
self.values = {r: 0 for r in range(len(self.rad) + 1)}
88+
self.qs = {(r, a): 10 for r in range(len(self.rad) + 1) for a in self.actions}
89+
90+
def choose_action(self, state):
91+
if self.method == 'naive':
92+
return self.naive_action(state)
93+
if self.method == 'egreedy':
94+
return self.epsilon_greedy(state)
95+
96+
def naive_action(self, state):
97+
if state[2] < 0:
98+
return 0
99+
return 1
100+
101+
def epsilon_greedy(self, state):
102+
if np.random.random() < self.epsilon:
103+
return np.random.choice(self.actions)
104+
else:
105+
s = self.get_bucket_index([state[2]])[0]
106+
action = np.array([self.qs[(s, a)] for a in self.actions]).argmax()
107+
return action
108+
109+
def get_bucket_index(self, states):
110+
inds = np.digitize(states, self.rad)
111+
return inds
112+
113+
def update(self, state, action, reward, state_):
114+
r, r_ = self.get_bucket_index([state[2], state_[2]])
115+
if self.function == 'V':
116+
# TD update w/ bootstrap
117+
self.values[r] += self.alpha * (reward + self.gamma * self.values[r_] - self.values[r])
118+
elif self.function == 'Q':
119+
Q_ = np.array([self.qs[(r_, a)] for a in self.actions]).max()
120+
self.qs[(r, action)] += self.alpha * (reward + self.gamma * Q_ - self.qs[(r, action)])
121+
self.decrease_eps()
122+
123+
def decrease_eps(self):
124+
self.epsilon = max(0.01, self.epsilon - 1e-5)

policy/blackjack/main.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import gym
2+
import argparse
3+
from tqdm import trange
4+
from policy.agent import BlackJackAgent
5+
6+
7+
parser = argparse.ArgumentParser(description='Black Jack Agents')
8+
parser.add_argument('--method', type=str, default='lucky', help='The name of the policy you wish to evaluate')
9+
parser.add_argument('--function', type=str, default='Q', help='The function to evaluate')
10+
parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for')
11+
args = parser.parse_args()
12+
13+
14+
def first_visit_monte_carlo():
15+
env = gym.make('Blackjack-v0')
16+
agent = BlackJackAgent(args.method, env, args.function)
17+
for _ in trange(args.n_episodes):
18+
state, done = env.reset(), False
19+
states, actions, rewards = [state], [], []
20+
while not done:
21+
action = agent.choose_action(state)
22+
state_, reward, done, _ = env.step(action)
23+
states.append(state)
24+
rewards.append(reward)
25+
actions.append(action)
26+
state = state_
27+
agent.update(rewards, states, actions)
28+
29+
print(agent.value_function(21, 2, True))
30+
print(agent.q_function(16, 2, False, 0))
31+
32+
33+
if __name__ == '__main__':
34+
first_visit_monte_carlo()

policy/cartpole/main.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import gym
2+
import argparse
3+
from tqdm import trange
4+
from policy.agent import CartPoleNoob
5+
6+
7+
parser = argparse.ArgumentParser(description='Cartpole Agents')
8+
parser.add_argument('--method', type=str, default='egreedy', help='The name of the policy you wish to evaluate')
9+
parser.add_argument('--function', type=str, default='Q', help='The function to evaluate')
10+
parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for')
11+
args = parser.parse_args()
12+
13+
14+
def td():
15+
env = gym.make('CartPole-v0')
16+
agent = CartPoleNoob(args.method, env, args.function)
17+
for _ in trange(args.n_episodes):
18+
state, done = env.reset(), False
19+
while not done:
20+
action = agent.choose_action(state)
21+
state_, reward, done, _ = env.step(action)
22+
agent.update(state, action, reward, state_)
23+
state = state_
24+
print(agent.values)
25+
26+
if __name__ == '__main__':
27+
td()

qlearning/agent.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
from torch import optim
55
from copy import deepcopy
6-
from qlearning.networks import QNaive, QBasic, QDueling
6+
from qlearning import networks
7+
from qlearning.networks import QNaive
78
from qlearning.experience_replay import ReplayBuffer
89

910

@@ -95,48 +96,40 @@ def decrease_epsilon(self):
9596
class DQNAgent(BaseAgent):
9697
def __init__(self, *args, **kwargs):
9798
super().__init__(*args)
98-
self.algorithm = kwargs['algorithm']
99-
self.batch_size = kwargs['batch_size']
100-
self.grad_clip = kwargs['grad_clip']
101-
self.prioritize = kwargs['prioritize']
102-
self.alpha = kwargs['alpha']
103-
self.beta = kwargs['beta']
104-
self.eps = kwargs['eps']
105-
self.memory = ReplayBuffer(kwargs['max_size'], self.state_dim)
106-
self.target_update_interval = kwargs['target_update_interval']
99+
for k, v in kwargs.items():
100+
setattr(self, k, v)
101+
self.memory = ReplayBuffer(self.max_size, self.state_dim)
107102
self.n_updates = 0
108-
109103
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
110-
if self.algorithm.startswith('Dueling'):
111-
self.Q_function = QDueling(
112-
kwargs['input_channels'],
113-
self.n_actions,
114-
kwargs['cpt_dir'],
115-
kwargs['algorithm'] + '_' + kwargs['env_name'],
116-
kwargs['img_size'],
117-
kwargs['hidden_dim'],
118-
noised=kwargs['noised']).to(self.device)
119-
else:
120-
self.Q_function = QBasic(
121-
kwargs['input_channels'],
122-
self.n_actions,
123-
kwargs['cpt_dir'],
124-
kwargs['algorithm'] + '_' + kwargs['env_name'],
125-
kwargs['img_size'],
126-
kwargs['hidden_dim'],
127-
noised=kwargs['noised']).to(self.device)
104+
105+
network = self.algorithm
106+
if 'DD' in network:
107+
import re
108+
network = re.sub('DDQN', 'DQN', network)
109+
network = getattr(networks, network)
110+
self.Q_function = network(
111+
input_channels=self.input_channels,
112+
out_features=self.n_actions,
113+
cpt_dir=self.cpt_dir,
114+
name=self.algorithm + '_' + self.env_name,
115+
img_size=self.img_size,
116+
hidden_dim=self.hidden_dim,
117+
n_repeats=self.n_repeats,
118+
noised=self.noised,
119+
num_atoms=self.num_atoms).to(self.device)
128120

129121
# instanciate target network
130122
self.target_Q = deepcopy(self.Q_function)
131123
self.freeze_network(self.target_Q)
132-
self.target_Q.name = kwargs['algorithm'] + '_' + kwargs['env_name'] + '_target'
124+
self.target_Q.name = self.algorithm + '_' + self.env_name + '_target'
133125

134126
self.optimizer = torch.optim.RMSprop(self.Q_function.parameters(), lr=self.lr, alpha=0.95)
135127
self.criterion = torch.nn.MSELoss(reduction='none')
136128

137129
def greedy_action(self, observation):
138-
observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device)
139-
next_action = self.Q_function(observation).argmax()
130+
with torch.no_grad():
131+
observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device)
132+
next_action = self.Q_function(observation).argmax()
140133
return next_action.item()
141134

142135
def update_target_network(self):
@@ -161,18 +154,18 @@ def update(self):
161154
# double DQN uses online network to select action for Q'
162155
if self.algorithm.endswith('DDQN'):
163156
next_actions = self.Q_function(next_observations).argmax(-1)
164-
q_prime = self.target_Q(next_observations)[list(range(self.batch_size)), next_actions]
157+
q_prime = self.target_Q(next_observations).gather(1, next_actions.unsqueeze(1))
165158
elif self.algorithm.endswith('DQN'):
166159
q_prime = self.target_Q(next_observations).max(-1)[0]
167160

168161
# calculate target + estimate
169-
q_target = rewards + self.gamma * q_prime * (~dones)
170-
q_pred = self.Q_function(observations)[list(range(self.batch_size)), actions]
171-
loss = self.criterion(q_target.detach(), q_pred)
162+
q_target = rewards + self.gamma * q_prime.squeeze() * (~dones)
163+
q_pred = self.Q_function(observations).gather(1, actions.unsqueeze(1))
164+
loss = self.criterion(q_target.detach(), q_pred.squeeze())
172165

173166
# for updating priorities if using priority replay
174167
if self.prioritize:
175-
priorities = (idx, loss.clone().detach() + self.eps)
168+
priorities = (idx, loss.detach().cpu() + self.eps)
176169
else:
177170
priorities = None
178171

@@ -182,13 +175,16 @@ def update(self):
182175
if self.grad_clip is not None:
183176
torch.nn.utils.clip_grad_norm_(self.Q_function.parameters(), self.grad_clip)
184177
self.optimizer.step()
185-
self.decrease_epsilon()
178+
self.adjust_epsilon_and_beta()
186179
self.n_updates += 1
187180
if self.n_updates % self.target_update_interval == 0:
188181
self.update_target_network()
189182
return priorities
190183

191-
def decrease_epsilon(self):
184+
def adjust_epsilon_and_beta(self):
185+
self.beta = min(
186+
self.beta_min,
187+
self.beta + self.beta_dec)
192188
self.epsilon = max(
193189
self.epsilon_min,
194190
self.epsilon - self.epsilon_desc)
@@ -198,7 +194,8 @@ def store_transition(self, state, reward, action, next_state, done, priority=Non
198194
self.memory.store(state, reward, action, next_state, done, priority=priority)
199195

200196
def sample_transitions(self):
201-
return self.memory.sample(self.batch_size, self.device)
197+
transition = self.memory.sample(self.batch_size, self.device, self.beta)
198+
return transition
202199

203200
def save_models(self):
204201
self.target_Q.check_point()

0 commit comments

Comments
 (0)