-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3f0eb49
commit c12979c
Showing
2 changed files
with
141 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
class Tabular_Agent: | ||
def __init__(self, states, actions, epsilon_init, epsilon_min, epsilon_desc, gamma, alpha, n_episodes): | ||
self.actions = actions | ||
self.n_actions = len(actions) | ||
self.states = states | ||
self.n_states = len(states) | ||
self.epsilon = epsilon_init | ||
self.epsilon_min = epsilon_min | ||
self.epsilon_desc = epsilon_desc | ||
self.gamma = gamma | ||
self.alpha = alpha | ||
self.n_episodes = n_episodes | ||
# initialize table with 0 Q-values | ||
self.q_table = pd.DataFrame(np.zeros((self.n_states, self.n_actions)), | ||
index=states, columns=actions) | ||
|
||
def greedy_action(self, state): | ||
Qs = self.q_table.loc[state] | ||
return Qs.argmax() | ||
|
||
def strategize(self, state): | ||
if np.random.random() > self.epsilon: | ||
return self.greedy_action(state) | ||
return np.random.choice(self.actions) | ||
|
||
def update(self, state, action, reward, next_state): | ||
# update Q-table | ||
max_Q_ = self.q_table.loc[next_state].max() | ||
Q_sa = self.q_table.loc[state, action] | ||
self.q_table.loc[state, action] += self.alpha * ( | ||
reward + self.gamma * max_Q_ - Q_sa) | ||
# update epsilon | ||
self.epsilon = max( | ||
self.epsilon_min, | ||
self.epsilon * self.epsilon_desc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import matplotlib.pyplot as plt | ||
import gym | ||
import numpy as np | ||
import argparse | ||
from collections import deque | ||
from qlearning.frozen_lake.agent import Tabular_Agent | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--policy', type=str, default='Q', help='The type of policy you wish to use') | ||
parser.add_argument('--trailing_n', type=int, default=10, help='Window size of plotting win %') | ||
parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes agent interacts with env') | ||
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate') | ||
parser.add_argument('--gamma', type=float, default=0.9, help='Discount factor') | ||
parser.add_argument('--epsilon_init', type=float, default=1.0, help='Initial epsilon value') | ||
parser.add_argument('--epsilon_min', type=float, default=0.01, help='Minimum epsilon value to decay to') | ||
parser.add_argument('--epsilon_desc', type=float, default=0.9999995, help='Epsilon multiplier') | ||
parser.add_argument('--progress_window', type=int, default=100, help='Window of episodes for progress') | ||
parser.add_argument('--print_every', type=int, default=1000, help='Print progress interval') | ||
args = parser.parse_args() | ||
|
||
|
||
class Policies: | ||
""" | ||
LEFT = 0 | ||
DOWN = 1 | ||
RIGHT = 2 | ||
UP = 3 | ||
""" | ||
def __init__(self, policy): | ||
self.policy = policy | ||
if policy == 'Q': | ||
self.agent = Tabular_Agent(np.arange(env.observation_space.n), | ||
np.arange(env.action_space.n), | ||
args.epsilon_init, args.epsilon_min, args.epsilon_desc, | ||
args.gamma, args.alpha, args.n_episodes) | ||
|
||
def __call__(self, state): | ||
if self.policy == 'random': | ||
return env.action_space.sample() | ||
if self.policy == 'direct': | ||
return self._direct_policy(state) | ||
if self.policy == 'Q': | ||
return self._epsilon_greedy(state) | ||
|
||
def _direct_policy(self, state): | ||
if state in [0, 4, 6, 9, 10]: | ||
return 1 | ||
if state in [1, 2, 8, 13, 14]: | ||
return 2 | ||
if state in [3]: | ||
return 0 | ||
|
||
def _epsilon_greedy(self, state): | ||
return self.agent.strategize(state) | ||
|
||
def update(self, state, action, reward, next_state): | ||
self.agent.update(state, action, reward, next_state) | ||
|
||
|
||
def plot_win_perc(scores, trailing_n, n_episodes): | ||
win_perc = [sum(scores[i - trailing_n:i]) / trailing_n for i in range(trailing_n, n_episodes)] | ||
plt.figure(figsize=(12, 12)) | ||
plt.plot(win_perc) | ||
plt.xlabel('Number of Trials') | ||
plt.ylabel('Winning Percentage') | ||
plt.title(f'Win % over Trialing {trailing_n} Games') | ||
plt.show() | ||
|
||
|
||
def plot_avg_score(scores): | ||
plt.figure(figsize=(12, 12)) | ||
plt.plot(scores) | ||
plt.ylabel('Score') | ||
plt.title(f'Average Score of Last {args.progress_window} Games') | ||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
env = gym.make('FrozenLake-v0') | ||
scores, avg_scores = [], [] | ||
pi = Policies(args.policy) | ||
for i in range(args.n_episodes): | ||
done, observation, score = False, env.reset(), 0 | ||
while not done: | ||
action = pi(observation) | ||
next_observation, reward, done, info = env.step(action) | ||
if args.policy == 'Q': | ||
pi.update(observation, action, reward, next_observation) | ||
score += reward | ||
observation = next_observation | ||
scores.append(score) | ||
avg_scores.append(np.mean(scores[-100:])) | ||
if (i + 1) % args.print_every == 0 and args.policy == 'Q': | ||
print(f'Episode: {i + 1}/{args.n_episodes}, Average Score: {avg_scores[-1]}, Epsilon {pi.agent.epsilon}') | ||
env.close() | ||
|
||
# plotting | ||
if args.policy == 'Q': | ||
plot_avg_score(avg_scores) | ||
else: | ||
plot_win_perc(scores, args.trailing_n, args.n_episodes) |