|
| 1 | +import numpy as np |
| 2 | +from utils.epsilon import standard_epsilon, decaying_epsilon, exponential_decay_epsilon |
| 3 | +import matplotlib.pyplot as plt |
| 4 | + |
| 5 | +class bandit: |
| 6 | + def __init__(self, win_rate): |
| 7 | + self.win_rate = win_rate |
| 8 | + self.n = 0 |
| 9 | + self.sample_mean = 0 |
| 10 | + |
| 11 | + def pull(self): |
| 12 | + return np.random.random() < self.win_rate |
| 13 | + |
| 14 | + def update(self, reward): |
| 15 | + self.n += 1 |
| 16 | + learning_rate = 1 / self.n |
| 17 | + self.sample_mean = self.sample_mean + learning_rate * (reward - self.sample_mean) |
| 18 | + |
| 19 | +if __name__ == '__main__': |
| 20 | + min_epsilon = 0.05 |
| 21 | + init_epsilon = 0.9 |
| 22 | + alpha = 0.999 |
| 23 | + n_iter = 10000 |
| 24 | + win_rate = [0.25, 0.35, 0.45] |
| 25 | + |
| 26 | + eps_list = ["standard_epsilon", "decay_epsilon", "exp_epsilon"] |
| 27 | + |
| 28 | + optimal_idx = np.argmax(win_rate) |
| 29 | + for key in eps_list: |
| 30 | + if key == 'standard_epsilon': |
| 31 | + epsilon = standard_epsilon(min_epsilon) |
| 32 | + bandit_list = [bandit(rate) for rate in win_rate] |
| 33 | + reward_list = [] |
| 34 | + n_explore = 0 |
| 35 | + n_exploit = 0 |
| 36 | + n_optimal = 0 |
| 37 | + for t in range(n_iter): |
| 38 | + if key == 'decaying_epsilon': |
| 39 | + epsilon = decaying_epsilon(min_epsilon, t) |
| 40 | + elif key == 'exp_epsilon': |
| 41 | + epsilon = exponential_decay_epsilon(min_epsilon, init_epsilon, alpha, t) |
| 42 | + |
| 43 | + if np.random.random() < epsilon: |
| 44 | + n_explore += 1 |
| 45 | + idx = np.random.choice(len(bandit_list)) #Explore |
| 46 | + else: |
| 47 | + n_exploit += 1 |
| 48 | + idx = np.argmax([bandit.sample_mean for bandit in bandit_list]) #Exploit |
| 49 | + |
| 50 | + if idx == optimal_idx: |
| 51 | + n_optimal += 1 |
| 52 | + |
| 53 | + reward = int(bandit_list[idx].pull()) |
| 54 | + reward_list.append(reward) |
| 55 | + bandit_list[idx].update(reward) |
| 56 | + |
| 57 | + for b in bandit_list: |
| 58 | + print(f"Under {key} : mean estimate is {b.sample_mean}") |
| 59 | + print("total reward earned:", np.sum(reward_list)) |
| 60 | + print("overall win rate:", np.sum(reward_list) / n_iter) |
| 61 | + print("num_times_explored:", n_explore) |
| 62 | + print("num_times_exploited:", n_exploit) |
| 63 | + print("num times selected optimal bandit:", n_optimal) |
| 64 | + |
| 65 | + cumulative_rewards = np.cumsum(reward_list) |
| 66 | + win_rates = cumulative_rewards / (np.arange(n_iter) + 1) |
| 67 | + plt.plot(win_rates) |
| 68 | + plt.title(f"{key}") |
| 69 | + plt.show() |
| 70 | + |
0 commit comments