Skip to content

Commit

Permalink
Add a model.check_trainable_weights_consistency (keras-team#8234)
Browse files Browse the repository at this point in the history
* Add a model.check_trainable_weights_consistency

This will raise a UserWarning when the user modifies model.trainable
and tries to print a model summary or launch a fit without having
called .compile.

Calling .compile() is necessary because trainable weights are collected
in compile (model._collected_trainable_weights).

* Fix comments and cosmetics on count_params

* Make Model.check_trainable_weights_consistency private

Also fix its docstring

* Fix docstring of test_trainable_weights_count_consistency
  • Loading branch information
julienr authored and fchollet committed Oct 25, 2017
1 parent e48bc45 commit cab77c8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 3 deletions.
3 changes: 2 additions & 1 deletion keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .. import initializers
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.layer_utils import print_summary as print_layer_summary
from ..utils.layer_utils import count_params
from ..utils.generic_utils import has_arg
from ..utils import conv_utils
from ..legacy import interfaces
Expand Down Expand Up @@ -1269,7 +1270,7 @@ def count_params(self):
self.name + ', but the layer isn\'t built. '
'You can build it manually via: `' +
self.name + '.build(batch_input_shape)`.')
return sum([K.count_params(p) for p in self.weights])
return count_params(self.weights)


class InputLayer(Layer):
Expand Down
21 changes: 21 additions & 0 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .. import losses
from .. import metrics as metrics_module
from ..utils.generic_utils import Progbar
from ..utils.layer_utils import count_params
from .. import callbacks as cbks
from ..legacy import interfaces

Expand Down Expand Up @@ -954,9 +955,29 @@ def handle_metrics(metrics, weights=None):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights

def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
This will raise a warning if `trainable_weights` and
`_collected_trainable_weights` are consistent (i.e. have the same
number of parameters).
Inconsistency will typically arise when one modifies `model.trainable`
without calling `model.compile` again.
"""
if not hasattr(self, '_collected_trainable_weights'):
return

if (count_params(self.trainable_weights) !=
count_params(self._collected_trainable_weights)):
warnings.warn(UserWarning(
'Discrepancy between trainable weights and collected trainable'
' weights, did you set `model.trainable` without calling'
' `model.compile` after ?'))

def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
self._check_trainable_weights_consistency()
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
Expand Down
20 changes: 18 additions & 2 deletions keras/utils/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
import numpy as np


def count_params(weights):
"""Count the total number of scalars composing the weights.
# Arguments
weights: An iterable containing the weights on which to compute params
# Returns
The total number of scalars composing the weights
"""
return int(np.sum([K.count_params(p) for p in set(weights)]))


def print_summary(model, line_length=None, positions=None, print_fn=print):
"""Prints a summary of a model.
Expand Down Expand Up @@ -134,8 +146,12 @@ def print_layer_summary_with_connections(layer):
else:
print_fn('_' * line_length)

trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
model._check_trainable_weights_consistency()
if hasattr(model, '_collected_trainable_weights'):
trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model.trainable_weights)

non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

Expand Down
41 changes: 41 additions & 0 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,5 +937,46 @@ def test_model_custom_target_tensors():
[output_a_np, output_b_np])


@pytest.mark.skipif(sys.version_info < (3,), reason='Cannot catch warnings in python 2')
@keras_test
def test_trainable_weights_count_consistency():
"""Tests the trainable weights consistency check of Model.
This verifies that a warning is shown if model.trainable is modified
and the model is summarized/run without a new call to .compile()
Reproduce issue #8121
"""
a = Input(shape=(3,), name='input_a')
model1 = Model(inputs=a, outputs=Dense(1)(a))

model1.trainable = False
b = Input(shape=(3,), name='input_b')
y = model1(b)
model2 = Model(inputs=b, outputs=Dense(1)(y))

model2.compile(optimizer='adam', loss='mse')

model1.trainable = True

# Should warn on .summary()
with pytest.warns(UserWarning) as w:
model2.summary()
warning_raised = any(['Discrepancy' in str(w_.message) for w_ in w])
assert warning_raised, 'No warning raised when trainable is modified without .compile.'

# And on .fit()
with pytest.warns(UserWarning) as w:
model2.fit(x=np.zeros((5, 3)), y=np.zeros((5, 1)))
warning_raised = any(['Discrepancy' in str(w_.message) for w_ in w])
assert warning_raised, 'No warning raised when trainable is modified without .compile.'

# And shouldn't warn if we recompile
model2.compile(optimizer='adam', loss='mse')
with pytest.warns(None) as w:
model2.summary()
assert len(w) == 0, "Warning raised even when .compile() is called after modifying .trainable"


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit cab77c8

Please sign in to comment.