Skip to content

Add a model.check_trainable_weights_consistency #8234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .. import initializers
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.layer_utils import print_summary as print_layer_summary
from ..utils.layer_utils import count_params
from ..utils.generic_utils import has_arg
from ..utils import conv_utils
from ..legacy import interfaces
Expand Down Expand Up @@ -1269,7 +1270,7 @@ def count_params(self):
self.name + ', but the layer isn\'t built. '
'You can build it manually via: `' +
self.name + '.build(batch_input_shape)`.')
return sum([K.count_params(p) for p in self.weights])
return count_params(self.weights)


class InputLayer(Layer):
Expand Down
21 changes: 21 additions & 0 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .. import losses
from .. import metrics as metrics_module
from ..utils.generic_utils import Progbar
from ..utils.layer_utils import count_params
from .. import callbacks as cbks
from ..legacy import interfaces

Expand Down Expand Up @@ -945,9 +946,29 @@ def handle_metrics(metrics, weights=None):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights

def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.

This will raise a warning if `trainable_weights` and
`_collected_trainable_weights` are consistent (i.e. have the same
number of parameters).
Inconsistency will typically arise when one modifies `model.trainable`
without calling `model.compile` again.
"""
if not hasattr(self, '_collected_trainable_weights'):
return

if (count_params(self.trainable_weights) !=
count_params(self._collected_trainable_weights)):
warnings.warn(UserWarning(
'Discrepancy between trainable weights and collected trainable'
' weights, did you set `model.trainable` without calling'
' `model.compile` after ?'))

def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
self._check_trainable_weights_consistency()
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
Expand Down
20 changes: 18 additions & 2 deletions keras/utils/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
import numpy as np


def count_params(weights):
"""Count the total number of scalars composing the weights.

# Arguments
weights: An iterable containing the weights on which to compute params

# Returns
The total number of scalars composing the weights
"""
return int(np.sum([K.count_params(p) for p in set(weights)]))


def print_summary(model, line_length=None, positions=None, print_fn=print):
"""Prints a summary of a model.

Expand Down Expand Up @@ -134,8 +146,12 @@ def print_layer_summary_with_connections(layer):
else:
print_fn('_' * line_length)

trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
model._check_trainable_weights_consistency()
if hasattr(model, '_collected_trainable_weights'):
trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model.trainable_weights)

non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

Expand Down
41 changes: 41 additions & 0 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,5 +923,46 @@ def test_model_custom_target_tensors():
[output_a_np, output_b_np])


@pytest.mark.skipif(sys.version_info < (3,), reason='Cannot catch warnings in python 2')
@keras_test
def test_trainable_weights_count_consistency():
"""Tests the trainable weights consistency check of Model.

This verifies that a warning is shown if model.trainable is modified
and the model is summarized/run without a new call to .compile()

Reproduce issue #8121
"""
a = Input(shape=(3,), name='input_a')
model1 = Model(inputs=a, outputs=Dense(1)(a))

model1.trainable = False
b = Input(shape=(3,), name='input_b')
y = model1(b)
model2 = Model(inputs=b, outputs=Dense(1)(y))

model2.compile(optimizer='adam', loss='mse')

model1.trainable = True

# Should warn on .summary()
with pytest.warns(UserWarning) as w:
model2.summary()
warning_raised = any(['Discrepancy' in str(w_.message) for w_ in w])
assert warning_raised, 'No warning raised when trainable is modified without .compile.'

# And on .fit()
with pytest.warns(UserWarning) as w:
model2.fit(x=np.zeros((5, 3)), y=np.zeros((5, 1)))
warning_raised = any(['Discrepancy' in str(w_.message) for w_ in w])
assert warning_raised, 'No warning raised when trainable is modified without .compile.'

# And shouldn't warn if we recompile
model2.compile(optimizer='adam', loss='mse')
with pytest.warns(None) as w:
model2.summary()
assert len(w) == 0, "Warning raised even when .compile() is called after modifying .trainable"


if __name__ == '__main__':
pytest.main([__file__])