-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a model.check_trainable_weights_consistency #8234
Add a model.check_trainable_weights_consistency #8234
Conversation
This will raise a UserWarning when the user modifies model.trainable and tries to print a model summary or launch a fit without having called .compile. Calling .compile() is necessary because trainable weights are collected in compile (model._collected_trainable_weights).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very good idea, thanks.
keras/engine/training.py
Outdated
if not hasattr(self, '_collected_trainable_weights'): | ||
return | ||
|
||
if count_params(self.trainable_weights) != \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please no \
keras/engine/training.py
Outdated
count_params(self._collected_trainable_weights): | ||
warnings.warn(UserWarning( | ||
'Discrepancy between trainable weights and collected trainable' | ||
' weights, did you set model.trainable without calling' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: ` around code keywords
keras/utils/layer_utils.py
Outdated
@@ -5,6 +5,11 @@ | |||
import numpy as np | |||
|
|||
|
|||
def count_params(weights): | |||
"""Count the total number of scalars composing the weights""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring formatting: one-line summary should end with a period; the dosctring should include a # Arguments
and # Returns
section.
Thanks for the quick review. I think I addressed your comments. |
keras/engine/training.py
Outdated
@@ -945,9 +946,25 @@ def handle_metrics(metrics, weights=None): | |||
trainable_weights = self.trainable_weights | |||
self._collected_trainable_weights = trainable_weights | |||
|
|||
def check_trainable_weights_consistency(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make this method private (underscore) and make the docstring style-compliant.
fc4581e
to
7efdf99
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
This will raise a UserWarning when the user modifies model.trainable
and tries to print a model summary or launch a fit without having
called .compile.
Calling .compile() is necessary because trainable weights are collected
in compile (model._collected_trainable_weights).
Fixes #8121