Skip to content

Commit 97e0471

Browse files
better rmsprop
1 parent 9739285 commit 97e0471

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

ann_class2/dropout_theano.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, hidden_layer_sizes, p_keep):
3535
self.hidden_layer_sizes = hidden_layer_sizes
3636
self.dropout_rates = p_keep
3737

38-
def fit(self, X, Y, learning_rate=1e-6, mu=0.99, decay=0.999, epochs=300, batch_sz=100, show_fig=False):
38+
def fit(self, X, Y, learning_rate=1e-4, mu=0.9, decay=0.9, epochs=8, batch_sz=100, show_fig=False):
3939
# make a validation set
4040
X, Y = shuffle(X, Y)
4141
X = X.astype(np.float32)
@@ -66,12 +66,6 @@ def fit(self, X, Y, learning_rate=1e-6, mu=0.99, decay=0.999, epochs=300, batch_
6666
for h in self.hidden_layers:
6767
self.params += h.params
6868

69-
# for momentum
70-
dparams = [theano.shared(np.zeros(p.get_value().shape)) for p in self.params]
71-
72-
# for rmsprop
73-
cache = [theano.shared(np.zeros(p.get_value().shape)) for p in self.params]
74-
7569
# set up theano functions and variables
7670
thX = T.matrix('X')
7771
thY = T.ivector('Y')
@@ -80,12 +74,23 @@ def fit(self, X, Y, learning_rate=1e-6, mu=0.99, decay=0.999, epochs=300, batch_
8074
# this cost is for training
8175
cost = -T.mean(T.log(pY_train[T.arange(thY.shape[0]), thY]))
8276

77+
# gradients wrt each param
78+
grads = T.grad(cost, self.params)
79+
80+
# for momentum
81+
dparams = [theano.shared(np.zeros_like(p.get_value())) for p in self.params]
82+
83+
# for rmsprop
84+
cache = [theano.shared(np.ones_like(p.get_value())) for p in self.params]
85+
86+
new_cache = [decay*c + (1-decay)*g*g for p, c, g in zip(self.params, cache, grads)]
87+
new_dparams = [mu*dp - learning_rate*g/T.sqrt(new_c + 1e-10) for p, new_c, dp, g in zip(self.params, new_cache, dparams, grads)]
8388
updates = [
84-
(c, decay*c + (1-decay)*T.grad(cost, p)*T.grad(cost, p)) for p, c in zip(self.params, cache)
89+
(c, new_c) for c, new_c in zip(cache, new_cache)
8590
] + [
86-
(p, p + mu*dp - learning_rate*T.grad(cost, p)/T.sqrt(c + 1e-10)) for p, c, dp in zip(self.params, cache, dparams)
91+
(dp, new_dp) for dp, new_dp in zip(dparams, new_dparams)
8792
] + [
88-
(dp, mu*dp - learning_rate*T.grad(cost, p)/T.sqrt(c + 1e-10)) for p, c, dp in zip(self.params, cache, dparams)
93+
(p, p + new_dp) for p, new_dp in zip(self.params, new_dparams)
8994
]
9095

9196
# momentum only

0 commit comments

Comments
 (0)