Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update trajectory_sampling.py #138

Merged
merged 1 commit into from
Jan 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions chapter08/trajectory_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm

matplotlib.use('Agg')

# 2 actions
ACTIONS = [0, 1]

Expand All @@ -23,12 +24,14 @@
# epsilon greedy for behavior policy
EPSILON = 0.1


# break tie randomly
def argmax(value):
max_q = np.max(value)
return np.random.choice([a for a, q in enumerate(value) if q == max_q])

class Task():

class Task:
# @n_states: number of non-terminal states
# @b: branch
# Each episode starts with state 0, and state n_states is a terminal state
Expand All @@ -46,8 +49,9 @@ def __init__(self, n_states, b):
def step(self, state, action):
if np.random.rand() < TERMINATION_PROB:
return self.n_states, 0
next = np.random.randint(self.b)
return self.transition[state, action, next], self.reward[state, action, next]
next_ = np.random.randint(self.b)
return self.transition[state, action, next_], self.reward[state, action, next_]


# Evaluate the value of the start state for the greedy policy
# derived from @q under the MDP @task
Expand All @@ -65,6 +69,7 @@ def evaluate_pi(q, task):
returns.append(rewards)
return np.mean(returns)


# perform expected update from a uniform state-action distribution of the MDP @task
# evaluate the learned q value every @eval_interval steps
def uniform(task, eval_interval):
Expand All @@ -84,6 +89,7 @@ def uniform(task, eval_interval):

return zip(*performance)


# perform expected update from an on-policy distribution of the MDP @task
# evaluate the learned q value every @eval_interval steps
def on_policy(task, eval_interval):
Expand Down Expand Up @@ -112,13 +118,14 @@ def on_policy(task, eval_interval):

return zip(*performance)


def figure_8_8():
num_states = [1000, 10000]
branch = [1, 3, 10]
methods = [on_policy, uniform]

# average accross 30 tasks
n_tasks = 30
# average across 30 tasks
n_tasks = 30

# number of evaluation points
x_ticks = 100
Expand All @@ -129,13 +136,14 @@ def figure_8_8():
for b in branch:
tasks = [Task(n, b) for _ in range(n_tasks)]
for method in methods:
steps = None
value = []
for task in tasks:
steps, v = method(task, MAX_STEPS / x_ticks)
value.append(v)
value = np.mean(np.asarray(value), axis=0)
plt.plot(steps, value, label='b = %d, %s' % (b, method.__name__))
plt.title('%d states' % (n))
plt.plot(steps, value, label=f'b = {b}, {method.__name__}')
plt.title(f'{n} states')

plt.ylabel('value of start state')
plt.legend()
Expand All @@ -146,5 +154,6 @@ def figure_8_8():
plt.savefig('../images/figure_8_8.png')
plt.close()


if __name__ == '__main__':
figure_8_8()