Skip to content

Implement MovingAverage optimizer #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 29, 2019

Conversation

Squadrick
Copy link
Member

Closes #5

@Squadrick
Copy link
Member Author

Squadrick commented Apr 27, 2019

@seanpmorgan @facaiy Sorry about the delay on this. Could you take a look at this PR?

Also, I've made a few changes to the API.

In contrib, we had to use swapping_saver() to get an object of tf.train.Saver, to get the running mean and save those weights.
In addons, I've made assign_average_vars(var_list) that assigns the running mean to the variables in var_list.

I've included a docstring with an example that should show the new API usage. Let me know what you think.

* Port MovingAverageOptimizer from tf.contrib.opt

* Inherits base Keras optimizer_v2

* `swapping_saver` replaced with `assign_average_vars`

* Update test cases for TF2.X

* Update docs
@Smokrow
Copy link
Contributor

Smokrow commented Apr 28, 2019

Thank you for your PR :). I left a few reviews above

@Squadrick
Copy link
Member Author

@Smokrow Made a change regarding the internal function, and I've left comments regarding the colab notebook example and paper. Thanks for taking the time to review this PR.

@facaiy
Copy link
Member

facaiy commented Apr 29, 2019

@Smokrow Good job, Moritz, thanks for your help :-)

@WindQAQ Hi, Tzu-Wei, do you have time to take an another look? I'm kind of busy this week.

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LSTM! Thanks for the contribution! Could you make the tests more compatible with eager modes? And other minor requests are in the comments

@WindQAQ WindQAQ self-assigned this Apr 29, 2019
@WindQAQ
Copy link
Member

WindQAQ commented Apr 29, 2019

Also, do we really need to override _create_slots? Seems that slots will be created while calling self._optimizer.apply_gradient in apply_gradients. And I didn't see this method used in other places.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L427

* Use _set_hyper() and _get_hyper() instead of member variables for
average_decay, num_updates and sequential_update

* Remove _create_slots() from MovingAverage

* Use _serialize_hyperparameter() in get_config()

* Replace if-else with tf.cond() to work with tensors

* Use absolute import of tensorflow_addons in moving_average_test.py
@WindQAQ WindQAQ self-requested a review April 29, 2019 13:26
Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Squadrick, I'm not aware that using _set_hyper for bool and None type will result in lots of works. Could you revert sequential_update and num_updates to the original implementation? Sorry for any convenience caused.

* Tests modified for static and eager execution

* num_updates and sequential_update reverted back to instance variables

* Type check of num_updates and sequential_update
@Squadrick
Copy link
Member Author

@WindQAQ Reverted them back and added support for eager execution along with the tests.

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very close! Seems that some test cases fail? And also make sure to run make code-format before committing. Thanks for the contribution :-)

* Remove six import in moving_average_test

* Wrap zip objects in list to pass tests in python3

* Fix typos
@Squadrick
Copy link
Member Author

@WindQAQ Had to wrap all zip objects in list() to get the tests to pass in python3, didn't know about that. Everything should be working now, could you trigger the CI?

@WindQAQ
Copy link
Member

WindQAQ commented Apr 29, 2019

Had to wrap all zip objects in list() to get the tests to pass in python3, didn't know about that.

It's likely to be the difference caused by generator. Anyway, thanks for the contribution! I'm looking forward to seeing the example colab notebook :-)

@WindQAQ WindQAQ merged commit e278852 into tensorflow:master Apr 29, 2019
@Squadrick
Copy link
Member Author

@WindQAQ @Smokrow Thanks for the reviews, guys.

@Squadrick Squadrick deleted the moving-avg-opt branch April 29, 2019 15:56

with tf.name_scope(name):
self._ema = tf.train.ExponentialMovingAverage(
average_decay, num_updates=num_updates)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the constructor pass optimizer.iterations as num_updates?

base_config = self._optimizer.get_config()
return dict(list(base_config.items()) + list(config.items()))

def assign_average_vars(self, var_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoints are usually saved in regular intervals. Is it a standard practice to continue training with the averaged variables after saving a checkpoint?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement MovingAverageOptimizer
7 participants