Skip to content

Commit 226fbef

Browse files
alt nesterov
1 parent d1d5903 commit 226fbef

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

ann_class2/momentum.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -128,25 +128,47 @@ def main():
128128
LL_nest = []
129129
CR_nest = []
130130
mu = 0.9
131-
dW2 = 0
132-
db2 = 0
133-
dW1 = 0
134-
db1 = 0
131+
# alternate version uses dW
132+
# dW2 = 0
133+
# db2 = 0
134+
# dW1 = 0
135+
# db1 = 0
136+
vW2 = 0
137+
vb2 = 0
138+
vW1 = 0
139+
vb1 = 0
135140
for i in xrange(max_iter):
136141
for j in xrange(n_batches):
142+
# because we want g(t) = grad(f(W(t-1) - lr*mu*dW(t-1)))
143+
# dW(t) = mu*dW(t-1) + g(t)
144+
# W(t) = W(t-1) - mu*dW(t)
145+
W1_tmp = W1 - lr*mu*vW1
146+
b1_tmp = b1 - lr*mu*vb1
147+
W2_tmp = W2 - lr*mu*vW2
148+
b2_tmp = b2 - lr*mu*vb2
149+
137150
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
138151
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
139-
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
152+
# pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
153+
pYbatch, Z = forward(Xbatch, W1_tmp, b1_tmp, W2_tmp, b2_tmp)
140154

141155
# updates
142-
dW2 = mu*mu*dW2 - (1 + mu)*lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
143-
W2 += dW2
144-
db2 = mu*mu*db2 - (1 + mu)*lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
145-
b2 += db2
146-
dW1 = mu*mu*dW1 - (1 + mu)*lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
147-
W1 += dW1
148-
db1 = mu*mu*db1 - (1 + mu)*lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
149-
b1 += db1
156+
# dW2 = mu*mu*dW2 - (1 + mu)*lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
157+
# W2 += dW2
158+
# db2 = mu*mu*db2 - (1 + mu)*lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
159+
# b2 += db2
160+
# dW1 = mu*mu*dW1 - (1 + mu)*lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
161+
# W1 += dW1
162+
# db1 = mu*mu*db1 - (1 + mu)*lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
163+
# b1 += db1
164+
vW2 = mu*vW2 + derivative_w2(Z, Ybatch, pYbatch) + reg*W2_tmp
165+
W2 -= lr*vW2
166+
vb2 = mu*vb2 + derivative_b2(Ybatch, pYbatch) + reg*b2_tmp
167+
b2 -= lr*vb2
168+
vW1 = mu*vW1 + derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2_tmp) + reg*W1_tmp
169+
W1 -= lr*vW1
170+
vb1 = mu*vb1 + derivative_b1(Z, Ybatch, pYbatch, W2_tmp) + reg*b1_tmp
171+
b1 -= lr*vb1
150172

151173
if j % print_period == 0:
152174
# calculate just for LL

0 commit comments

Comments
 (0)