@@ -50,10 +50,15 @@ def main():
50
50
W2 = np .random .randn (M , K ) / np .sqrt (M )
51
51
b2 = np .zeros (K )
52
52
53
+ # save initial weights
54
+ W1_0 = W1 .copy ()
55
+ b1_0 = b1 .copy ()
56
+ W2_0 = W2 .copy ()
57
+ b2_0 = b2 .copy ()
58
+
53
59
# 1. batch
54
- # cost = -16
55
- LL_batch = []
56
- CR_batch = []
60
+ losses_batch = []
61
+ errors_batch = []
57
62
for i in range (max_iter ):
58
63
for j in range (n_batches ):
59
64
Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
@@ -68,26 +73,25 @@ def main():
68
73
b1 -= lr * (derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1 )
69
74
70
75
if j % print_period == 0 :
71
- # calculate just for LL
72
76
pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
73
- ll = cost (pY , Ytest_ind )
74
- LL_batch .append (ll )
75
- print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , ll ))
77
+ l = cost (pY , Ytest_ind )
78
+ losses_batch .append (l )
79
+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
76
80
77
- err = error_rate (pY , Ytest )
78
- CR_batch .append (err )
79
- print ("Error rate:" , err )
81
+ e = error_rate (pY , Ytest )
82
+ errors_batch .append (e )
83
+ print ("Error rate:" , e )
80
84
81
85
pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
82
86
print ("Final error rate:" , error_rate (pY , Ytest ))
83
87
84
88
# 2. batch with momentum
85
- W1 = np . random . randn ( D , M ) / np . sqrt ( D )
86
- b1 = np . zeros ( M )
87
- W2 = np . random . randn ( M , K ) / np . sqrt ( M )
88
- b2 = np . zeros ( K )
89
- LL_momentum = []
90
- CR_momentum = []
89
+ W1 = W1_0 . copy ( )
90
+ b1 = b1_0 . copy ( )
91
+ W2 = W2_0 . copy ( )
92
+ b2 = b2_0 . copy ( )
93
+ losses_momentum = []
94
+ errors_momentum = []
91
95
mu = 0.9
92
96
dW2 = 0
93
97
db2 = 0
@@ -99,100 +103,92 @@ def main():
99
103
Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
100
104
pYbatch , Z = forward (Xbatch , W1 , b1 , W2 , b2 )
101
105
106
+ # gradients
107
+ gW2 = derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2
108
+ gb2 = derivative_b2 (Ybatch , pYbatch ) + reg * b2
109
+ gW1 = derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1
110
+ gb1 = derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1
111
+
112
+ # update velocities
113
+ dW2 = mu * dW2 - lr * gW2
114
+ db2 = mu * db2 - lr * gb2
115
+ dW1 = mu * dW1 - lr * gW1
116
+ db1 = mu * db1 - lr * gb1
117
+
102
118
# updates
103
- dW2 = mu * dW2 - lr * (derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2 )
104
119
W2 += dW2
105
- db2 = mu * db2 - lr * (derivative_b2 (Ybatch , pYbatch ) + reg * b2 )
106
120
b2 += db2
107
- dW1 = mu * dW1 - lr * (derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1 )
108
121
W1 += dW1
109
- db1 = mu * db1 - lr * (derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1 )
110
122
b1 += db1
111
123
112
124
if j % print_period == 0 :
113
- # calculate just for LL
114
125
pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
115
- # print "pY:", pY
116
- ll = cost (pY , Ytest_ind )
117
- LL_momentum .append (ll )
118
- print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , ll ))
119
-
120
- err = error_rate (pY , Ytest )
121
- CR_momentum .append (err )
122
- print ("Error rate:" , err )
126
+ l = cost (pY , Ytest_ind )
127
+ losses_momentum .append (l )
128
+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
129
+
130
+ e = error_rate (pY , Ytest )
131
+ errors_momentum .append (e )
132
+ print ("Error rate:" , e )
123
133
pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
124
134
print ("Final error rate:" , error_rate (pY , Ytest ))
125
135
126
136
127
137
# 3. batch with Nesterov momentum
128
- W1 = np .random .randn (D , M ) / np .sqrt (D )
129
- b1 = np .zeros (M )
130
- W2 = np .random .randn (M , K ) / np .sqrt (M )
131
- b2 = np .zeros (K )
132
- LL_nest = []
133
- CR_nest = []
138
+ W1 = W1_0 .copy ()
139
+ b1 = b1_0 .copy ()
140
+ W2 = W2_0 .copy ()
141
+ b2 = b2_0 .copy ()
142
+
143
+ losses_nesterov = []
144
+ errors_nesterov = []
145
+
134
146
mu = 0.9
135
- # alternate version uses dW
136
- # dW2 = 0
137
- # db2 = 0
138
- # dW1 = 0
139
- # db1 = 0
140
147
vW2 = 0
141
148
vb2 = 0
142
149
vW1 = 0
143
150
vb1 = 0
144
151
for i in range (max_iter ):
145
152
for j in range (n_batches ):
146
- # because we want g(t) = grad(f(W(t-1) - lr*mu*dW(t-1)))
147
- # dW(t) = mu*dW(t-1) + g(t)
148
- # W(t) = W(t-1) - mu*dW(t)
149
- W1_tmp = W1 - lr * mu * vW1
150
- b1_tmp = b1 - lr * mu * vb1
151
- W2_tmp = W2 - lr * mu * vW2
152
- b2_tmp = b2 - lr * mu * vb2
153
-
154
153
Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
155
154
Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
156
- # pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
157
- pYbatch , Z = forward (Xbatch , W1_tmp , b1_tmp , W2_tmp , b2_tmp )
155
+ pYbatch , Z = forward (Xbatch , W1 , b1 , W2 , b2 )
158
156
159
157
# updates
160
- # dW2 = mu*mu*dW2 - (1 + mu)*lr*( derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
161
- # W2 += dW2
162
- # db2 = mu*mu*db2 - (1 + mu)*lr*(derivative_b2( Ybatch, pYbatch) + reg*b2)
163
- # b2 += db2
164
- # dW1 = mu*mu*dW1 - (1 + mu)*lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
165
- # W1 += dW1
166
- # db1 = mu*mu*db1 - (1 + mu)* lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
167
- # b1 += db1
168
- vW2 = mu * vW2 + derivative_w2 ( Z , Ybatch , pYbatch ) + reg * W2_tmp
169
- W2 -= lr * vW2
170
- vb2 = mu * vb2 + derivative_b2 ( Ybatch , pYbatch ) + reg * b2_tmp
171
- b2 -= lr * vb2
172
- vW1 = mu * vW1 + derivative_w1 ( Xbatch , Z , Ybatch , pYbatch , W2_tmp ) + reg * W1_tmp
173
- W1 -= lr * vW1
174
- vb1 = mu * vb1 + derivative_b1 ( Z , Ybatch , pYbatch , W2_tmp ) + reg * b1_tmp
175
- b1 -= lr * vb1
158
+ gW2 = derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2
159
+ gb2 = derivative_b2 ( Ybatch , pYbatch ) + reg * b2
160
+ gW1 = derivative_w1 ( Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1
161
+ gb1 = derivative_b1 ( Z , Ybatch , pYbatch , W2 ) + reg * b1
162
+
163
+ # v update
164
+ vW2 = mu * vW2 - lr * gW2
165
+ vb2 = mu * vb2 - lr * gb2
166
+ vW1 = mu * vW1 - lr * gW1
167
+ vb1 = mu * vb1 - lr * gb1
168
+
169
+ # param update
170
+ W2 + = mu * vW2 - lr * gW2
171
+ b2 += mu * vb2 - lr * gb2
172
+ W1 + = mu * vW1 - lr * gW1
173
+ b1 += mu * vb1 - lr * gb1
176
174
177
175
if j % print_period == 0 :
178
- # calculate just for LL
179
176
pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
180
- # print "pY:", pY
181
- ll = cost (pY , Ytest_ind )
182
- LL_nest .append (ll )
183
- print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , ll ))
184
-
185
- err = error_rate (pY , Ytest )
186
- CR_nest .append (err )
187
- print ("Error rate:" , err )
177
+ l = cost (pY , Ytest_ind )
178
+ losses_nesterov .append (l )
179
+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
180
+
181
+ e = error_rate (pY , Ytest )
182
+ errors_nesterov .append (e )
183
+ print ("Error rate:" , e )
188
184
pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
189
185
print ("Final error rate:" , error_rate (pY , Ytest ))
190
186
191
187
192
188
193
- plt .plot (LL_batch , label = "batch" )
194
- plt .plot (LL_momentum , label = "momentum" )
195
- plt .plot (LL_nest , label = "nesterov" )
189
+ plt .plot (losses_batch , label = "batch" )
190
+ plt .plot (losses_momentum , label = "momentum" )
191
+ plt .plot (losses_nesterov , label = "nesterov" )
196
192
plt .legend ()
197
193
plt .show ()
198
194
0 commit comments