-
Notifications
You must be signed in to change notification settings - Fork 615
Discriminative Layer Training #969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e2c1378
ef4b235
34e4e16
c9e8b99
3b545a4
fab5871
333e907
b79135d
c4d2588
c47c34a
a510f70
6940433
af932d8
ea1f62e
80e2768
f89ac90
1137bc3
1149719
87b396c
a5a8a6f
e76835f
504d4bb
9eb2d3c
7e6579f
0fefcdf
595ec9a
8fd351d
9d43e44
3d581df
742117b
9511238
55399f8
3fa5e19
f3d402c
3457b3b
29d440a
5b0531c
40e8bba
e3781e0
9c62b01
c0ad05a
9a69ae8
abbb961
9f19a63
cd1f613
5f67423
bbc0f6c
b77bbfc
2d5fe1a
503a9e5
ff697cb
740127c
10b7417
a2831d9
6baa024
ff86000
f244bf2
b9119e8
a178620
d9408d0
353bcc3
4c7cd95
9ccd67a
9437d57
0ba9348
126b5d4
6e560bc
44b6300
2625d33
f827c6b
360f2ce
e809eb9
cb30bbb
d03c568
9f5e7bf
7bb37de
6e4c78e
99324be
4e1ca6d
4b84779
2041767
0431121
d41495b
46f8870
5883b75
f275f84
00cba8c
d142572
c4fc0e2
ef3130b
30d3d1d
7e46a06
e266116
63438e3
b4332ae
b18e231
9eb2104
ffcba16
b54a538
39d28ef
4f6d082
fea9a04
9d70f8d
02e79df
dea5198
1f951e5
77a2b3d
c903d27
8a0c9e2
a029857
cd01a2b
6f82e8f
ef4ede1
87d4c9d
35dca9b
f189c39
c63c635
2c46752
4b856f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Discriminative Layer Training Optimizer for TensorFlow.""" | ||
|
||
from typing import Union | ||
|
||
import tensorflow as tf | ||
from typeguard import typechecked | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="Addons") | ||
class MultiOptimzer(tf.keras.optimizers.Optimizer): | ||
"""Multi Optimizer Wrapper for Discriminative Layer Training. | ||
|
||
Creates a wrapper around a set of instantiated optimizer layer pairs. Generally useful for transfer learning | ||
of deep networks. | ||
|
||
Each optimizer will optimize only the weights associated with its paired layer. This can be used | ||
to implement discriminative layer training by assigning different learning rates to each optimizer | ||
layer pair. (Optimizer, list(Layers)) pairs are also supported. Please note that the layers must be | ||
instantiated before instantiating the optimizer. | ||
|
||
Args: | ||
optimizers_and_layers: a list of tuples of an optimizer and a layer or model. Each tuple should contain | ||
exactly 1 instantiated optimizer and 1 object that subclasses tf.keras.Model or tf.keras.Layer. Nested | ||
layers and models will be automatically discovered. Alternatively, in place of a single layer, you can pass | ||
a list of layers. | ||
optimizer_specs: specialized list for serialization. Should be left as None for almost all cases. If you are | ||
loading a serialized version of this optimizer, please use tf.keras.models.load_model after saving a | ||
model compiled with this optimizer. | ||
|
||
hyang0129 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Usage: | ||
|
||
```python | ||
model = get_model() | ||
|
||
opt1 = tf.keras.optimizers.Adam(learning_rate=1e-4) | ||
opt2 = tf.keras.optimizers.Adam(learning_rate=1e-2) | ||
|
||
opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1:])] | ||
|
||
loss = tf.keras.losses.MSE | ||
optimizer = tfa.optimizers.MultiOpt(opt_layer_pairs) | ||
|
||
model.compile(optimizer=optimizer, loss = loss) | ||
|
||
model.fit(x,y) | ||
''' | ||
|
||
Reference: | ||
|
||
[Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) | ||
[Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440) | ||
|
||
Notes: | ||
|
||
hyang0129 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Currently, MultiOpt does not support callbacks that modify optimizers. However, you can instantiate | ||
optimizer layer pairs with tf.keras.optimizers.schedules.LearningRateSchedule instead of a static learning | ||
rate. | ||
|
||
This code should function on CPU, GPU, and TPU. Apply the with strategy.scope() context as you | ||
would with any other optimizer. | ||
|
||
""" | ||
|
||
@typechecked | ||
def __init__( | ||
self, | ||
optimizers_and_layers: Union[list, None] = None, | ||
optimizer_specs: Union[list, None] = None, | ||
name: str = "MultiOptimzer", | ||
**kwargs | ||
): | ||
|
||
super(MultiOptimzer, self).__init__(name, **kwargs) | ||
|
||
if optimizer_specs is None and optimizers_and_layers is not None: | ||
self.optimizer_specs = [ | ||
self.create_optimizer_spec(opt, layer) | ||
for opt, layer in optimizers_and_layers | ||
] | ||
|
||
elif optimizer_specs is not None and optimizers_and_layers is None: | ||
self.optimizer_specs = [ | ||
self.maybe_initialize_optimizer_spec(spec) for spec in optimizer_specs | ||
] | ||
|
||
else: | ||
raise RuntimeError( | ||
"You must specify either an list of optimizers and layers or a list of optimizer_specs" | ||
) | ||
|
||
def apply_gradients(self, grads_and_vars, name=None, **kwargs): | ||
"""Wrapped apply_gradient method. | ||
|
||
Returns a list of tf ops to be executed. | ||
Name of variable is used rather than var.ref() to enable serialization and deserialization. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. """Wrapped apply_gradient method.
Returns a list of tf ops to be executed.
Name of variable is used rather than var.ref() to enable serialization and deserialization.
""" |
||
|
||
for spec in self.optimizer_specs: | ||
spec["gv"] = [] | ||
|
||
for grad, var in tuple(grads_and_vars): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it matches the parent class's usage of tuple |
||
for spec in self.optimizer_specs: | ||
for name in spec["weights"]: | ||
if var.name == name: | ||
spec["gv"].append((grad, var)) | ||
|
||
return tf.group( | ||
[ | ||
spec["optimizer"].apply_gradients(spec["gv"], **kwargs) | ||
for spec in self.optimizer_specs | ||
] | ||
) | ||
|
||
def get_config(self): | ||
config = super(MultiOptimzer, self).get_config() | ||
config.update({"optimizer_specs": self.optimizer_specs}) | ||
return config | ||
|
||
@classmethod | ||
def create_optimizer_spec(cls, optimizer_instance, layer): | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we check |
||
assert isinstance( | ||
optimizer_instance, tf.keras.optimizers.Optimizer | ||
), "Object passed is not an instance of tf.keras.optimizers.Optimizer" | ||
|
||
assert isinstance(layer, tf.keras.layers.Layer) or isinstance( | ||
layer, tf.keras.Model | ||
), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model" | ||
|
||
if type(layer) == list: | ||
weights = [var.name for sublayer in layer for var in sublayer.weights] | ||
else: | ||
weights = [var.name for var in layer.weights] | ||
|
||
return { | ||
"optimizer": optimizer_instance, | ||
"weights": weights, | ||
} | ||
|
||
@classmethod | ||
def maybe_initialize_optimizer_spec(cls, optimizer_spec): | ||
if type(optimizer_spec["optimizer"]) == dict: | ||
optimizer_spec["optimizer"] = tf.keras.optimizers.deserialize( | ||
optimizer_spec["optimizer"] | ||
) | ||
|
||
return optimizer_spec | ||
|
||
def __repr__(self): | ||
return "Multi Optimizer with %i optimizer layer pairs" % len( | ||
self.optimizer_specs | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for Discriminative Layer Training Optimizer for TensorFlow.""" | ||
|
||
import pytest | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from tensorflow_addons.optimizers.discriminative_layer_training import MultiOptimzer | ||
from tensorflow_addons.utils import test_utils | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import pytest
import numpy as np
import tensorflow as tf
from tensorflow_addons.optimizers.discriminative_layer_training import MultiOptimzer
from tensorflow_addons.utils import test_utils |
||
|
||
|
||
def _dtypes_to_test(use_gpu): | ||
# Based on issue #347 in the following link, | ||
# "https://github.com/tensorflow/addons/issues/347" | ||
# tf.half is not registered for 'ResourceScatterUpdate' OpKernel | ||
# for 'GPU' devices. | ||
# So we have to remove tf.half when testing with gpu. | ||
WindQAQ marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The function "_DtypesToTest" is from | ||
# "https://github.com/tensorflow/tensorflow/blob/5d4a6cee737a1dc6c20172a1dc1 | ||
# 5df10def2df72/tensorflow/python/kernel_tests/conv_ops_3d_test.py#L53-L62" | ||
# TODO(WindQAQ): Clean up this in TF2.4 | ||
|
||
if use_gpu: | ||
return [tf.float32, tf.float64] | ||
else: | ||
return [tf.half, tf.float32, tf.float64] | ||
|
||
|
||
@pytest.mark.with_device(["cpu", "gpu"]) | ||
@pytest.mark.parametrize("dtype", [tf.float16, tf.float32, tf.float64]) | ||
@pytest.mark.parametrize("serialize", [True, False]) | ||
def test_fit_layer_optimizer(dtype, device, serialize): | ||
# Test ensures that each optimizer is only optimizing its own layer with its learning rate | ||
|
||
if "gpu" in device and dtype == tf.float16: | ||
pytest.xfail("See https://github.com/tensorflow/addons/issues/347") | ||
|
||
model = tf.keras.Sequential( | ||
[tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)] | ||
) | ||
|
||
x = np.array(np.ones([100])) | ||
y = np.array(np.ones([100])) | ||
|
||
weights_before_train = ( | ||
model.layers[0].weights[0].numpy(), | ||
model.layers[1].weights[0].numpy(), | ||
) | ||
|
||
opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3) | ||
opt2 = tf.keras.optimizers.SGD(learning_rate=0) | ||
|
||
opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])] | ||
|
||
loss = tf.keras.losses.MSE | ||
optimizer = MultiOptimzer(opt_layer_pairs) | ||
|
||
model.compile(optimizer=optimizer, loss=loss) | ||
|
||
# serialize whole model including optimizer, clear the session, then reload the whole model. | ||
if serialize: | ||
model.save("test", save_format="tf") | ||
tf.keras.backend.clear_session() | ||
model = tf.keras.models.load_model("test") | ||
|
||
model.fit(x, y, batch_size=8, epochs=10) | ||
|
||
weights_after_train = ( | ||
model.layers[0].weights[0].numpy(), | ||
model.layers[1].weights[0].numpy(), | ||
) | ||
|
||
with np.testing.assert_raises(AssertionError): | ||
# expect weights to be different for layer 1 | ||
test_utils.assert_allclose_according_to_type( | ||
weights_before_train[0], weights_after_train[0] | ||
) | ||
|
||
# expect weights to be same for layer 2 | ||
test_utils.assert_allclose_according_to_type( | ||
weights_before_train[1], weights_after_train[1] | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also have a test about serialization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a serialization to the existing test as a parameter. After compiling but before fitting, the test serializes the model with the optimizer, clears the session, then loads the model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm I added a serialization test by itself as well |
||
|
||
|
||
def test_serialization(): | ||
|
||
model = tf.keras.Sequential( | ||
[tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)] | ||
) | ||
|
||
opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3) | ||
opt2 = tf.keras.optimizers.SGD(learning_rate=0) | ||
|
||
opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])] | ||
|
||
optimizer = MultiOptimzer(opt_layer_pairs) | ||
config = tf.keras.optimizers.serialize(optimizer) | ||
|
||
new_optimizer = tf.keras.optimizers.deserialize(config) | ||
assert new_optimizer.get_config() == optimizer.get_config() |
Uh oh!
There was an error while loading. Please reload this page.