Skip to content

Fix MultiOptimizer list of layers #2180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

WindQAQ
Copy link
Member

@WindQAQ WindQAQ commented Sep 25, 2020

Description

Fixes #2178

Type of change

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running Black + Flake8
    • By running pre-commit hooks
  • This PR addresses an already submitted issue for TensorFlow Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • This PR contains modifications to C++ custom-ops

How Has This Been Tested?

Additional test.

@WindQAQ WindQAQ requested a review from a team September 25, 2020 19:07
@bot-of-gabrieldemarmiesse

@hyang0129

You are owner of some files modified in this pull request.
Would you kindly review the changes whenever you have the time to?
Thank you very much.

@WindQAQ WindQAQ mentioned this pull request Sep 25, 2020
16 tasks
@pytest.mark.with_device(["cpu", "gpu"])
@pytest.mark.parametrize("dtype", [tf.float16, tf.float32, tf.float64])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not being used in the original tests, so I remove it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we should add this test, right?

Copy link
Contributor

@bhack bhack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick overview It seems that we are not covering the optimizer, model case?

if type(layer) == list:
The name of each variable is used rather than `var.ref()` to enable serialization and deserialization.
"""
if isinstance(layer, list):
weights = [var.name for sublayer in layer for var in sublayer.weights]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is here trainable_weights?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is discussed in #969 (comment).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would respect the decision of code owner. Change in design is not part of this PR. /cc @hyang0129

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it was interpreted as we have weights and we set the same weights at semantic level is ok

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. We can wait the response from the code owner. I see the underlying fit only takes trainable variables into account. As you recommend, we can change it to trainable_weights after dicussion.

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/training.py#L737

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is V1.
In V2 you need to pass a variable list or a callable as it has no default but they suppose as default example:
https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L118-L120

Instead in our apply_gradients we are controlling internally the variable list to pass to the real apply_gradiennts of optimizers. It seems to me that we are performing a little bit like fit->minimize method or not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not “controlling the var list” internally but fetching only vars that users pass in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we look at the argument ok we receive an user argument but internally we access to our prepared spec["weights"]:
for name in spec["weights"]:

Copy link
Contributor

@hyang0129 hyang0129 Sep 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the multi optimizer stores a reference of weight names.

when optimizing, it uses this reference to allocate grad var pairs to the correct optimizer.

it does not determine what weights are passed to the multi optimizer.

So you may have a situation where 5 vars are passed but spec['weights'] contains 10 var names and only 3 of them match. Thus, only 3 are passed on.

The original design is intended to be used in the model.fit function, not a custom optimization loop. Even in a custom optimization loop, the user is expected to specify the var list to include only trainable weights. I am fairly certain that when an optimizer is called in a fit loop, the var list passed includes only trainable variables. Thus, the var list passed to the multi optimizer includes only trainable variables. Finally, the multi optimizer only passes on to its sub optimizers the variables in var list it received that match a variable in spec['weights'] for that particular optimizer. If this is true, then a non trainable variable will never make it to the multi optimizer or sub optimizers.

This behavior has been tested in a colab notebook. When a layer's trainable attribute has been set to false, the multi optimizer's sub optimizer assigned to that layer does not optimizer the weights for that layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it Is a case where weight is a superset to check the name probably ok.

But I still don't catch when if var.name == name needs to handle not trainable weights.

if type(layer) == list:
The name of each variable is used rather than `var.ref()` to enable serialization and deserialization.
"""
if isinstance(layer, list):
weights = [var.name for sublayer in layer for var in sublayer.weights]
else:
weights = [var.name for var in layer.weights]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto?

@WindQAQ
Copy link
Member Author

WindQAQ commented Sep 25, 2020

From a quick overview It seems that we are not covering the optimizer, model case?

Not sure what this case is. Can you elaborate it?

@bhack
Copy link
Contributor

bhack commented Sep 25, 2020

I meant it seems that tf.Model could be an input right? Do we have a test for this input case?

@WindQAQ
Copy link
Member Author

WindQAQ commented Sep 26, 2020

I meant it seems that tf.Model could be and input right? Do we have a test for this input case?

Thanks for the info. Updated the related tests :-)

def create_optimizer_spec(
cls,
optimizer: tf.keras.optimizers.Optimizer,
layer: Union[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the naming here It Is a little bit confusing also in the internal code cause if layers with s It could cover the list case but what about tf.keras.Model?

Copy link
Member Author

@WindQAQ WindQAQ Sep 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the naming is not so good to me either, but I cannot come up with a new one... do you have any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugly as ugly we already have optimizers_and_layers.
With another boolean we could have layers_or_model 😄

Copy link
Member Author

@WindQAQ WindQAQ Sep 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share the full input signature you propose?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea. please go ahead with the renaming.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer -> layers_or_model

Copy link
Member Author

@WindQAQ WindQAQ Sep 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, tf.keras.Model is a subclass of tf.keras.layers.Layer. Do we still need to do this?

Copy link
Contributor

@bhack bhack Sep 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a borderline case cause it is multi inheritance:
class Model(base_layer.Layer, version_utils.ModelVersionSelector):
We are only using layer the single base class features but I don't know about readability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as you want cause we cannot upcast to Layer in python or It Is generally a borderline pratice

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the naming. See if this is better now :-)

@hyang0129
Copy link
Contributor

Sorry was on vacation. I'll take a look tomorrow.

multi_optimizer = MultiOptimizer(optimizers_and_layers)
model.compile(multi_optimizer, loss="mse")

x = np.random.rand(128, 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend using a signal rather than complete noise. The purpose of this test is to demonstrate that the model weights will move when there is a signal based on the optimizer setup.

Here, when you use np rand, you are generating noise without any signal. Technically, the signal is the average, but the model will likely memorize the input based on the model size and number of examples of x.

You can choose to leave it as is because it will test the multi optimizer. In the future, people might be confused (temporarily) as to what signal the model is trying to learn.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion :-) Updated

hyang0129
hyang0129 previously approved these changes Sep 27, 2020
@@ -217,6 +217,25 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytest.mark.skip("The gpu is not available."))


def assert_not_allclose(a, b, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Each optimizer will optimize only the weights associated with its paired layer.
This can be used to implement discriminative layer training by assigning
different learning rates to each optimizer layer pair.
`(tf.keras.optimizers.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List[tf.keras.layers.Layer]) -> List([tf.keras.layers.Layer]). Was missing a (

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ( optimizer, List[layer] ), where () stands for Tuple.

@hyang0129
Copy link
Contributor

@bhack I think this is good to go.

@WindQAQ WindQAQ merged commit 392f36c into tensorflow:master Sep 27, 2020
jrruijli pushed a commit to jrruijli/addons that referenced this pull request Dec 23, 2020
* Fix MultiOptimizer list of layers

* Fix name

* Remove unused tests

* Change list to iterable

* Update doc

* Update code snippet

* Update doc

* Back to list

* Update error message

* Update doc

* Fix tmpdir fixture

* Fix tmpdir

* Update doc

* Add test on tf.keras.Model

* Add nested model tests

* Better naming

* Add custom subclass model tests

* Inherit from Layer

* Move assert_not_allclose to test_utils

* Change input to ones

* Inherit from Model

* Test all weights instead of first one

* Update doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error in MultiOptimizer when layers list are used
6 participants