Skip to content

Fixing minor issues from #969 #2163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 121 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
e2c1378
initial setup. need to build tests
hyang0129 Jan 29, 2020
ef4b235
build some tests. need to test them
hyang0129 Jan 29, 2020
34e4e16
fixed typo
hyang0129 Jan 29, 2020
c9e8b99
created first test
hyang0129 Jan 29, 2020
3b545a4
created first test
hyang0129 Jan 29, 2020
fab5871
accidentally messed up another file
hyang0129 Jan 29, 2020
333e907
accidentally messed up another file
hyang0129 Jan 29, 2020
b79135d
accidentally messed up another file
hyang0129 Jan 29, 2020
c4d2588
added run all distributed
hyang0129 Jan 29, 2020
c47c34a
fixed formatting
hyang0129 Jan 29, 2020
a510f70
trying to fix tests not running on github CI.
hyang0129 Jan 29, 2020
6940433
realized that I should probably add the new optimizer files to the bu…
hyang0129 Jan 29, 2020
af932d8
added typeguard and docstring
hyang0129 Jan 29, 2020
ea1f62e
removed run_all_distributed
hyang0129 Jan 29, 2020
80e2768
graph and eager testing for SGD
hyang0129 Jan 30, 2020
f89ac90
reformatted
hyang0129 Jan 30, 2020
1137bc3
added distributed tests
hyang0129 Jan 30, 2020
1149719
removed distributed tests
hyang0129 Jan 30, 2020
87b396c
reverted discriminative layer grad adjust back to apply gradients
hyang0129 Jan 30, 2020
a5a8a6f
added distributed tests with one time virtual device init
hyang0129 Jan 30, 2020
e76835f
increased tolerance for distributed
hyang0129 Jan 30, 2020
504d4bb
changed how distributed is recognized for increasing tolerance
hyang0129 Jan 30, 2020
9eb2d3c
Redesigned Logic into Optimizer Wrapper (#1)
hyang0129 Jan 31, 2020
7e6579f
reformatted
hyang0129 Jan 31, 2020
0fefcdf
updated documentation
hyang0129 Jan 31, 2020
595ec9a
added typecheck for name
hyang0129 Feb 1, 2020
8fd351d
added typecheck for name
hyang0129 Feb 4, 2020
9d43e44
fixed blank line at end of init file
hyang0129 Feb 4, 2020
3d581df
realized no new line meant to add new line
hyang0129 Feb 4, 2020
742117b
ran buildifier
hyang0129 Feb 4, 2020
9511238
fixed accidentally affecting moving average
hyang0129 Feb 4, 2020
55399f8
changed print to logging.info
hyang0129 Feb 4, 2020
3fa5e19
changed print to logging.info
hyang0129 Feb 10, 2020
f3d402c
Revert "changed print to logging.info"
hyang0129 Feb 10, 2020
3457b3b
added tutorial.
hyang0129 Feb 10, 2020
29d440a
refactored to use static method
hyang0129 Feb 11, 2020
5b0531c
updated the usage of lr_mult in variables
hyang0129 Feb 11, 2020
40e8bba
renamed discriminative wrapper to disclayeropt
hyang0129 Feb 11, 2020
e3781e0
added note to disuade directly calling apply_gradients
hyang0129 Feb 11, 2020
9c62b01
updated toy_cnn to use tempdir and no longer call context.eager
hyang0129 Feb 11, 2020
c0ad05a
added toy_rnn and sgd to the test permutations
hyang0129 Feb 11, 2020
9a69ae8
refactored permutes and train results into private fns
hyang0129 Feb 11, 2020
abbb961
reformatted files and fixed flake 8 issues
hyang0129 Feb 11, 2020
9f19a63
added missing functions in prep for tests
hyang0129 Feb 11, 2020
cd1f613
updated assign lr mult and explained further why
hyang0129 Feb 11, 2020
5f67423
forgot to run black so ran it to reformat
hyang0129 Feb 11, 2020
bbc0f6c
specified inputshape for rnn
hyang0129 Feb 11, 2020
b77bbfc
increased size of test
hyang0129 Feb 11, 2020
2d5fe1a
remove toy rnn for now
hyang0129 Feb 11, 2020
503a9e5
changed back to medium. maybe large was not actually increasing runtime
hyang0129 Feb 11, 2020
ff697cb
fixed input layer
hyang0129 Feb 11, 2020
740127c
fixed input layer being in wrong place
hyang0129 Feb 11, 2020
10b7417
virtual device modification issue
hyang0129 Feb 12, 2020
a2831d9
fixed incorrect usage of lr_mult
hyang0129 Feb 12, 2020
6baa024
added comments for tests explaining them better
hyang0129 Feb 12, 2020
ff86000
added new test
hyang0129 Feb 12, 2020
f244bf2
fixed typo
hyang0129 Feb 12, 2020
b9119e8
added inputshape so that pretrained rnn generates weights
hyang0129 Feb 12, 2020
a178620
changed test to allow head to learn. it should move the loss better
hyang0129 Feb 12, 2020
d9408d0
reformatted
hyang0129 Feb 12, 2020
353bcc3
fixed test for variable assignment
hyang0129 Feb 12, 2020
4c7cd95
reformatted
hyang0129 Feb 12, 2020
9ccd67a
fixed layer references from 1 to 0 because input layer isn't counted
hyang0129 Feb 12, 2020
9437d57
reformatted
hyang0129 Feb 12, 2020
0ba9348
increased lr and epochs because learning was happning, but assertless
hyang0129 Feb 12, 2020
126b5d4
attempting to use run distributed from test utils
hyang0129 Feb 13, 2020
6e560bc
removed tutorial
hyang0129 Feb 13, 2020
44b6300
switched to alternative distributed training method
hyang0129 Feb 13, 2020
2625d33
Changes (#2)
hyang0129 Feb 14, 2020
f827c6b
Merge pull request #3 from tensorflow/master
hyang0129 Feb 14, 2020
360f2ce
trying to use run distributed without graph and eager
hyang0129 Feb 14, 2020
e809eb9
trying to use run_distributed
hyang0129 Feb 14, 2020
cb30bbb
seems that doing any tensorstuff before tf.test.main creates the issu…
hyang0129 Feb 14, 2020
d03c568
forgot to return a model on first run of model fn
hyang0129 Feb 14, 2020
9f5e7bf
create model weights on init
hyang0129 Feb 14, 2020
7bb37de
changed how args are passed for testcase
hyang0129 Feb 14, 2020
6e4c78e
changed how args are passed for testcase
hyang0129 Feb 14, 2020
99324be
try fix init
hyang0129 Feb 14, 2020
4e1ca6d
trying to init weights on model properly
hyang0129 Feb 14, 2020
4b84779
trying to init weights on model properly
hyang0129 Feb 14, 2020
2041767
just trying all the possibilities
hyang0129 Feb 14, 2020
0431121
trying to fix weights setup
hyang0129 Feb 17, 2020
d41495b
expanded some comments for some tests
hyang0129 Feb 18, 2020
46f8870
Merge pull request #4 from tensorflow/master
hyang0129 Feb 24, 2020
5883b75
fixed some docstrings and expanded on some comments
hyang0129 Feb 24, 2020
f275f84
Merge branch 'master' of https://github.com/hyang0129/addons
hyang0129 Feb 24, 2020
00cba8c
reformatted files
hyang0129 Feb 24, 2020
d142572
capitalized comments properly.
hyang0129 Feb 24, 2020
c4fc0e2
removed sgd, reduced size of training inputs.
hyang0129 Feb 24, 2020
ef3130b
simplified checkpoint name
hyang0129 Feb 25, 2020
30d3d1d
reformatted
hyang0129 Feb 25, 2020
7e46a06
remove run tests in notebook
hyang0129 Feb 25, 2020
e266116
Merge pull request #5 from tensorflow/master
hyang0129 Feb 25, 2020
63438e3
Merge branch 'master' into hyang0129_master
gabrieldemarmiesse Feb 26, 2020
b4332ae
Merge branch 'master' into hyang0129_master
gabrieldemarmiesse Feb 26, 2020
b18e231
updated README.md
hyang0129 Feb 27, 2020
9eb2104
fixed formatting
hyang0129 Feb 27, 2020
ffcba16
Merge branch 'master' of https://github.com/tensorflow/addons
hyang0129 Mar 11, 2020
b54a538
removed distributed tests and added a warning if optimizer is initial…
hyang0129 Mar 11, 2020
39d28ef
renamed test_wrap to wrap_test bc pytest thought it was a test.
hyang0129 Mar 11, 2020
4f6d082
Merge remote-tracking branch 'upstream/master'
hyang0129 Mar 23, 2020
fea9a04
Merge branch 'master' of https://github.com/hyang0129/addons
hyang0129 Mar 23, 2020
9d70f8d
converting tests into the pytest framework
hyang0129 Mar 23, 2020
02e79df
converted tests and parameterized
hyang0129 Mar 23, 2020
dea5198
cleaned up code
hyang0129 Mar 23, 2020
1f951e5
added additional checks and doc string for changes in lr multiplier d…
hyang0129 Mar 31, 2020
77a2b3d
Merge remote-tracking branch 'upstream/master'
hyang0129 Apr 21, 2020
c903d27
changed comment
hyang0129 Apr 21, 2020
8a0c9e2
Merge remote-tracking branch 'upstream/master'
hyang0129 May 1, 2020
a029857
Merge pull request #6 from tensorflow/master
hyang0129 Sep 9, 2020
cd01a2b
Simplified discriminative layer training by using a multi optimizer w…
hyang0129 Sep 9, 2020
6f82e8f
Refactored code using black and flake8
hyang0129 Sep 9, 2020
ef4ede1
updated init file
hyang0129 Sep 9, 2020
87d4c9d
fixed typeguard error and usage of private/experimental api.
hyang0129 Sep 9, 2020
35dca9b
restructured wrapper serialization and removed unnecessary components.
hyang0129 Sep 13, 2020
f189c39
expanded on docstr and added repr
hyang0129 Sep 13, 2020
c63c635
cleaned up docstrings, added assertion tests, and added explicit test…
hyang0129 Sep 14, 2020
2c46752
ran black and flake8
hyang0129 Sep 14, 2020
4b856f1
fixed doc string
hyang0129 Sep 14, 2020
0e3910a
modified code owners
hyang0129 Sep 14, 2020
6f50e6a
fixed typo in multioptimizer class name
hyang0129 Sep 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@
/tensorflow_addons/optimizers/tests/conditional_gradient_test.py @pkan2 @lokhande-vishnu
/tensorflow_addons/optimizers/cyclical_learning_rate.py @raphaelmeudec
/tensorflow_addons/optimizers/tests/cyclical_learning_rate_test.py @raphaelmeudec
/tensorflow_addons/optimizers/discriminative_layer_training.py @hyang0129
/tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py @hyang0129
/tensorflow_addons/optimizers/lamb.py @junjiek
/tensorflow_addons/optimizers/tests/lamb_test.py @junjiek
/tensorflow_addons/optimizers/lazy_adam.py @ssaishruthi
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from tensorflow_addons.optimizers.cyclical_learning_rate import (
ExponentialCyclicalLearningRate,
)
from tensorflow_addons.optimizers.discriminative_layer_training import (
MultiOptimizer,
)
from tensorflow_addons.optimizers.lamb import LAMB
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.lookahead import Lookahead
Expand Down
166 changes: 166 additions & 0 deletions tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Discriminative Layer Training Optimizer for TensorFlow."""

from typing import Union

import tensorflow as tf
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class MultiOptimizer(tf.keras.optimizers.Optimizer):
"""Multi Optimizer Wrapper for Discriminative Layer Training.

Creates a wrapper around a set of instantiated optimizer layer pairs. Generally useful for transfer learning
of deep networks.

Each optimizer will optimize only the weights associated with its paired layer. This can be used
to implement discriminative layer training by assigning different learning rates to each optimizer
layer pair. (Optimizer, list(Layers)) pairs are also supported. Please note that the layers must be
instantiated before instantiating the optimizer.

Args:
optimizers_and_layers: a list of tuples of an optimizer and a layer or model. Each tuple should contain
exactly 1 instantiated optimizer and 1 object that subclasses tf.keras.Model or tf.keras.Layer. Nested
layers and models will be automatically discovered. Alternatively, in place of a single layer, you can pass
a list of layers.
optimizer_specs: specialized list for serialization. Should be left as None for almost all cases. If you are
loading a serialized version of this optimizer, please use tf.keras.models.load_model after saving a
model compiled with this optimizer.

Usage:

```python
model = get_model()

opt1 = tf.keras.optimizers.Adam(learning_rate=1e-4)
opt2 = tf.keras.optimizers.Adam(learning_rate=1e-2)

opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1:])]

loss = tf.keras.losses.MSE
optimizer = tfa.optimizers.MultiOpt(opt_layer_pairs)

model.compile(optimizer=optimizer, loss = loss)

model.fit(x,y)
'''

Reference:

[Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146)
[Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440)

Notes:

Currently, MultiOpt does not support callbacks that modify optimizers. However, you can instantiate
optimizer layer pairs with tf.keras.optimizers.schedules.LearningRateSchedule instead of a static learning
rate.

This code should function on CPU, GPU, and TPU. Apply the with strategy.scope() context as you
would with any other optimizer.

"""

@typechecked
def __init__(
self,
optimizers_and_layers: Union[list, None] = None,
optimizer_specs: Union[list, None] = None,
name: str = "MultiOptimzer",
**kwargs
):

super(MultiOptimizer, self).__init__(name, **kwargs)

if optimizer_specs is None and optimizers_and_layers is not None:
self.optimizer_specs = [
self.create_optimizer_spec(opt, layer)
for opt, layer in optimizers_and_layers
]

elif optimizer_specs is not None and optimizers_and_layers is None:
self.optimizer_specs = [
self.maybe_initialize_optimizer_spec(spec) for spec in optimizer_specs
]

else:
raise RuntimeError(
"You must specify either an list of optimizers and layers or a list of optimizer_specs"
)

def apply_gradients(self, grads_and_vars, name=None, **kwargs):
"""Wrapped apply_gradient method.

Returns a list of tf ops to be executed.
Name of variable is used rather than var.ref() to enable serialization and deserialization.
"""

for spec in self.optimizer_specs:
spec["gv"] = []

for grad, var in tuple(grads_and_vars):
for spec in self.optimizer_specs:
for name in spec["weights"]:
if var.name == name:
spec["gv"].append((grad, var))

return tf.group(
[
spec["optimizer"].apply_gradients(spec["gv"], **kwargs)
for spec in self.optimizer_specs
]
)

def get_config(self):
config = super(MultiOptimizer, self).get_config()
config.update({"optimizer_specs": self.optimizer_specs})
return config

@classmethod
def create_optimizer_spec(cls, optimizer_instance, layer):

assert isinstance(
optimizer_instance, tf.keras.optimizers.Optimizer
), "Object passed is not an instance of tf.keras.optimizers.Optimizer"

assert isinstance(layer, tf.keras.layers.Layer) or isinstance(
layer, tf.keras.Model
), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model"

if type(layer) == list:
weights = [var.name for sublayer in layer for var in sublayer.weights]
else:
weights = [var.name for var in layer.weights]

return {
"optimizer": optimizer_instance,
"weights": weights,
}

@classmethod
def maybe_initialize_optimizer_spec(cls, optimizer_spec):
if type(optimizer_spec["optimizer"]) == dict:
optimizer_spec["optimizer"] = tf.keras.optimizers.deserialize(
optimizer_spec["optimizer"]
)

return optimizer_spec

def __repr__(self):
return "Multi Optimizer with %i optimizer layer pairs" % len(
self.optimizer_specs
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Discriminative Layer Training Optimizer for TensorFlow."""

import pytest
import numpy as np
import tensorflow as tf

from tensorflow_addons.optimizers.discriminative_layer_training import MultiOptimizer
from tensorflow_addons.utils import test_utils


def _dtypes_to_test(use_gpu):
# Based on issue #347 in the following link,
# "https://github.com/tensorflow/addons/issues/347"
# tf.half is not registered for 'ResourceScatterUpdate' OpKernel
# for 'GPU' devices.
# So we have to remove tf.half when testing with gpu.
# The function "_DtypesToTest" is from
# "https://github.com/tensorflow/tensorflow/blob/5d4a6cee737a1dc6c20172a1dc1
# 5df10def2df72/tensorflow/python/kernel_tests/conv_ops_3d_test.py#L53-L62"
# TODO(WindQAQ): Clean up this in TF2.4

if use_gpu:
return [tf.float32, tf.float64]
else:
return [tf.half, tf.float32, tf.float64]


@pytest.mark.with_device(["cpu", "gpu"])
@pytest.mark.parametrize("dtype", [tf.float16, tf.float32, tf.float64])
@pytest.mark.parametrize("serialize", [True, False])
def test_fit_layer_optimizer(dtype, device, serialize):
# Test ensures that each optimizer is only optimizing its own layer with its learning rate

if "gpu" in device and dtype == tf.float16:
pytest.xfail("See https://github.com/tensorflow/addons/issues/347")

model = tf.keras.Sequential(
[tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)]
)

x = np.array(np.ones([100]))
y = np.array(np.ones([100]))

weights_before_train = (
model.layers[0].weights[0].numpy(),
model.layers[1].weights[0].numpy(),
)

opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3)
opt2 = tf.keras.optimizers.SGD(learning_rate=0)

opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])]

loss = tf.keras.losses.MSE
optimizer = MultiOptimizer(opt_layer_pairs)

model.compile(optimizer=optimizer, loss=loss)

# serialize whole model including optimizer, clear the session, then reload the whole model.
if serialize:
model.save("test", save_format="tf")
tf.keras.backend.clear_session()
model = tf.keras.models.load_model("test")

model.fit(x, y, batch_size=8, epochs=10)

weights_after_train = (
model.layers[0].weights[0].numpy(),
model.layers[1].weights[0].numpy(),
)

with np.testing.assert_raises(AssertionError):
# expect weights to be different for layer 1
test_utils.assert_allclose_according_to_type(
weights_before_train[0], weights_after_train[0]
)

# expect weights to be same for layer 2
test_utils.assert_allclose_according_to_type(
weights_before_train[1], weights_after_train[1]
)


def test_serialization():

model = tf.keras.Sequential(
[tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)]
)

opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3)
opt2 = tf.keras.optimizers.SGD(learning_rate=0)

opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])]

optimizer = MultiOptimizer(opt_layer_pairs)
config = tf.keras.optimizers.serialize(optimizer)

new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()