Skip to content

Commit

Permalink
Add feature to supply SavedModel with a main_op.
Browse files Browse the repository at this point in the history
Change: 139254613
  • Loading branch information
sukritiramesh authored and tensorflower-gardener committed Nov 15, 2016
1 parent d2693c8 commit 97e0b7b
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 17 deletions.
12 changes: 12 additions & 0 deletions tensorflow/python/saved_model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ py_library(
],
)

py_library(
name = "main_op",
srcs = ["main_op.py"],
srcs_version = "PY2AND3",
deps = [
":constants",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:platform",
],
)

py_test(
name = "saved_model_test",
size = "small",
Expand All @@ -66,6 +77,7 @@ py_test(
deps = [
":builder",
":loader",
":main_op",
":tag_constants",
":utils",
"//tensorflow:tensorflow_py",
Expand Down
44 changes: 34 additions & 10 deletions tensorflow/python/saved_model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,20 @@ def _maybe_add_legacy_init_op(self, legacy_init_op=None):
legacy_init_op)
ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, legacy_init_op)

def _add_main_op(self, main_op):
"""Add main op to the SavedModel.
Args:
main_op: Main op to run as part of graph initialization.
Raises:
TypeError if main op is not of type `Operation`.
"""
if main_op is not None:
if not isinstance(main_op, ops.Operation):
raise TypeError("main_op needs to be an Operation: %r" % main_op)
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)

def _maybe_save_assets(self, assets_collection_to_add=None):
"""Saves assets to the meta graph.
Expand Down Expand Up @@ -253,7 +267,8 @@ def add_meta_graph(self,
signature_def_map=None,
assets_collection=None,
legacy_init_op=None,
clear_devices=False):
clear_devices=False,
main_op=None):
"""Adds the current meta graph to the SavedModel.
Creates a Saver in the current scope and uses the Saver to export the meta
Expand All @@ -267,10 +282,11 @@ def add_meta_graph(self,
assets_collection: Assets collection to be saved with SavedModel. Note
that this collection should be a subset of the assets saved as part of
the first meta graph in the SavedModel.
legacy_init_op: Op or group of ops to execute after the restore op upon a
load.
legacy_init_op: Legacy support for op or group of ops to execute after the
restore op upon a load.
clear_devices: Set to true if the device info on the default graph should
be cleared.
main_op: Op or group of ops to execute when the graph is loaded.
Raises:
AssertionError: If the variables for the SavedModel have not been saved
Expand All @@ -284,8 +300,11 @@ def add_meta_graph(self,
# Save asset files and write them to disk, if any.
self._save_and_write_assets(assets_collection)

# Add legacy init op to the SavedModel.
self._maybe_add_legacy_init_op(legacy_init_op)
if main_op is None:
# Add legacy init op to the SavedModel.
self._maybe_add_legacy_init_op(legacy_init_op)
else:
self._add_main_op(main_op)

# Initialize a saver to generate a sharded output for all variables in the
# current scope.
Expand All @@ -305,7 +324,8 @@ def add_meta_graph_and_variables(self,
signature_def_map=None,
assets_collection=None,
legacy_init_op=None,
clear_devices=False):
clear_devices=False,
main_op=None):
"""Adds the current meta graph to the SavedModel and saves variables.
Creates a Saver to save the variables from the provided session. Exports the
Expand All @@ -321,10 +341,11 @@ def add_meta_graph_and_variables(self,
signature_def_map: The map of signature def map to add to the meta graph
def.
assets_collection: Assets collection to be saved with SavedModel.
legacy_init_op: Op or group of ops to execute after the restore op upon a
load.
legacy_init_op: Legacy support for op or group of ops to execute after the
restore op upon a load.
clear_devices: Set to true if the device info on the default graph should
be cleared.
main_op: Op or group of ops to execute when the graph is loaded.
"""
if self._has_saved_variables:
raise AssertionError("Variables and assets have already been saved. "
Expand All @@ -344,8 +365,11 @@ def add_meta_graph_and_variables(self,
compat.as_text(variables_dir),
compat.as_text(constants.VARIABLES_FILENAME))

# Add legacy init op to the SavedModel.
self._maybe_add_legacy_init_op(legacy_init_op)
if main_op is None:
# Add legacy init op to the SavedModel.
self._maybe_add_legacy_init_op(legacy_init_op)
else:
self._add_main_op(main_op)

# Initialize a saver to generate a sharded output for all variables in the
# current scope.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/saved_model/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ASSETS_KEY = "saved_model_assets"

LEGACY_INIT_OP_KEY = "legacy_init_op"
MAIN_OP_KEY = "saved_model_main_op"

SAVED_MODEL_SCHEMA_VERSION = 1
SAVED_MODEL_FILENAME_PB = "saved_model.pb"
Expand Down
38 changes: 31 additions & 7 deletions tensorflow/python/saved_model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,29 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
return asset_tensor_dict


def _get_main_op_tensor(meta_graph_def_to_load):
"""Gets the main op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
Returns:
The main op tensor, if it exists and `None` otherwise.
Raises:
RuntimeError: If the collection def corresponding to the main op key has
other than exactly one tensor.
"""
collection_def = meta_graph_def_to_load.collection_def
main_op_tensor = None
if constants.MAIN_OP_KEY in collection_def:
main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
if len(main_ops) != 1:
raise RuntimeError("Expected exactly one SavedModel main op.")
main_op_tensor = tf.get_collection(constants.MAIN_OP_KEY)[0]
return main_op_tensor


def _get_legacy_init_op_tensor(meta_graph_def_to_load):
"""Gets the legacy init op tensor, if one exists.
Expand Down Expand Up @@ -220,12 +243,13 @@ def load(sess, tags, export_dir):
asset_tensors_dictionary = _get_asset_tensors(export_dir,
meta_graph_def_to_load)

# TODO(sukritiramesh): Add support for a single main op to run upon load,
# which will supersede the legacy_init_op.
legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)

if legacy_init_op_tensor is not None:
sess.run(fetches=[legacy_init_op_tensor],
feed_dict=asset_tensors_dictionary)
main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
if main_op_tensor is not None:
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
else:
legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
if legacy_init_op_tensor is not None:
sess.run(fetches=[legacy_init_op_tensor],
feed_dict=asset_tensors_dictionary)

return meta_graph_def_to_load
63 changes: 63 additions & 0 deletions tensorflow/python/saved_model/main_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SavedModel main op.
Builds a main op that defines the sequence of ops to be run as part of the
SavedModel load/restore operations.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops
from tensorflow.python.ops import variables as tf_variables


def main_op():
"""Returns a main op to init variables and tables.
Returns the main op including the group of ops that initializes all
variables, initializes local variables and initialize all tables.
Returns:
The set of ops to be run as part of the main op upon the load operation.
"""
init = tf_variables.initialize_all_variables()
init_local = tf_variables.initialize_local_variables()
init_tables = tf_data_flow_ops.initialize_all_tables()
return tf.group(init, init_local, init_tables)


def main_op_with_restore(restore_op_name):
"""Returns a main op to init variables, tables and restore the graph.
Returns the main op including the group of ops that initializes all
variables, initialize local variables, initialize all tables and the restore
op name.
Args:
restore_op_name: Name of the op to use to restore the graph.
Returns:
The set of ops to be run as part of the main op upon the load operation.
"""
simple_main_op = main_op()
with ops.control_dependency([simple_main_op]):
restore = restore_op_name
return tf.group(restore)
35 changes: 35 additions & 0 deletions tensorflow/python/saved_model/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat
Expand Down Expand Up @@ -394,6 +395,40 @@ def testAssets(self):
compat.as_bytes("ignored.txt"))
self.assertFalse(file_io.file_exists(ignored_asset_path))

def testCustomMainOp(self):
export_dir = os.path.join(tf.test.get_temp_dir(), "test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)

with self.test_session(graph=tf.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = tf.Variable(1, name="v1")
tf.add_to_collection("v", v1)
v2 = tf.Variable(2, name="v2")
tf.add_to_collection("v", v2)

# Initialize another variable `v3` to 42.
v3 = tf.Variable(42, name="v3", trainable=False, collections=[])
tf.add_to_collection("v", v3)

# Set up an assignment op to be run as part of the main_op.
assign_v3 = tf.assign(v3, tf.add(v1, v2))
custom_main_op = tf.group(main_op.main_op(), assign_v3)

sess.run(tf.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], main_op=custom_main_op)

# Save the SavedModel to disk.
builder.save()

with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, tf.get_collection("v")[0].eval())
self.assertEqual(2, tf.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the main_op, following a restore.
self.assertEqual(3, tf.get_collection("v")[2].eval())

def testLegacyInitOp(self):
export_dir = os.path.join(tf.test.get_temp_dir(), "test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
Expand Down

0 comments on commit 97e0b7b

Please sign in to comment.