Skip to content

Commit

Permalink
Merge branch 'sursu-where-q_best'
Browse files Browse the repository at this point in the history
  • Loading branch information
ShangtongZhang committed Jun 29, 2019
2 parents d593539 + 16f655c commit af7336a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions chapter02/ten_armed_testbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def act(self):
UCB_estimation = self.q_estimation + \
self.UCB_param * np.sqrt(np.log(self.time + 1) / (self.action_count + 1e-5))
q_best = np.max(UCB_estimation)
return np.random.choice([action for action, q in enumerate(UCB_estimation) if q == q_best])
return np.random.choice(np.where(UCB_estimation == q_best)[0])

if self.gradient:
exp_est = np.exp(self.q_estimation)
self.action_prob = exp_est / np.sum(exp_est)
return np.random.choice(self.indices, p=self.action_prob)

q_best = np.max(self.q_estimation)
return np.random.choice([action for action, q in enumerate(self.q_estimation) if q == q_best])
return np.random.choice(np.where(self.q_estimation == q_best)[0])

# take an action, update estimation for this action
def step(self, action):
Expand Down Expand Up @@ -97,8 +97,8 @@ def step(self, action):


def simulate(runs, time, bandits):
best_action_counts = np.zeros((len(bandits), runs, time))
rewards = np.zeros(best_action_counts.shape)
rewards = np.zeros((len(bandits), runs, time))
best_action_counts = np.zeros(rewards.shape)
for i, bandit in enumerate(bandits):
for r in trange(runs):
bandit.reset()
Expand Down

0 comments on commit af7336a

Please sign in to comment.