Skip to content

Commit 3ac5e05

Browse files
committed
Create bandit_epsilon.py
1 parent 16b800c commit 3ac5e05

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy as np
2+
from utils.epsilon import standard_epsilon, decaying_epsilon, exponential_decay_epsilon
3+
import matplotlib.pyplot as plt
4+
5+
class bandit:
6+
def __init__(self, win_rate):
7+
self.win_rate = win_rate
8+
self.n = 0
9+
self.sample_mean = 0
10+
11+
def pull(self):
12+
return np.random.random() < self.win_rate
13+
14+
def update(self, reward):
15+
self.n += 1
16+
learning_rate = 1 / self.n
17+
self.sample_mean = self.sample_mean + learning_rate * (reward - self.sample_mean)
18+
19+
if __name__ == '__main__':
20+
min_epsilon = 0.05
21+
init_epsilon = 0.9
22+
alpha = 0.999
23+
n_iter = 10000
24+
win_rate = [0.25, 0.35, 0.45]
25+
26+
eps_list = ["standard_epsilon", "decay_epsilon", "exp_epsilon"]
27+
28+
optimal_idx = np.argmax(win_rate)
29+
for key in eps_list:
30+
if key == 'standard_epsilon':
31+
epsilon = standard_epsilon(min_epsilon)
32+
bandit_list = [bandit(rate) for rate in win_rate]
33+
reward_list = []
34+
n_explore = 0
35+
n_exploit = 0
36+
n_optimal = 0
37+
for t in range(n_iter):
38+
if key == 'decaying_epsilon':
39+
epsilon = decaying_epsilon(min_epsilon, t)
40+
elif key == 'exp_epsilon':
41+
epsilon = exponential_decay_epsilon(min_epsilon, init_epsilon, alpha, t)
42+
43+
if np.random.random() < epsilon:
44+
n_explore += 1
45+
idx = np.random.choice(len(bandit_list)) #Explore
46+
else:
47+
n_exploit += 1
48+
idx = np.argmax([bandit.sample_mean for bandit in bandit_list]) #Exploit
49+
50+
if idx == optimal_idx:
51+
n_optimal += 1
52+
53+
reward = int(bandit_list[idx].pull())
54+
reward_list.append(reward)
55+
bandit_list[idx].update(reward)
56+
57+
for b in bandit_list:
58+
print(f"Under {key} : mean estimate is {b.sample_mean}")
59+
print("total reward earned:", np.sum(reward_list))
60+
print("overall win rate:", np.sum(reward_list) / n_iter)
61+
print("num_times_explored:", n_explore)
62+
print("num_times_exploited:", n_exploit)
63+
print("num times selected optimal bandit:", n_optimal)
64+
65+
cumulative_rewards = np.cumsum(reward_list)
66+
win_rates = cumulative_rewards / (np.arange(n_iter) + 1)
67+
plt.plot(win_rates)
68+
plt.title(f"{key}")
69+
plt.show()
70+

0 commit comments

Comments
 (0)