Skip to content

Commit 0d25fb2

Browse files
update
1 parent 0cb1607 commit 0d25fb2

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ann_class2/adam.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,23 @@ def main():
163163
# updates
164164
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
165165
cache_W2 = decay_rate*cache_W2 + (1 - decay_rate)*gW2*gW2
166-
dW2 = mu * dW2 - (1 - mu) * lr0 * gW2 / (np.sqrt(cache_W2) + eps)
167-
W2 += dW2
166+
dW2 = mu * dW2 + (1 - mu) * lr0 * gW2 / (np.sqrt(cache_W2) + eps)
167+
W2 -= dW2
168168

169169
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
170170
cache_b2 = decay_rate*cache_b2 + (1 - decay_rate)*gb2*gb2
171-
db2 = mu * db2 - (1 - mu) * lr0 * gb2 / (np.sqrt(cache_b2) + eps)
172-
b2 += db2
171+
db2 = mu * db2 + (1 - mu) * lr0 * gb2 / (np.sqrt(cache_b2) + eps)
172+
b2 -= db2
173173

174174
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
175175
cache_W1 = decay_rate*cache_W1 + (1 - decay_rate)*gW1*gW1
176-
dW1 = mu * dW1 - (1 - mu) * lr0 * gW1 / (np.sqrt(cache_W1) + eps)
177-
W1 += dW1
176+
dW1 = mu * dW1 + (1 - mu) * lr0 * gW1 / (np.sqrt(cache_W1) + eps)
177+
W1 -= dW1
178178

179179
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
180180
cache_b1 = decay_rate*cache_b1 + (1 - decay_rate)*gb1*gb1
181-
db1 = mu * db1 - (1 - mu) * lr0 * gb1 / (np.sqrt(cache_b1) + eps)
182-
b1 += db1
181+
db1 = mu * db1 + (1 - mu) * lr0 * gb1 / (np.sqrt(cache_b1) + eps)
182+
b1 -= db1
183183

184184
if j % print_period == 0:
185185
pY, _ = forward(Xtest, W1, b1, W2, b2)

0 commit comments

Comments
 (0)