1+ import numpy as np
2+ import matplotlib .pyplot as plt
3+ from scipy .stats import norm
4+
5+ class bandit :
6+ def __init__ (self , mean_rate ):
7+ self .mean_rate = mean_rate
8+ self .m = 0
9+ self .lambda_ = 1
10+ self .sum_x = 0 # for convenience
11+ self .tau = 1
12+ self .N = 0
13+
14+ def pull (self ):
15+ return np .random .randn () / np .sqrt (self .tau ) + self .mean_rate
16+
17+ def sample (self ):
18+ return np .random .randn () / np .sqrt (self .lambda_ ) + self .m
19+
20+ def update (self , reward ):
21+ self .lambda_ += self .tau
22+ self .sum_x += reward
23+ self .m = self .tau * self .sum_x / self .lambda_
24+ self .N += 1
25+
26+ def plot (bandits , trial ):
27+ x = np .linspace (- 3 , 6 , 200 )
28+ for b in bandits :
29+ y = norm .pdf (x , b .m , np .sqrt (1. / b .lambda_ ))
30+ plt .plot (x , y , label = f"real mean: { b .mean_rate :.4f} , num plays: { b .N } " )
31+ plt .title (f"Bandit distributions after { trial } trials" )
32+ plt .legend ()
33+ plt .show ()
34+
35+
36+ if __name__ == '__main__' :
37+ n_iter = 10000
38+ mean_rate = [1.25 , 2.35 , 3.45 ]
39+ plot_pts = [5 ,100 ,1000 ,5000 ,9999 ]
40+ optimal_idx = np .argmax (mean_rate )
41+ bandit_list = [bandit (rate ) for rate in mean_rate ]
42+ reward_list = []
43+ n_optimal = 0
44+ N = 0
45+ for t in range (n_iter ):
46+ if t in plot_pts :
47+ plot (bandit_list , t )
48+ N += 1
49+ idx = np .argmax ([b .sample () for b in bandit_list ]) #Thompson Sampling
50+ if idx == optimal_idx :
51+ n_optimal += 1
52+ reward = int (bandit_list [idx ].pull ())
53+ reward_list .append (reward )
54+ bandit_list [idx ].update (reward )
55+
56+ for b in bandit_list :
57+ print (f"mean estimate is { b .m } " )
58+ print ("total reward earned:" , np .mean (reward_list ))
59+ print ("overall win rate:" , np .mean (reward_list ))
60+ print ("num times selected optimal bandit:" , n_optimal )
61+
0 commit comments