Skip to content

Commit 97c3fef

Browse files
authored
Merge pull request tensorflow#45390 from k-w-w/r2.4
Revert recent SavedModel changes
2 parents ab0402b + 45967fe commit 97c3fef

File tree

9 files changed

+30
-158
lines changed

9 files changed

+30
-158
lines changed

tensorflow/core/protobuf/saved_object_graph.proto

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ message SavedUserObject {
7676
string identifier = 1;
7777
// Version information from the producer of this SavedUserObject.
7878
VersionDef version = 2;
79-
// Deprecated! At the time of deprecation, Keras was the only user of this
80-
// field, and its saving and loading code will be updated shortly.
81-
// Please save your application-specific metadata to separate file
8279
// Initialization-related metadata.
8380
string metadata = 3;
8481
}

tensorflow/python/keras/saving/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ py_library(
4949
deps = [
5050
"//tensorflow/python:lib",
5151
"//tensorflow/python:math_ops",
52-
"//tensorflow/python:platform",
5352
"//tensorflow/python:saver",
5453
"//tensorflow/python:tensor_spec",
5554
"//tensorflow/python/eager:def_function",

tensorflow/python/keras/saving/saved_model/constants.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,3 @@
2626
# Keys for the serialization cache.
2727
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
2828
KERAS_CACHE_KEY = 'keras_serialized_attributes'
29-
30-
31-
# Name of Keras metadata file stored in the SavedModel.
32-
SAVED_METADATA_PATH = 'keras_metadata.pb'

tensorflow/python/keras/saving/saved_model/load.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import os
2120
import re
2221
import types
2322

24-
from google.protobuf import message
25-
2623
from tensorflow.core.framework import versions_pb2
2724
from tensorflow.python.eager import context
2825
from tensorflow.python.eager import function as defun
@@ -41,7 +38,6 @@
4138
from tensorflow.python.keras.utils import generic_utils
4239
from tensorflow.python.keras.utils import metrics_utils
4340
from tensorflow.python.keras.utils.generic_utils import LazyLoader
44-
from tensorflow.python.platform import gfile
4541
from tensorflow.python.platform import tf_logging as logging
4642
from tensorflow.python.saved_model import load as tf_load
4743
from tensorflow.python.saved_model import loader_impl
@@ -125,26 +121,13 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
125121
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
126122
# TODO(kathywu): Add code to load from objects that contain all endpoints
127123

128-
# Look for metadata file or parse the SavedModel
124+
# The Keras metadata file is not yet saved, so create it from the SavedModel.
129125
metadata = saved_metadata_pb2.SavedMetadata()
130126
meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
131127
object_graph_def = meta_graph_def.object_graph_def
132-
path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
133-
if gfile.Exists(path_to_metadata_pb):
134-
try:
135-
with gfile.GFile(path_to_metadata_pb, 'rb') as f:
136-
file_content = f.read()
137-
metadata.ParseFromString(file_content)
138-
except message.DecodeError as e:
139-
raise IOError('Cannot parse keras metadata {}: {}.'
140-
.format(path_to_metadata_pb, str(e)))
141-
else:
142-
logging.warning('SavedModel saved prior to TF 2.4 detected when loading '
143-
'Keras model. Please ensure that you are saving the model '
144-
'with model.save() or tf.keras.models.save_model(), *NOT* '
145-
'tf.saved_model.save(). To confirm, there should be a file '
146-
'named "keras_metadata.pb" in the SavedModel directory.')
147-
_read_legacy_metadata(object_graph_def, metadata)
128+
# TODO(kathywu): When the keras metadata file is saved, load it directly
129+
# instead of calling the _read_legacy_metadata function.
130+
_read_legacy_metadata(object_graph_def, metadata)
148131

149132
if not metadata.nodes:
150133
# When there are no Keras objects, return the results from the core loader

tensorflow/python/keras/saving/saved_model/save.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,15 @@
1818
from __future__ import print_function
1919

2020
import os
21-
22-
from tensorflow.core.framework import versions_pb2
2321
from tensorflow.python.distribute import distribution_strategy_context
2422
from tensorflow.python.keras import backend as K
25-
from tensorflow.python.keras.protobuf import saved_metadata_pb2
2623
from tensorflow.python.keras.saving import saving_utils
27-
from tensorflow.python.keras.saving.saved_model import constants
2824
from tensorflow.python.keras.saving.saved_model import save_impl
2925
from tensorflow.python.keras.saving.saved_model import utils
3026
from tensorflow.python.keras.utils.generic_utils import LazyLoader
3127
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
32-
from tensorflow.python.platform import gfile
3328
from tensorflow.python.saved_model import save as save_lib
3429

35-
3630
# To avoid circular dependencies between keras/engine and keras/saving,
3731
# code in keras/saving must delay imports.
3832

@@ -92,39 +86,7 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
9286
# we use the default replica context here.
9387
with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
9488
with utils.keras_option_scope(save_traces):
95-
saved_nodes, node_paths = save_lib.save_and_return_nodes(
96-
model, filepath, signatures, options)
97-
98-
# Save all metadata to a separate file in the SavedModel directory.
99-
metadata = generate_keras_metadata(saved_nodes, node_paths)
100-
101-
with gfile.GFile(
102-
os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w:
103-
w.write(metadata.SerializeToString(deterministic=True))
89+
save_lib.save(model, filepath, signatures, options)
10490

10591
if not include_optimizer:
10692
model.optimizer = orig_optimizer
107-
108-
109-
def generate_keras_metadata(saved_nodes, node_paths):
110-
"""Constructs a KerasMetadata proto with the metadata of each keras object."""
111-
metadata = saved_metadata_pb2.SavedMetadata()
112-
113-
for node_id, node in enumerate(saved_nodes):
114-
if isinstance(node, base_layer.Layer):
115-
path = node_paths[node]
116-
if not path:
117-
node_path = "root"
118-
else:
119-
node_path = "root.{}".format(
120-
".".join([ref.name for ref in path]))
121-
122-
metadata.nodes.add(
123-
node_id=node_id,
124-
node_path=node_path,
125-
version=versions_pb2.VersionDef(
126-
producer=1, min_consumer=1, bad_consumers=[]),
127-
identifier=node._object_identifier, # pylint: disable=protected-access
128-
metadata=node._tracking_metadata) # pylint: disable=protected-access
129-
130-
return metadata

tensorflow/python/keras/saving/saved_model/saved_model_test.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import os
2828
import shutil
29+
import sys
2930

3031
from absl.testing import parameterized
3132
import numpy as np
@@ -410,16 +411,14 @@ def testBatchNormUpdates(self):
410411
self.evaluate(variables.variables_initializer(model.variables))
411412
saved_model_dir = self._save_model_dir()
412413

413-
# TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
414-
# metadata warning.
415-
# with self.captureWritesToStream(sys.stderr) as captured_logs:
416-
model.save(saved_model_dir, save_format='tf')
417-
loaded = keras_load.load(saved_model_dir)
414+
with self.captureWritesToStream(sys.stderr) as captured_logs:
415+
model.save(saved_model_dir, save_format='tf')
416+
loaded = keras_load.load(saved_model_dir)
418417

419418
# Assert that saving does not log deprecation warnings
420419
# (even if it needs to set learning phase for compat reasons)
421-
# if context.executing_eagerly():
422-
# self.assertNotIn('deprecated', captured_logs.contents())
420+
if context.executing_eagerly():
421+
self.assertNotIn('deprecated', captured_logs.contents())
423422

424423
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
425424
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)

tensorflow/python/saved_model/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,6 @@ py_strict_library(
349349
"//tensorflow/python:framework",
350350
"//tensorflow/python:framework_ops",
351351
"//tensorflow/python:lib",
352-
"//tensorflow/python:platform",
353352
"//tensorflow/python:resource_variable_ops",
354353
"//tensorflow/python:tensor_util",
355354
"//tensorflow/python:tf_export",

tensorflow/python/saved_model/save.py

Lines changed: 17 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from tensorflow.python.ops import array_ops
4444
from tensorflow.python.ops import control_flow_ops
4545
from tensorflow.python.ops import resource_variable_ops
46-
from tensorflow.python.platform import tf_logging
4746
from tensorflow.python.saved_model import builder_impl
4847
from tensorflow.python.saved_model import constants
4948
from tensorflow.python.saved_model import function_serialization
@@ -183,9 +182,8 @@ def __init__(self, checkpoint_view, options, wrapped_functions=None):
183182
"""
184183
self.options = options
185184
self.checkpoint_view = checkpoint_view
186-
trackable_objects, path_to_root, node_ids, slot_variables = (
187-
self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
188-
self.node_paths = path_to_root
185+
trackable_objects, node_ids, slot_variables = (
186+
self.checkpoint_view.objects_ids_and_slot_variables())
189187
self.nodes = trackable_objects
190188
self.node_ids = node_ids
191189
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
@@ -752,17 +750,14 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
752750
if serialized is not None:
753751
proto.concrete_functions[name].CopyFrom(serialized)
754752

755-
saved_object_metadata = False
756753
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
757-
has_saved_object_metadata = _write_object_proto(
758-
obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
759-
saved_object_metadata = saved_object_metadata or has_saved_object_metadata
760-
return proto, saved_object_metadata
754+
_write_object_proto(obj, obj_proto, asset_file_def_index,
755+
saveable_view.function_name_map)
756+
return proto
761757

762758

763759
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
764760
"""Saves an object into SavedObject proto."""
765-
has_saved_object_metadata = False # The metadata field will be deprecated.
766761
if isinstance(obj, tracking.Asset):
767762
proto.asset.SetInParent()
768763
proto.asset.asset_file_def_index = asset_file_def_index[obj]
@@ -798,14 +793,11 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
798793
if registered_type_proto is None:
799794
# Fallback for types with no matching registration
800795
# pylint:disable=protected-access
801-
metadata = obj._tracking_metadata
802-
if metadata:
803-
has_saved_object_metadata = True
804796
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
805797
identifier=obj._object_identifier,
806798
version=versions_pb2.VersionDef(
807799
producer=1, min_consumer=1, bad_consumers=[]),
808-
metadata=metadata)
800+
metadata=obj._tracking_metadata)
809801
# pylint:enable=protected-access
810802
proto.user_object.CopyFrom(registered_type_proto)
811803

@@ -818,7 +810,6 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
818810
# documentation.
819811
if hasattr(obj, "_write_object_proto"):
820812
obj._write_object_proto(proto, options) # pylint: disable=protected-access
821-
return has_saved_object_metadata
822813

823814

824815
def _export_debug_info(exported_graph, export_dir):
@@ -1016,7 +1007,8 @@ def serve():
10161007
instances with input signatures or concrete functions. Keys of such a
10171008
dictionary may be arbitrary strings, but will typically be from the
10181009
`tf.saved_model.signature_constants` module.
1019-
options: `tf.saved_model.SaveOptions` object for configuring save options.
1010+
options: Optional, `tf.saved_model.SaveOptions` object that specifies
1011+
options for saving.
10201012
10211013
Raises:
10221014
ValueError: If `obj` is not trackable.
@@ -1030,40 +1022,15 @@ def serve():
10301022
May not be called from within a function body.
10311023
@end_compatibility
10321024
"""
1033-
save_and_return_nodes(obj, export_dir, signatures, options,
1034-
raise_metadata_warning=True)
1035-
1036-
1037-
def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
1038-
raise_metadata_warning=False):
1039-
"""Saves a SavedModel while returning all saved nodes and their paths.
1040-
1041-
Please see `tf.saved_model.save` for details.
1042-
1043-
Args:
1044-
obj: A trackable object to export.
1045-
export_dir: A directory in which to write the SavedModel.
1046-
signatures: A function or dictionary of functions to save in the SavedModel
1047-
as signatures.
1048-
options: `tf.saved_model.SaveOptions` object for configuring save options.
1049-
raise_metadata_warning: Whether to raise the metadata warning. This arg will
1050-
be removed in TF 2.5.
1051-
1052-
Returns:
1053-
A tuple of (a list of saved nodes in the order they are serialized to the
1054-
`SavedObjectGraph`, dictionary mapping nodes to one possible path from
1055-
the root node to the key node)
1056-
"""
10571025
options = options or save_options.SaveOptions()
10581026
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
10591027
# compatible (no sessions) and share it with this export API rather than
10601028
# making a SavedModel proto and writing it directly.
10611029
saved_model = saved_model_pb2.SavedModel()
10621030
meta_graph_def = saved_model.meta_graphs.add()
10631031

1064-
_, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
1065-
_build_meta_graph(obj, signatures, options, meta_graph_def,
1066-
raise_metadata_warning))
1032+
_, exported_graph, object_saver, asset_info = _build_meta_graph(
1033+
obj, signatures, options, meta_graph_def)
10671034
saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
10681035

10691036
# Write the checkpoint, copy assets into the assets directory, and write out
@@ -1103,8 +1070,6 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
11031070
# constants in the saved graph.
11041071
ops.dismantle_graph(exported_graph)
11051072

1106-
return saved_nodes, node_paths
1107-
11081073

11091074
def export_meta_graph(obj, filename, signatures=None, options=None):
11101075
"""Exports the MetaGraph proto of the `obj` to a file.
@@ -1131,7 +1096,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
11311096
"""
11321097
options = options or save_options.SaveOptions()
11331098
export_dir = os.path.dirname(filename)
1134-
meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
1099+
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
11351100
obj, signatures, options)
11361101

11371102
file_io.atomic_write_string_to_file(
@@ -1150,8 +1115,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
11501115
def _build_meta_graph_impl(obj,
11511116
signatures,
11521117
options,
1153-
meta_graph_def=None,
1154-
raise_metadata_warning=True):
1118+
meta_graph_def=None):
11551119
"""Creates a MetaGraph containing the resources and functions of an object."""
11561120
if ops.inside_function():
11571121
raise AssertionError(
@@ -1195,35 +1159,17 @@ def _build_meta_graph_impl(obj,
11951159
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
11961160
function_aliases[fdef.name] = alias
11971161

1198-
object_graph_proto, saved_object_metadata = _serialize_object_graph(
1199-
saveable_view, asset_info.asset_index)
1162+
object_graph_proto = _serialize_object_graph(saveable_view,
1163+
asset_info.asset_index)
12001164
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
12011165

1202-
if saved_object_metadata and raise_metadata_warning:
1203-
tf_logging.warn(
1204-
'FOR KERAS USERS: The object that you are saving contains one or more '
1205-
'Keras models or layers. If you are loading the SavedModel with '
1206-
'`tf.keras.models.load_model`, continue reading (otherwise, you may '
1207-
'ignore the following instructions). Please change your code to save '
1208-
'with `tf.keras.models.save_model` or `model.save`, and confirm that '
1209-
'the file "keras.metadata" exists in the export directory. In the '
1210-
'future, Keras will only load the SavedModels that have this file. In '
1211-
'other words, `tf.saved_model.save` will no longer write SavedModels '
1212-
'that can be recovered as Keras models (this will apply in TF 2.5).'
1213-
'\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
1214-
' this property has been used to save metadata in the SavedModel. The '
1215-
'metadta field will be deprecated soon, so please move the metadata to '
1216-
'a different file.')
1217-
1218-
return (meta_graph_def, exported_graph, object_saver, asset_info,
1219-
saveable_view.nodes, saveable_view.node_paths)
1166+
return meta_graph_def, exported_graph, object_saver, asset_info
12201167

12211168

12221169
def _build_meta_graph(obj,
12231170
signatures,
12241171
options,
1225-
meta_graph_def=None,
1226-
raise_metadata_warning=True):
1172+
meta_graph_def=None):
12271173
"""Creates a MetaGraph under a save context.
12281174
12291175
Args:
@@ -1236,8 +1182,6 @@ def _build_meta_graph(obj,
12361182
options: `tf.saved_model.SaveOptions` object that specifies options for
12371183
saving.
12381184
meta_graph_def: Optional, the MetaGraphDef proto fill.
1239-
raise_metadata_warning: Whether to raise a warning when user objects contain
1240-
non-empty metadata.
12411185
12421186
Raises:
12431187
AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
@@ -1251,5 +1195,4 @@ def _build_meta_graph(obj,
12511195
"""
12521196

12531197
with save_context.save_context(options):
1254-
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
1255-
raise_metadata_warning)
1198+
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

0 commit comments

Comments
 (0)