Skip to content

Commit 487ad6b

Browse files
committed
Create bandit_thompson_sampling.py
1 parent 0ab4ac5 commit 487ad6b

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from scipy.stats import norm
4+
5+
class bandit:
6+
def __init__(self, mean_rate):
7+
self.mean_rate = mean_rate
8+
self.m = 0
9+
self.lambda_ = 1
10+
self.sum_x = 0 # for convenience
11+
self.tau = 1
12+
self.N = 0
13+
14+
def pull(self):
15+
return np.random.randn() / np.sqrt(self.tau) + self.mean_rate
16+
17+
def sample(self):
18+
return np.random.randn() / np.sqrt(self.lambda_) + self.m
19+
20+
def update(self, reward):
21+
self.lambda_ += self.tau
22+
self.sum_x += reward
23+
self.m = self.tau*self.sum_x / self.lambda_
24+
self.N += 1
25+
26+
def plot(bandits, trial):
27+
x = np.linspace(-3, 6, 200)
28+
for b in bandits:
29+
y = norm.pdf(x, b.m, np.sqrt(1. / b.lambda_))
30+
plt.plot(x, y, label=f"real mean: {b.mean_rate:.4f}, num plays: {b.N}")
31+
plt.title(f"Bandit distributions after {trial} trials")
32+
plt.legend()
33+
plt.show()
34+
35+
36+
if __name__ == '__main__':
37+
n_iter = 10000
38+
mean_rate = [1.25, 2.35, 3.45]
39+
plot_pts = [5,100,1000,5000,9999]
40+
optimal_idx = np.argmax(mean_rate)
41+
bandit_list = [bandit(rate) for rate in mean_rate]
42+
reward_list = []
43+
n_optimal = 0
44+
N = 0
45+
for t in range(n_iter):
46+
if t in plot_pts:
47+
plot(bandit_list, t)
48+
N += 1
49+
idx = np.argmax([b.sample() for b in bandit_list]) #Thompson Sampling
50+
if idx == optimal_idx:
51+
n_optimal += 1
52+
reward = int(bandit_list[idx].pull())
53+
reward_list.append(reward)
54+
bandit_list[idx].update(reward)
55+
56+
for b in bandit_list:
57+
print(f"mean estimate is {b.m}")
58+
print("total reward earned:", np.mean(reward_list))
59+
print("overall win rate:", np.mean(reward_list))
60+
print("num times selected optimal bandit:", n_optimal)
61+

0 commit comments

Comments
 (0)