@@ -128,25 +128,47 @@ def main():
128
128
LL_nest = []
129
129
CR_nest = []
130
130
mu = 0.9
131
- dW2 = 0
132
- db2 = 0
133
- dW1 = 0
134
- db1 = 0
131
+ # alternate version uses dW
132
+ # dW2 = 0
133
+ # db2 = 0
134
+ # dW1 = 0
135
+ # db1 = 0
136
+ vW2 = 0
137
+ vb2 = 0
138
+ vW1 = 0
139
+ vb1 = 0
135
140
for i in xrange (max_iter ):
136
141
for j in xrange (n_batches ):
142
+ # because we want g(t) = grad(f(W(t-1) - lr*mu*dW(t-1)))
143
+ # dW(t) = mu*dW(t-1) + g(t)
144
+ # W(t) = W(t-1) - mu*dW(t)
145
+ W1_tmp = W1 - lr * mu * vW1
146
+ b1_tmp = b1 - lr * mu * vb1
147
+ W2_tmp = W2 - lr * mu * vW2
148
+ b2_tmp = b2 - lr * mu * vb2
149
+
137
150
Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
138
151
Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
139
- pYbatch , Z = forward (Xbatch , W1 , b1 , W2 , b2 )
152
+ # pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
153
+ pYbatch , Z = forward (Xbatch , W1_tmp , b1_tmp , W2_tmp , b2_tmp )
140
154
141
155
# updates
142
- dW2 = mu * mu * dW2 - (1 + mu )* lr * (derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2 )
143
- W2 += dW2
144
- db2 = mu * mu * db2 - (1 + mu )* lr * (derivative_b2 (Ybatch , pYbatch ) + reg * b2 )
145
- b2 += db2
146
- dW1 = mu * mu * dW1 - (1 + mu )* lr * (derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1 )
147
- W1 += dW1
148
- db1 = mu * mu * db1 - (1 + mu )* lr * (derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1 )
149
- b1 += db1
156
+ # dW2 = mu*mu*dW2 - (1 + mu)*lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
157
+ # W2 += dW2
158
+ # db2 = mu*mu*db2 - (1 + mu)*lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
159
+ # b2 += db2
160
+ # dW1 = mu*mu*dW1 - (1 + mu)*lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
161
+ # W1 += dW1
162
+ # db1 = mu*mu*db1 - (1 + mu)*lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
163
+ # b1 += db1
164
+ vW2 = mu * vW2 + derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2_tmp
165
+ W2 -= lr * vW2
166
+ vb2 = mu * vb2 + derivative_b2 (Ybatch , pYbatch ) + reg * b2_tmp
167
+ b2 -= lr * vb2
168
+ vW1 = mu * vW1 + derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2_tmp ) + reg * W1_tmp
169
+ W1 -= lr * vW1
170
+ vb1 = mu * vb1 + derivative_b1 (Z , Ybatch , pYbatch , W2_tmp ) + reg * b1_tmp
171
+ b1 -= lr * vb1
150
172
151
173
if j % print_period == 0 :
152
174
# calculate just for LL
0 commit comments