Skip to content

Commit

Permalink
Skip loading functional model if config cannot be found, and instead …
Browse files Browse the repository at this point in the history
…fall back to creating a subclassed model.

Functional models that use subclassed layers/models that do not overwrite get_config() will not be saved with their config objects. Will (in another CL) add an additional attribute to the metadata that contains the network topology without the config objects.

PiperOrigin-RevId: 269704731
  • Loading branch information
k-w-w authored and tensorflower-gardener committed Sep 18, 2019
1 parent 8576b7b commit e502d88
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tensorflow/python/keras/saving/saved_model/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def load(path, compile=True): # pylint: disable=redefined-builtin
return model


def _is_graph_network(node):
# pylint: disable=protected-access
return (
isinstance(node, RevivedNetwork) and
node._serialized_attributes['metadata'].get('is_graph_network', False) and
hasattr(node, '_config'))
# pylint: enable=protected-access


class KerasObjectLoader(tf_load.Loader):
"""Loader that recreates Keras objects."""

Expand All @@ -117,8 +126,7 @@ def _finalize(self):
for node in self._nodes:
if isinstance(node, RevivedLayer):
node.built = True
is_graph_network = node._serialized_attributes['metadata'].get(
'is_graph_network', False)
is_graph_network = _is_graph_network(node)
if not (isinstance(node, models_lib.Sequential) or is_graph_network):
if hasattr(node.keras_api, 'call_and_return_conditional_losses'):
node.call = utils.use_wrapped_call(
Expand All @@ -135,8 +143,7 @@ def _finalize(self):
inputs = call_fn.input_signature[0]

# Set model inputs and outputs.
is_graph_network = node._serialized_attributes['metadata'].get(
'is_graph_network', False)
is_graph_network = _is_graph_network(node)
if isinstance(node, models_lib.Sequential):
with trackable.no_automatic_dependency_tracking_scope(node):
node._layers = []
Expand Down

0 comments on commit e502d88

Please sign in to comment.