@@ -163,23 +163,23 @@ def main():
163
163
# updates
164
164
gW2 = derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2
165
165
cache_W2 = decay_rate * cache_W2 + (1 - decay_rate )* gW2 * gW2
166
- dW2 = mu * dW2 - (1 - mu ) * lr0 * gW2 / (np .sqrt (cache_W2 ) + eps )
167
- W2 + = dW2
166
+ dW2 = mu * dW2 + (1 - mu ) * lr0 * gW2 / (np .sqrt (cache_W2 ) + eps )
167
+ W2 - = dW2
168
168
169
169
gb2 = derivative_b2 (Ybatch , pYbatch ) + reg * b2
170
170
cache_b2 = decay_rate * cache_b2 + (1 - decay_rate )* gb2 * gb2
171
- db2 = mu * db2 - (1 - mu ) * lr0 * gb2 / (np .sqrt (cache_b2 ) + eps )
172
- b2 + = db2
171
+ db2 = mu * db2 + (1 - mu ) * lr0 * gb2 / (np .sqrt (cache_b2 ) + eps )
172
+ b2 - = db2
173
173
174
174
gW1 = derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1
175
175
cache_W1 = decay_rate * cache_W1 + (1 - decay_rate )* gW1 * gW1
176
- dW1 = mu * dW1 - (1 - mu ) * lr0 * gW1 / (np .sqrt (cache_W1 ) + eps )
177
- W1 + = dW1
176
+ dW1 = mu * dW1 + (1 - mu ) * lr0 * gW1 / (np .sqrt (cache_W1 ) + eps )
177
+ W1 - = dW1
178
178
179
179
gb1 = derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1
180
180
cache_b1 = decay_rate * cache_b1 + (1 - decay_rate )* gb1 * gb1
181
- db1 = mu * db1 - (1 - mu ) * lr0 * gb1 / (np .sqrt (cache_b1 ) + eps )
182
- b1 + = db1
181
+ db1 = mu * db1 + (1 - mu ) * lr0 * gb1 / (np .sqrt (cache_b1 ) + eps )
182
+ b1 - = db1
183
183
184
184
if j % print_period == 0 :
185
185
pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
0 commit comments