forked from fayimora/rl-coursework
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sarsa.py
134 lines (105 loc) · 4.43 KB
/
sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from collections import defaultdict
from random import randint, random
from state import State
from env import step
from common import epsilon_greedy_policy, plot_value_function, action_value_to_value_function, load
from progressbar import ProgressBar
import matplotlib
import pylab as plt
matplotlib.use('Agg')
from mc_control import monte_carlo_control
from datetime import datetime
def compute_mse(action_value_function):
mc_action_value_function = load('mc_result.dat')
err_sq, count = 0, 0
for dealer in xrange(1, 11):
for player in xrange(1, 22):
for action in xrange(0, 2):
v1 = action_value_function[(dealer, player, action)]
v2 = mc_action_value_function[(dealer, player, action)]
err_sq += (v1 - v2) ** 2
count += 1
mse = err_sq / count
return mse
def sarsa(lambd):
n_episodes = 1000
epi_batch = 100
episodes = xrange(n_episodes)
action_value_function = defaultdict(float)
n_zero = 100
n_s = defaultdict(int)
n_s_a = defaultdict(int)
if lambd == 0.0 or lambd == 1.0:
mses = []
for episode in episodes:
if episode%epi_batch == 0:
if lambd == 0.0 or lambd == 1.0:
mses.append(compute_mse(action_value_function))
# initialize state, action, epsilon, and eligibility-trace
state = State()
current_dealer = state.dealer
current_player = state.player
epsilon = float(n_zero) / (n_zero + n_s[(current_dealer, current_player)])
current_action = epsilon_greedy_policy(action_value_function, state, epsilon)
eligibility_trace = defaultdict(int)
while not state.terminal:
n_s[(current_dealer, current_player)] += 1
n_s_a[(current_dealer, current_player, current_action)] += 1
reward = step(state, current_action)
new_dealer = state.dealer
new_player = state.player
epsilon = float(n_zero) / (n_zero + n_s[(new_dealer, new_player)])
new_action = epsilon_greedy_policy(action_value_function, state, epsilon)
alpha = 1.0 / n_s_a[(current_dealer, current_player, current_action)]
prev_action_value = action_value_function[(current_dealer, current_player, current_action)]
new_action_value = action_value_function[(new_dealer, new_player, new_action)]
delta = reward + new_action_value - prev_action_value
eligibility_trace[(current_dealer, current_player, current_action)] += 1
for key in action_value_function.keys():
dealer, player, action = key
# update the action value function
action_value_function[(dealer, player, action)] \
+= alpha * delta * eligibility_trace[(dealer, player, action)]
# update eligibility-trace
eligibility_trace[(dealer, player, action)] *= lambd
# update state and action
current_dealer = new_dealer
current_player = new_player
current_action = new_action
if lambd == 0.0 or lambd == 1.0:
mses.append(compute_mse(action_value_function))
# plot mses curve
if lambd == 0.0 or lambd == 1.0:
print "Plotting learning curve for $\lambda$=",lambd
x = range(0, n_episodes + 1, epi_batch)
fig = plt.figure()
plt.title('Learning curve of MSE against episode number: $\lambda$ = ' + str(lambd))
plt.xlabel("episode number")
plt.xlim([0, n_episodes])
plt.xticks(range(0, n_episodes + 1, epi_batch))
plt.ylabel("Mean-Squared Error (MSE)")
plt.plot(x, mses)
fname = "mse_lambda%f_%s.png" % (lambd, str(datetime.now()))
plt.savefig(fname)
# plt.show()
mse = compute_mse(action_value_function)
return mse
if __name__ == '__main__':
mses = [0 for i in range(11)]
pbar = ProgressBar(maxval=len(mses)).start()
for i in range(11):
mses[i] = sarsa(lambd=float(i) / 10)
pbar.update(i)
pbar.finish()
# plot the mse against lambda
x = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
fig = plt.figure()
plt.title('Mean-Squared Error against $\lambda$')
plt.xlabel("$\lambda$")
plt.xlim([0., 1.])
plt.xticks(x)
plt.ylabel("Mean-Squared Error")
plt.plot(x, mses)
fname = "mse_vs_lamnda_" + str(datetime.now())+".png"
plt.savefig(fname)
# plt.show()