Skip to content

Commit

Permalink
Adds a contextmanager to easily evaluate TFHub modules.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 220646945
  • Loading branch information
TensorFlow Hub Authors authored and arnoegw committed Nov 9, 2018
1 parent 772a74a commit e815d2e
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 11 deletions.
1 change: 1 addition & 0 deletions tensorflow_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tensorflow_hub.image_util import get_expected_image_size
from tensorflow_hub.image_util import get_num_image_channels
from tensorflow_hub.image_util import ImageModuleInfo
from tensorflow_hub.module import eval_function_for_module
from tensorflow_hub.module import load_module_spec
from tensorflow_hub.module import Module
from tensorflow_hub.module_spec import ModuleSpec
Expand Down
90 changes: 90 additions & 0 deletions tensorflow_hub/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import contextlib
import six
import tensorflow as tf
from tensorflow_hub import module_spec
Expand Down Expand Up @@ -464,3 +465,92 @@ def _prepare_outputs(dict_outputs, as_dict):
return dict_outputs["default"]
else:
raise TypeError("There is no output named 'default'. Use as_dict=True.")


@contextlib.contextmanager
def eval_function_for_module(spec, tags=None):
"""Context manager that yields a function to directly evaluate a Module.
This creates a separate graph, in which all of the signatures of the module
are instantiated. Then, it creates a session and initializes the module
variables. Finally, it returns a function which can be used to evaluate the
module signatures.
The function returned by eval_function_for_module has the same syntax as
Module.__call__ , except that inputs and outputs are not tensors but actual
values as used with Session.run().
```python
with hub.eval_function_for_module("/tmp/text-embedding") as f:
# The module can be directly evaluated using f without constructing a graph.
embeddings = f(["Hello world!",], signature="mysignature")
```
Args:
spec: A ModuleSpec defining the Module to instantiate or a path where to
load a ModuleSpec from via `load_module_spec`.
tags: A set of strings specifying the graph variant to use.
Yields:
A function whose keyword arguments are fed into the tfhub module and which
returns a dictionary with the value of the output tensors.
Raises:
RuntimeError: explaning the reason why it failed to instantiate the
Module.
ValueError: if the requested graph variant does not exists.
"""
# We create a separate graph and add all the signatures of the module to it.
original_graph = tf.get_default_graph()
with tf.Graph().as_default():
module = Module(spec, tags=tags)
input_tensors_per_signature = {}
output_tensors_per_signature = {}
for signature in module.get_signature_names():
# We scope with the signature name as different signatures will likely
# contain tensors with the same name (e.g. the input and output tensors).
with tf.variable_scope(signature):
input_tensors = {}
for name, tensorinfo in module.get_input_info_dict(signature).items():
# We need to be care with the shape as it may be fully-known,
# partially-known or even unknown.
shape = tensorinfo.get_shape()
effective_shape = None if shape.dims is None else shape.as_list()
if tensorinfo.is_sparse:
input_tensors[name] = tf.sparse_placeholder(
tensorinfo.dtype, shape=effective_shape, name=name)
else:
input_tensors[name] = tf.placeholder(
tensorinfo.dtype, shape=effective_shape, name=name)
input_tensors_per_signature[signature] = input_tensors
output_tensors_per_signature[signature] = module(
input_tensors_per_signature[signature],
signature=signature,
as_dict=True)

# Evaluating the tfhub module requires an active tensorflow session.
with tf.train.SingularMonitoredSession() as sess:

def func(
inputs=None,
_sentinel=None, # pylint: disable=invalid-name
signature=None,
as_dict=None):
"""Function that directly evaluates a signature in the module."""
signature = signature or "default"
input_tensors = input_tensors_per_signature[signature]

dict_inputs = _prepare_dict_inputs(inputs, input_tensors)

# The input arguments are directly fed into the session.
feed_dict = {
input_tensors[key]: value for key, value in dict_inputs.items()
}
output = output_tensors_per_signature[signature]
output = _prepare_outputs(output, as_dict)
return sess.run(output, feed_dict=feed_dict)

with original_graph.as_default():
# Yield the function since that will keep the session alive until the
# user exits the context.
yield func
60 changes: 54 additions & 6 deletions tensorflow_hub/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,17 @@ def get_tags(self):

def get_signature_names(self, tags=None):
if tags == set(["special"]):
return iter(["default", "extra"])
return iter(["default", "extra", "sparse"])
else:
return iter(["default"])

def get_input_info_dict(self, signature=None, tags=None):
result = {
"x": tensor_info.ParsedTensorInfo(
tf.float32,
tf.TensorShape([None]),
is_sparse=False),
"x":
tensor_info.ParsedTensorInfo(
tf.float32,
tf.TensorShape([None]),
is_sparse=(signature == "sparse" and tags == set(["special"]))),
}
if tags == set(["special"]) and signature == "extra":
result["y"] = result["x"]
Expand Down Expand Up @@ -207,6 +208,11 @@ def __init__(self, name, trainable):

def create_apply_graph(self, signature, input_tensors, name):
with tf.name_scope(name):
if signature == "sparse":
input_tensors = {
key: tf.sparse_tensor_to_dense(value)
for key, value in input_tensors.items()
}
result = {"default": 2 * input_tensors["x"]}
if signature == "extra":
result["z"] = 2 * input_tensors["x"] + 3 * input_tensors["y"]
Expand Down Expand Up @@ -256,11 +262,53 @@ def testModuleInterfaceGettersDefaultSignatureAndTags(self):
def testModuleInterfaceGettersExplicitSignatureAndTags(self):
"""Tests that tags from Module(...) apply to module.get_*()."""
m = module.Module(_ModuleSpec(), tags={"special"})
self.assertItemsEqual(m.get_signature_names(), ["default", "extra"])
self.assertItemsEqual(m.get_signature_names(),
["default", "extra", "sparse"])
self.assertItemsEqual(m.get_input_info_dict(signature="extra").keys(),
["x", "y"])
self.assertItemsEqual(m.get_output_info_dict(signature="extra").keys(),
["z", "default"])


class EvalFunctionForModuleTest(tf.test.TestCase):
"""Tests for hub.eval_function_for_module(...).
This tests that hub.eval_function_for_module parses input variables,
signatures and tags correctly and that it returns the correct output.
End-to-end tests with the native module are done in native_module_test.py.
"""

def testSingleInput(self):
with module.eval_function_for_module(_ModuleSpec()) as f:
self.assertAllEqual(f([1, 2]), [2, 4])

def testSparseInput(self):
with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f:
self.assertAllEqual(
f(tf.SparseTensorValue([[0]], [1], [2]), # Value is [1, 0].
signature="sparse"),
[2, 0])

def testDictInput(self):
with module.eval_function_for_module(_ModuleSpec()) as f:
self.assertAllEqual(f({"x": [1, 2]}), [2, 4])

def testDictOutput(self):
with module.eval_function_for_module(_ModuleSpec()) as f:
result = f({"x": [1, 2]}, as_dict=True)
self.assertTrue(isinstance(result, dict))
self.assertAllEqual(list(result.keys()), ["default"])

def testSignature(self):
with module.eval_function_for_module(_ModuleSpec()) as f:
self.assertAllEqual(f([1, 2]), [2, 4])

def testExplicitSignatureAndTags(self):
with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f:
result = f(dict(x=[1], y=[2]), signature="extra", as_dict=True)
self.assertAllEqual(result["default"], [2])
self.assertAllEqual(result["z"], [8])


if __name__ == "__main__":
tf.test.main()
23 changes: 18 additions & 5 deletions tensorflow_hub/native_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,8 @@ def testLoadTrainableModuleFromFuncDef(self):
got = sess.run(x)
self.assertAllClose(got, [3.1, 3.2, 3.3])

def testModuleWithTrainedVariable(self):
def _exportModulewithTrainedVariable(self):
export_path = os.path.join(self.get_temp_dir(), "var-module")

with tf.Graph().as_default():
spec = hub.create_module_spec(stateful_module_fn)
m = hub.Module(spec, trainable=True)
Expand All @@ -590,15 +589,22 @@ def testModuleWithTrainedVariable(self):
with tf.Session() as sess:
sess.run(assign_op)
m.export(export_path, sess)
return export_path

def testModuleWithTrainedVariable(self):
with tf.Graph().as_default():
f = hub.Module(export_path)
f = hub.Module(self._exportModulewithTrainedVariable())
out = f()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
got = sess.run(out)
self.assertAllClose(got, [9.0, 9.0, 9.0])

def testModuleEvalWithTrainedVariable(self):
export_path = self._exportModulewithTrainedVariable()
with hub.eval_function_for_module(export_path) as f:
self.assertAllClose(f(), [9.0, 9.0, 9.0])


def table_lookup_module_fn():
x = tf.placeholder(dtype=tf.int64, name="x")
Expand All @@ -611,7 +617,7 @@ def table_lookup_module_fn():

class TFHubTableLookupModuleTest(tf.test.TestCase):

def testModuleWithTable(self):
def _exportModuleWithTable(self):
export_path = os.path.join(self.get_temp_dir(), "table-module")
with tf.Graph().as_default():
spec = hub.create_module_spec(table_lookup_module_fn)
Expand All @@ -620,16 +626,23 @@ def testModuleWithTable(self):
# variables to export.
with tf.Session() as sess:
m.export(export_path, sess)
return export_path

def testModuleWithTable(self):
with tf.Graph().as_default():
v = tf.placeholder(dtype=tf.int64)
f = hub.Module(export_path)
f = hub.Module(self._exportModuleWithTable())
y = f(v)
with tf.Session() as sess:
sess.run(tf.tables_initializer())
got = sess.run(y, feed_dict={v: [0, 1, 2, 3]})
self.assertAllEqual(list(got), [b"index0", b"hello", b"world", b"UNK"])

def testModuleEvalWithTable(self):
with hub.eval_function_for_module(self._exportModuleWithTable()) as f:
got = f([0, 1, 2, 3])
self.assertAllEqual(list(got), [b"index0", b"hello", b"world", b"UNK"])


def do_table_lookup(indices, vocabulary_file):
table = tf.contrib.lookup.index_to_string_table_from_file(
Expand Down

0 comments on commit e815d2e

Please sign in to comment.