Skip to content

Implement MovingAverage optimizer #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = [
"__init__.py",
"lazy_adam.py",
"moving_average.py",
],
srcs_version = "PY2AND3",
deps = [
Expand All @@ -26,3 +27,16 @@ py_test(
":optimizers",
],
)

py_test(
name = "moving_average_test",
size = "small",
srcs = [
"moving_average_test.py",
],
main = "moving_average_test.py",
srcs_version = "PY2AND3",
deps = [
":optimizers",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/optimizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| lazy_adam | SIG-Addons | addons@tensorflow.org |
| moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com |

## Components
| Submodule | Optimizer | Reference |
|:----------------------- |:---------------------- |:---------|
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
| moving_average | MovingAverage | |


## Contribution Guidelines
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from __future__ import print_function

from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.moving_average import MovingAverage
134 changes: 134 additions & 0 deletions tensorflow_addons/optimizers/moving_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils


@keras_utils.register_keras_custom_object
class MovingAverage(tf.keras.optimizers.Optimizer):
"""Optimizer that computes a moving average of the variables.

Empirically it has been found that using the moving average of the trained
parameters of a deep network is better than using its trained parameters
directly. This optimizer allows you to compute this moving average and swap
the variables at save time so that any code outside of the training loop
will use by default the average values instead of the original ones.

Example of usage:

```python
opt = tf.keras.optimizers.SGD(learning_rate)
opt = tfa.optimizers.MovingAverage(opt)

```
"""

def __init__(self,
optimizer,
average_decay=0.1,
num_updates=None,
sequential_update=True,
name="MovingAverage",
**kwargs):

super(MovingAverage, self).__init__(name, **kwargs)

if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.Optimizer")

if num_updates is not None and not isinstance(num_updates, int):
raise TypeError("num_updates must be None or of integer type")

if not isinstance(sequential_update, bool):
raise TypeError("sequential_update must be of bool type")

self._optimizer = optimizer

with tf.name_scope(name):
self._ema = tf.train.ExponentialMovingAverage(
average_decay, num_updates=num_updates)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the constructor pass optimizer.iterations as num_updates?


self._set_hyper("average_decay", average_decay)
self._num_updates = num_updates
self._sequential_update = sequential_update
self._init = True

def apply_gradients(self, grads_and_vars, name=None):
var_list = [v for (_, v) in grads_and_vars]

if tf.executing_eagerly() and self._init:
# this to ensure that var_list is registered initially
self._ema.apply(var_list)
self._init = False

train_op = self._optimizer.apply_gradients(grads_and_vars, name=name)

if self._sequential_update:
with tf.control_dependencies([train_op]):
ma_op = self._ema.apply(var_list)
else:
ma_op = self._ema.apply(var_list)

return tf.group(train_op, ma_op, name="train_with_avg")

def get_config(self):
config = {
'average_decay': self._serialize_hyperparameter('average_decay'),
'num_updates': self._num_updates,
'sequential_update': self._sequential_update
}
base_config = self._optimizer.get_config()
return dict(list(base_config.items()) + list(config.items()))

def assign_average_vars(self, var_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoints are usually saved in regular intervals. Is it a standard practice to continue training with the averaged variables after saving a checkpoint?

"""Update variables in var_list with the running mean of the variables.

Example:
```python
model = tf.Sequential([...])
opt = tfa.optimizers.MovingAverage(
tf.keras.optimizers.SGD(lr=2.0), 0.5)

model.compile(opt, ...)
model.fit(x, y, ...)

# Update the weights to their mean before saving
opt.assign_average_vars(model.variables)

model.save('model.h5')
```
"""
assign = tf.group([v.assign(self._ema.average(v)) for v in var_list])
return assign

@property
def weights(self):
return self._optimizer.weights

def _resource_apply_dense(self, grad, var):
return self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access

def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
return self._optimizer._resource_apply_sparse_duplicate_indices( # pylint: disable=protected-access
grad, var, indices)

def _resource_apply_sparse(self, grad, var, indices):
return self._optimizer._resource_apply_sparse(grad, var, indices) # pylint: disable=protected-access
134 changes: 134 additions & 0 deletions tensorflow_addons/optimizers/moving_average_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MovingAverage optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow_addons.optimizers import MovingAverage
from tensorflow_addons.utils import test_utils


class MovingAverageTest(tf.test.TestCase):
@test_utils.run_in_graph_and_eager_modes
def test_run(self):
for sequential_update in [True, False]:
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])

grads0 = tf.constant([0.1, 0.1])
grads1 = tf.constant([0.01, 0.01])

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))

opt = MovingAverage(
tf.keras.optimizers.SGD(lr=2.0),
average_decay=0.5,
sequential_update=sequential_update)

if not tf.executing_eagerly():
update = opt.apply_gradients(grads_and_vars)
self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate(update)
self.evaluate(update)
else:
opt.apply_gradients(grads_and_vars)
opt.apply_gradients(grads_and_vars)

self.assertAllClose(var0.read_value(), [0.6, 1.6])
self.assertAllClose(var1.read_value(), [2.96, 3.96])

ema_var0 = opt._ema.average(var0) # pylint: disable=protected-access
ema_var1 = opt._ema.average(var1) # pylint: disable=protected-access

if sequential_update:
self.assertAllClose(ema_var0.read_value(), [0.75, 1.75])
self.assertAllClose(ema_var1.read_value(), [2.975, 3.975])

assign = opt.assign_average_vars([var0, var1])
self.evaluate(assign)

if sequential_update:
self.assertAllClose(var0.read_value(), [0.75, 1.75])
self.assertAllClose(var1.read_value(), [2.975, 3.975])

perturb = tf.group([
var0.assign_add([1.0, 1.0]),
var1.assign_add([2.0, 2.0]),
ema_var0.assign_add([3.0, 3.0]),
ema_var1.assign_add([4.0, 4.0])
])
self.evaluate(perturb)

if sequential_update:
self.assertAllClose(var0.read_value(), [1.75, 2.75])
self.assertAllClose(var1.read_value(), [4.975, 5.975])
self.assertAllClose(ema_var0.read_value(), [3.75, 4.75])
self.assertAllClose(ema_var1.read_value(), [6.975, 7.975])

@test_utils.run_in_graph_and_eager_modes
def test_opt_failure(self):
base_opt = None
for sequential_update in [True, False]:
with self.assertRaises(TypeError):
MovingAverage(base_opt, 0.5, sequential_update)

@test_utils.run_in_graph_and_eager_modes
def test_model_weights_update(self):
grad = tf.Variable([[0.1]])
model = tf.keras.Sequential([
tf.keras.layers.Dense(
1,
kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
use_bias=False)
])
model.build(input_shape=[1, 1])

opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), 0.5)
update = opt.apply_gradients(list(zip([grad], model.variables)))

self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate(update)
self.assertAllClose(model.variables[0].read_value(), [[0.8]])

mean_update = opt.assign_average_vars(model.variables)
self.evaluate(mean_update)
self.assertAllClose(model.variables[0].read_value(), [[0.9]])

@test_utils.run_in_graph_and_eager_modes
def test_config(self):
sgd_opt = tf.keras.optimizers.SGD(
lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
opt = MovingAverage(
sgd_opt,
average_decay=0.5,
num_updates=100,
sequential_update=False)
config = opt.get_config()

self.assertEqual(config['average_decay'], 0.5)
self.assertEqual(config['decay'], 0.1)
self.assertEqual(config['learning_rate'], 2.0)
self.assertEqual(config['momentum'], 0.3)
self.assertEqual(config['name'], 'SGD')
self.assertEqual(config['nesterov'], True)
self.assertEqual(config['num_updates'], 100)
self.assertEqual(config['sequential_update'], False)


if __name__ == '__main__':
tf.test.main()