-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c03f3ab
commit ffde0ba
Showing
8 changed files
with
301 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import numpy as np | ||
|
||
|
||
class BlackJackAgent: | ||
def __init__(self, method, env, function='V', gamma=0.99, epsilon=0.1): | ||
self.method = method | ||
self.values = {(i, j, b): 0 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]} | ||
self.vreturns = {(i, j, b): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]} | ||
self.qs = {(i, j, b, a): 10 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)} | ||
self.qreturns = {(i, j, b, a): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)} | ||
self.value_function = lambda i, j, k: self.values[(i, j, k)] | ||
self.q_function = lambda i, j, k, l: self.qs[(i, j, k, l)] | ||
self.get_state_name = lambda state: (state[0], state[1], state[2]) | ||
self.get_state_action_name = lambda state, action: (state[0], state[1], state[2], action) | ||
self.gamma = gamma | ||
self.actions = list(range(env.action_space.n)) | ||
self.policy = {state: 0 for state in self.values.keys()} | ||
self.epsilon = epsilon | ||
self.function = function | ||
|
||
def choose_action(self, state): | ||
sum_, show, ace = state | ||
if self.method == 'lucky': | ||
return self.feeling_lucky(sum_) | ||
if self.method == 'egreedy': | ||
return self.epsilon_greedy(state) | ||
|
||
def epsilon_greedy(self, state): | ||
if np.random.random() < self.epsilon: | ||
return np.random.choice(self.actions) | ||
else: | ||
state_name = self.get_state_name(state) | ||
return self.policy[state_name] | ||
|
||
def feeling_lucky(self, sum_): | ||
if sum_ < 20: | ||
return 1 | ||
return 0 | ||
|
||
def update(self, rewards, states, actions, function='V'): | ||
visited = set() | ||
if self.function == 'V': | ||
for i, state in enumerate(states): | ||
state_name = self.get_state_name(state) | ||
if state_name in visited: | ||
continue | ||
G = 0 | ||
for j, reward in enumerate(rewards[i:], 1): | ||
G += self.gamma ** j * reward | ||
self.vreturns[state_name].append(G) | ||
self.values[state_name] = np.mean(self.vreturns[state_name]) | ||
visited.add(state_name) | ||
elif self.function == 'Q': | ||
for i, (state, action) in enumerate(zip(states, actions)): | ||
state_action_name = self.get_state_action_name(state, action) | ||
if state_action_name in visited: | ||
continue | ||
G = 0 | ||
for j, reward in enumerate(rewards[i:], 1): | ||
G += self.gamma ** j * reward | ||
self.qreturns[state_action_name].append(G) | ||
self.qs[state_action_name] = np.mean(self.qreturns[state_action_name]) | ||
visited.add(state_action_name) | ||
for state in states: | ||
Q_prime, A_prime = -np.inf, None | ||
for action in actions: | ||
state_action_name = self.get_state_action_name(state, action) | ||
curr_Q = self.qs[state_action_name] | ||
if curr_Q > Q_prime: | ||
Q_prime = curr_Q | ||
A_prime = action | ||
state_name = self.get_state_name(state) | ||
self.policy[state_name] = A_prime | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
class CartPoleNoob: | ||
def __init__(self, method, env, function='V', alpha=0.1, gamma=0.99, epsilon=0.1, n_bins=10): | ||
self.method = method | ||
self.alpha = alpha | ||
self.gamma = gamma | ||
self.epsilon = epsilon | ||
self.function = function | ||
self.actions = list(range(env.action_space.n)) | ||
self.rad = np.linspace(-0.2094, 0.2094, n_bins) | ||
self.values = {r: 0 for r in range(len(self.rad) + 1)} | ||
self.qs = {(r, a): 10 for r in range(len(self.rad) + 1) for a in self.actions} | ||
|
||
def choose_action(self, state): | ||
if self.method == 'naive': | ||
return self.naive_action(state) | ||
if self.method == 'egreedy': | ||
return self.epsilon_greedy(state) | ||
|
||
def naive_action(self, state): | ||
if state[2] < 0: | ||
return 0 | ||
return 1 | ||
|
||
def epsilon_greedy(self, state): | ||
if np.random.random() < self.epsilon: | ||
return np.random.choice(self.actions) | ||
else: | ||
s = self.get_bucket_index([state[2]])[0] | ||
action = np.array([self.qs[(s, a)] for a in self.actions]).argmax() | ||
return action | ||
|
||
def get_bucket_index(self, states): | ||
inds = np.digitize(states, self.rad) | ||
return inds | ||
|
||
def update(self, state, action, reward, state_): | ||
r, r_ = self.get_bucket_index([state[2], state_[2]]) | ||
if self.function == 'V': | ||
# TD update w/ bootstrap | ||
self.values[r] += self.alpha * (reward + self.gamma * self.values[r_] - self.values[r]) | ||
elif self.function == 'Q': | ||
Q_ = np.array([self.qs[(r_, a)] for a in self.actions]).max() | ||
self.qs[(r, action)] += self.alpha * (reward + self.gamma * Q_ - self.qs[(r, action)]) | ||
self.decrease_eps() | ||
|
||
def decrease_eps(self): | ||
self.epsilon = max(0.01, self.epsilon - 1e-5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import gym | ||
import argparse | ||
from tqdm import trange | ||
from policy.agent import BlackJackAgent | ||
|
||
|
||
parser = argparse.ArgumentParser(description='Black Jack Agents') | ||
parser.add_argument('--method', type=str, default='lucky', help='The name of the policy you wish to evaluate') | ||
parser.add_argument('--function', type=str, default='Q', help='The function to evaluate') | ||
parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for') | ||
args = parser.parse_args() | ||
|
||
|
||
def first_visit_monte_carlo(): | ||
env = gym.make('Blackjack-v0') | ||
agent = BlackJackAgent(args.method, env, args.function) | ||
for _ in trange(args.n_episodes): | ||
state, done = env.reset(), False | ||
states, actions, rewards = [state], [], [] | ||
while not done: | ||
action = agent.choose_action(state) | ||
state_, reward, done, _ = env.step(action) | ||
states.append(state) | ||
rewards.append(reward) | ||
actions.append(action) | ||
state = state_ | ||
agent.update(rewards, states, actions) | ||
|
||
print(agent.value_function(21, 2, True)) | ||
print(agent.q_function(16, 2, False, 0)) | ||
|
||
|
||
if __name__ == '__main__': | ||
first_visit_monte_carlo() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import gym | ||
import argparse | ||
from tqdm import trange | ||
from policy.agent import CartPoleNoob | ||
|
||
|
||
parser = argparse.ArgumentParser(description='Cartpole Agents') | ||
parser.add_argument('--method', type=str, default='egreedy', help='The name of the policy you wish to evaluate') | ||
parser.add_argument('--function', type=str, default='Q', help='The function to evaluate') | ||
parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for') | ||
args = parser.parse_args() | ||
|
||
|
||
def td(): | ||
env = gym.make('CartPole-v0') | ||
agent = CartPoleNoob(args.method, env, args.function) | ||
for _ in trange(args.n_episodes): | ||
state, done = env.reset(), False | ||
while not done: | ||
action = agent.choose_action(state) | ||
state_, reward, done, _ = env.step(action) | ||
agent.update(state, action, reward, state_) | ||
state = state_ | ||
print(agent.values) | ||
|
||
if __name__ == '__main__': | ||
td() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.