-
Notifications
You must be signed in to change notification settings - Fork 615
Implement MovingAverage optimizer #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
47c9654
25d33f4
d2420c9
b8da9a6
ba51cb5
30b12ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
from tensorflow_addons.utils import keras_utils | ||
|
||
|
||
@keras_utils.register_keras_custom_object | ||
class MovingAverage(tf.keras.optimizers.Optimizer): | ||
"""Optimizer that computes a moving average of the variables. | ||
|
||
Empirically it has been found that using the moving average of the trained | ||
parameters of a deep network is better than using its trained parameters | ||
directly. This optimizer allows you to compute this moving average and swap | ||
the variables at save time so that any code outside of the training loop | ||
will use by default the average values instead of the original ones. | ||
|
||
Example of usage: | ||
Squadrick marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
opt = tf.keras.optimizers.SGD(learning_rate) | ||
opt = tfa.optimizers.MovingAverage(opt) | ||
|
||
Squadrick marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
""" | ||
|
||
def __init__(self, | ||
optimizer, | ||
average_decay=0.1, | ||
num_updates=None, | ||
sequential_update=True, | ||
name="MovingAverage", | ||
**kwargs): | ||
|
||
super(MovingAverage, self).__init__(name, **kwargs) | ||
|
||
if not isinstance(optimizer, tf.keras.optimizers.Optimizer): | ||
raise TypeError( | ||
"optimizer is not an object of tf.keras.optimizers.Optimizer") | ||
|
||
if num_updates is not None and not isinstance(num_updates, int): | ||
raise TypeError("num_updates must be None or of integer type") | ||
|
||
if not isinstance(sequential_update, bool): | ||
raise TypeError("sequential_update must be of bool type") | ||
|
||
self._optimizer = optimizer | ||
|
||
with tf.name_scope(name): | ||
self._ema = tf.train.ExponentialMovingAverage( | ||
average_decay, num_updates=num_updates) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the constructor pass |
||
|
||
self._set_hyper("average_decay", average_decay) | ||
self._num_updates = num_updates | ||
self._sequential_update = sequential_update | ||
self._init = True | ||
|
||
def apply_gradients(self, grads_and_vars, name=None): | ||
var_list = [v for (_, v) in grads_and_vars] | ||
|
||
if tf.executing_eagerly() and self._init: | ||
# this to ensure that var_list is registered initially | ||
self._ema.apply(var_list) | ||
self._init = False | ||
|
||
train_op = self._optimizer.apply_gradients(grads_and_vars, name=name) | ||
|
||
if self._sequential_update: | ||
with tf.control_dependencies([train_op]): | ||
ma_op = self._ema.apply(var_list) | ||
else: | ||
ma_op = self._ema.apply(var_list) | ||
|
||
return tf.group(train_op, ma_op, name="train_with_avg") | ||
|
||
def get_config(self): | ||
config = { | ||
'average_decay': self._serialize_hyperparameter('average_decay'), | ||
'num_updates': self._num_updates, | ||
'sequential_update': self._sequential_update | ||
} | ||
base_config = self._optimizer.get_config() | ||
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
def assign_average_vars(self, var_list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checkpoints are usually saved in regular intervals. Is it a standard practice to continue training with the averaged variables after saving a checkpoint? |
||
"""Update variables in var_list with the running mean of the variables. | ||
|
||
Example: | ||
```python | ||
model = tf.Sequential([...]) | ||
opt = tfa.optimizers.MovingAverage( | ||
tf.keras.optimizers.SGD(lr=2.0), 0.5) | ||
|
||
model.compile(opt, ...) | ||
model.fit(x, y, ...) | ||
|
||
# Update the weights to their mean before saving | ||
opt.assign_average_vars(model.variables) | ||
|
||
model.save('model.h5') | ||
``` | ||
""" | ||
assign = tf.group([v.assign(self._ema.average(v)) for v in var_list]) | ||
return assign | ||
|
||
@property | ||
def weights(self): | ||
return self._optimizer.weights | ||
|
||
def _resource_apply_dense(self, grad, var): | ||
return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access | ||
|
||
def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): | ||
return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access | ||
grad, var, indices) | ||
|
||
def _resource_apply_sparse(self, grad, var, indices): | ||
return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for MovingAverage optimizers.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
from tensorflow_addons.optimizers import MovingAverage | ||
from tensorflow_addons.utils import test_utils | ||
|
||
|
||
class MovingAverageTest(tf.test.TestCase): | ||
@test_utils.run_in_graph_and_eager_modes | ||
def test_run(self): | ||
Squadrick marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for sequential_update in [True, False]: | ||
var0 = tf.Variable([1.0, 2.0]) | ||
var1 = tf.Variable([3.0, 4.0]) | ||
|
||
grads0 = tf.constant([0.1, 0.1]) | ||
grads1 = tf.constant([0.01, 0.01]) | ||
|
||
grads_and_vars = list(zip([grads0, grads1], [var0, var1])) | ||
|
||
opt = MovingAverage( | ||
tf.keras.optimizers.SGD(lr=2.0), | ||
average_decay=0.5, | ||
sequential_update=sequential_update) | ||
|
||
if not tf.executing_eagerly(): | ||
update = opt.apply_gradients(grads_and_vars) | ||
self.evaluate(tf.compat.v1.global_variables_initializer()) | ||
self.evaluate(update) | ||
self.evaluate(update) | ||
else: | ||
opt.apply_gradients(grads_and_vars) | ||
opt.apply_gradients(grads_and_vars) | ||
|
||
self.assertAllClose(var0.read_value(), [0.6, 1.6]) | ||
self.assertAllClose(var1.read_value(), [2.96, 3.96]) | ||
|
||
ema_var0 = opt._ema.average(var0) # pylint: disable=protected-access | ||
ema_var1 = opt._ema.average(var1) # pylint: disable=protected-access | ||
|
||
if sequential_update: | ||
self.assertAllClose(ema_var0.read_value(), [0.75, 1.75]) | ||
self.assertAllClose(ema_var1.read_value(), [2.975, 3.975]) | ||
|
||
assign = opt.assign_average_vars([var0, var1]) | ||
self.evaluate(assign) | ||
|
||
if sequential_update: | ||
self.assertAllClose(var0.read_value(), [0.75, 1.75]) | ||
self.assertAllClose(var1.read_value(), [2.975, 3.975]) | ||
|
||
perturb = tf.group([ | ||
var0.assign_add([1.0, 1.0]), | ||
var1.assign_add([2.0, 2.0]), | ||
ema_var0.assign_add([3.0, 3.0]), | ||
ema_var1.assign_add([4.0, 4.0]) | ||
]) | ||
self.evaluate(perturb) | ||
|
||
if sequential_update: | ||
self.assertAllClose(var0.read_value(), [1.75, 2.75]) | ||
self.assertAllClose(var1.read_value(), [4.975, 5.975]) | ||
self.assertAllClose(ema_var0.read_value(), [3.75, 4.75]) | ||
self.assertAllClose(ema_var1.read_value(), [6.975, 7.975]) | ||
|
||
@test_utils.run_in_graph_and_eager_modes | ||
def test_opt_failure(self): | ||
base_opt = None | ||
for sequential_update in [True, False]: | ||
with self.assertRaises(TypeError): | ||
MovingAverage(base_opt, 0.5, sequential_update) | ||
|
||
@test_utils.run_in_graph_and_eager_modes | ||
def test_model_weights_update(self): | ||
grad = tf.Variable([[0.1]]) | ||
model = tf.keras.Sequential([ | ||
tf.keras.layers.Dense( | ||
1, | ||
kernel_initializer=tf.keras.initializers.Constant([[1.0]]), | ||
use_bias=False) | ||
]) | ||
model.build(input_shape=[1, 1]) | ||
|
||
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), 0.5) | ||
update = opt.apply_gradients(list(zip([grad], model.variables))) | ||
|
||
self.evaluate(tf.compat.v1.global_variables_initializer()) | ||
self.evaluate(update) | ||
self.assertAllClose(model.variables[0].read_value(), [[0.8]]) | ||
|
||
mean_update = opt.assign_average_vars(model.variables) | ||
self.evaluate(mean_update) | ||
self.assertAllClose(model.variables[0].read_value(), [[0.9]]) | ||
|
||
@test_utils.run_in_graph_and_eager_modes | ||
def test_config(self): | ||
sgd_opt = tf.keras.optimizers.SGD( | ||
lr=2.0, nesterov=True, momentum=0.3, decay=0.1) | ||
opt = MovingAverage( | ||
sgd_opt, | ||
average_decay=0.5, | ||
num_updates=100, | ||
sequential_update=False) | ||
config = opt.get_config() | ||
|
||
self.assertEqual(config['average_decay'], 0.5) | ||
self.assertEqual(config['decay'], 0.1) | ||
self.assertEqual(config['learning_rate'], 2.0) | ||
self.assertEqual(config['momentum'], 0.3) | ||
self.assertEqual(config['name'], 'SGD') | ||
self.assertEqual(config['nesterov'], True) | ||
self.assertEqual(config['num_updates'], 100) | ||
self.assertEqual(config['sequential_update'], False) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Uh oh!
There was an error while loading. Please reload this page.