Skip to content

Commit

Permalink
Add asset support in SavedModel loader py.
Browse files Browse the repository at this point in the history
Change: 136195816
  • Loading branch information
sukritiramesh authored and tensorflower-gardener committed Oct 14, 2016
1 parent 3ba17a9 commit 3373d23
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tensorflow/python/saved_model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import os

from google.protobuf import text_format
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants
Expand Down Expand Up @@ -118,6 +119,37 @@ def _parse_saved_model(export_dir):
return saved_model


def _get_asset_tensors(export_dir, meta_graph_def_to_load):
"""Gets the asset tensors, if defined in the meta graph def to load.
Args:
export_dir: Directory where the SavedModel is located.
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
Returns:
A dictionary of asset tensors, keyed by the name of the asset tensor. The
value in the map corresponds to the absolute path of the asset file.
"""
# Collection-def that may contain the assets key.
collection_def = meta_graph_def_to_load.collection_def

asset_tensor_dict = {}
if constants.ASSETS_KEY in collection_def:
# Location of the assets for SavedModel.
assets_directory = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.ASSETS_DIRECTORY))
assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
# Process each asset and add it to the asset tensor dictionary.
for asset_any_proto in assets_any_proto:
asset_proto = meta_graph_pb2.AssetFileDef()
asset_any_proto.Unpack(asset_proto)
asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
compat.as_bytes(assets_directory),
compat.as_bytes(asset_proto.filename))
return asset_tensor_dict


def load(sess, tags, export_dir):
"""Loads the model from a SavedModel as specified by tags.
Expand Down Expand Up @@ -161,5 +193,8 @@ def load(sess, tags, export_dir):
# Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path)

# Get asset tensors, if any.
_get_asset_tensors(export_dir, meta_graph_def_to_load)

# Return the meta graph def that was loaded into the session.
return meta_graph_def_to_load

0 comments on commit 3373d23

Please sign in to comment.