Skip to content

Commit eee6fee

Browse files
update momentum
1 parent d449c68 commit eee6fee

File tree

1 file changed

+74
-78
lines changed

1 file changed

+74
-78
lines changed

ann_class2/momentum.py

+74-78
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@ def main():
5050
W2 = np.random.randn(M, K) / np.sqrt(M)
5151
b2 = np.zeros(K)
5252

53+
# save initial weights
54+
W1_0 = W1.copy()
55+
b1_0 = b1.copy()
56+
W2_0 = W2.copy()
57+
b2_0 = b2.copy()
58+
5359
# 1. batch
54-
# cost = -16
55-
LL_batch = []
56-
CR_batch = []
60+
losses_batch = []
61+
errors_batch = []
5762
for i in range(max_iter):
5863
for j in range(n_batches):
5964
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
@@ -68,26 +73,25 @@ def main():
6873
b1 -= lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
6974

7075
if j % print_period == 0:
71-
# calculate just for LL
7276
pY, _ = forward(Xtest, W1, b1, W2, b2)
73-
ll = cost(pY, Ytest_ind)
74-
LL_batch.append(ll)
75-
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, ll))
77+
l = cost(pY, Ytest_ind)
78+
losses_batch.append(l)
79+
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, l))
7680

77-
err = error_rate(pY, Ytest)
78-
CR_batch.append(err)
79-
print("Error rate:", err)
81+
e = error_rate(pY, Ytest)
82+
errors_batch.append(e)
83+
print("Error rate:", e)
8084

8185
pY, _ = forward(Xtest, W1, b1, W2, b2)
8286
print("Final error rate:", error_rate(pY, Ytest))
8387

8488
# 2. batch with momentum
85-
W1 = np.random.randn(D, M) / np.sqrt(D)
86-
b1 = np.zeros(M)
87-
W2 = np.random.randn(M, K) / np.sqrt(M)
88-
b2 = np.zeros(K)
89-
LL_momentum = []
90-
CR_momentum = []
89+
W1 = W1_0.copy()
90+
b1 = b1_0.copy()
91+
W2 = W2_0.copy()
92+
b2 = b2_0.copy()
93+
losses_momentum = []
94+
errors_momentum = []
9195
mu = 0.9
9296
dW2 = 0
9397
db2 = 0
@@ -99,100 +103,92 @@ def main():
99103
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
100104
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
101105

106+
# gradients
107+
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
108+
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
109+
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
110+
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
111+
112+
# update velocities
113+
dW2 = mu*dW2 - lr*gW2
114+
db2 = mu*db2 - lr*gb2
115+
dW1 = mu*dW1 - lr*gW1
116+
db1 = mu*db1 - lr*gb1
117+
102118
# updates
103-
dW2 = mu*dW2 - lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
104119
W2 += dW2
105-
db2 = mu*db2 - lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
106120
b2 += db2
107-
dW1 = mu*dW1 - lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
108121
W1 += dW1
109-
db1 = mu*db1 - lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
110122
b1 += db1
111123

112124
if j % print_period == 0:
113-
# calculate just for LL
114125
pY, _ = forward(Xtest, W1, b1, W2, b2)
115-
# print "pY:", pY
116-
ll = cost(pY, Ytest_ind)
117-
LL_momentum.append(ll)
118-
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, ll))
119-
120-
err = error_rate(pY, Ytest)
121-
CR_momentum.append(err)
122-
print("Error rate:", err)
126+
l = cost(pY, Ytest_ind)
127+
losses_momentum.append(l)
128+
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, l))
129+
130+
e = error_rate(pY, Ytest)
131+
errors_momentum.append(e)
132+
print("Error rate:", e)
123133
pY, _ = forward(Xtest, W1, b1, W2, b2)
124134
print("Final error rate:", error_rate(pY, Ytest))
125135

126136

127137
# 3. batch with Nesterov momentum
128-
W1 = np.random.randn(D, M) / np.sqrt(D)
129-
b1 = np.zeros(M)
130-
W2 = np.random.randn(M, K) / np.sqrt(M)
131-
b2 = np.zeros(K)
132-
LL_nest = []
133-
CR_nest = []
138+
W1 = W1_0.copy()
139+
b1 = b1_0.copy()
140+
W2 = W2_0.copy()
141+
b2 = b2_0.copy()
142+
143+
losses_nesterov = []
144+
errors_nesterov = []
145+
134146
mu = 0.9
135-
# alternate version uses dW
136-
# dW2 = 0
137-
# db2 = 0
138-
# dW1 = 0
139-
# db1 = 0
140147
vW2 = 0
141148
vb2 = 0
142149
vW1 = 0
143150
vb1 = 0
144151
for i in range(max_iter):
145152
for j in range(n_batches):
146-
# because we want g(t) = grad(f(W(t-1) - lr*mu*dW(t-1)))
147-
# dW(t) = mu*dW(t-1) + g(t)
148-
# W(t) = W(t-1) - mu*dW(t)
149-
W1_tmp = W1 - lr*mu*vW1
150-
b1_tmp = b1 - lr*mu*vb1
151-
W2_tmp = W2 - lr*mu*vW2
152-
b2_tmp = b2 - lr*mu*vb2
153-
154153
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
155154
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
156-
# pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
157-
pYbatch, Z = forward(Xbatch, W1_tmp, b1_tmp, W2_tmp, b2_tmp)
155+
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
158156

159157
# updates
160-
# dW2 = mu*mu*dW2 - (1 + mu)*lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
161-
# W2 += dW2
162-
# db2 = mu*mu*db2 - (1 + mu)*lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
163-
# b2 += db2
164-
# dW1 = mu*mu*dW1 - (1 + mu)*lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
165-
# W1 += dW1
166-
# db1 = mu*mu*db1 - (1 + mu)*lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
167-
# b1 += db1
168-
vW2 = mu*vW2 + derivative_w2(Z, Ybatch, pYbatch) + reg*W2_tmp
169-
W2 -= lr*vW2
170-
vb2 = mu*vb2 + derivative_b2(Ybatch, pYbatch) + reg*b2_tmp
171-
b2 -= lr*vb2
172-
vW1 = mu*vW1 + derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2_tmp) + reg*W1_tmp
173-
W1 -= lr*vW1
174-
vb1 = mu*vb1 + derivative_b1(Z, Ybatch, pYbatch, W2_tmp) + reg*b1_tmp
175-
b1 -= lr*vb1
158+
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
159+
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
160+
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
161+
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
162+
163+
# v update
164+
vW2 = mu*vW2 - lr*gW2
165+
vb2 = mu*vb2 - lr*gb2
166+
vW1 = mu*vW1 - lr*gW1
167+
vb1 = mu*vb1 - lr*gb1
168+
169+
# param update
170+
W2 += mu*vW2 - lr*gW2
171+
b2 += mu*vb2 - lr*gb2
172+
W1 += mu*vW1 - lr*gW1
173+
b1 += mu*vb1 - lr*gb1
176174

177175
if j % print_period == 0:
178-
# calculate just for LL
179176
pY, _ = forward(Xtest, W1, b1, W2, b2)
180-
# print "pY:", pY
181-
ll = cost(pY, Ytest_ind)
182-
LL_nest.append(ll)
183-
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, ll))
184-
185-
err = error_rate(pY, Ytest)
186-
CR_nest.append(err)
187-
print("Error rate:", err)
177+
l = cost(pY, Ytest_ind)
178+
losses_nesterov.append(l)
179+
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, l))
180+
181+
e = error_rate(pY, Ytest)
182+
errors_nesterov.append(e)
183+
print("Error rate:", e)
188184
pY, _ = forward(Xtest, W1, b1, W2, b2)
189185
print("Final error rate:", error_rate(pY, Ytest))
190186

191187

192188

193-
plt.plot(LL_batch, label="batch")
194-
plt.plot(LL_momentum, label="momentum")
195-
plt.plot(LL_nest, label="nesterov")
189+
plt.plot(losses_batch, label="batch")
190+
plt.plot(losses_momentum, label="momentum")
191+
plt.plot(losses_nesterov, label="nesterov")
196192
plt.legend()
197193
plt.show()
198194

0 commit comments

Comments
 (0)