Skip to content

Commit

Permalink
Make hub.Module usable within "tf.compat.v1.wrap_function".
Browse files Browse the repository at this point in the history
This makes it impossible to raise an error about incorrect usage
within functions that do not prune the ops, so log a warning instead.

PiperOrigin-RevId: 245928528
  • Loading branch information
TensorFlow Hub Authors authored and andresusanopinto committed May 2, 2019
1 parent 39d2b2a commit 54a242d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 67 deletions.
4 changes: 4 additions & 0 deletions tensorflow_hub/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def __call__(self, inputs=None, # pylint: disable=invalid-name
- Add constant tensors to ASSET_FILEPATHS, even if those are not needed
directly needed for the signature.
Note: `hub.Module` implementation depends on graph pruning that happens
usually during `session.run` as so it can lead to errors when used inside
function graphs that execute all its ops (e.g. `tf.data.Dataset.map`).
Args:
inputs: Inputs to the signature. A dict from input names to tensor
values. If the signature only expects one input, one may pass
Expand Down
43 changes: 29 additions & 14 deletions tensorflow_hub/native_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,21 @@ def __init__(self, spec, meta_graph, trainable, checkpoint_path, name):
register_ops_if_needed({
op.name for op in self._meta_graph.meta_info_def.stripped_op_list.op})

# Use an init scope to clear dependencies and lift state ops outside of
# function-building graphs. Modules can be constructed from deep inside
if _is_tpu_graph_function():
# TODO(b/129142908): Hub should not use `tf.init_scope` since that makes
# it incompatible with tf.compat.v1.wrap_function. For now the only use
# case where hub used it was for tpu compatibility. This should be cleaned
# up at an early convinience.
scope_func = tf.init_scope
else:
scope_func = lambda: tf.control_dependencies(None)

# Clear dependencies so modules can be constructed from deep inside
# functions that have dependencies active. Note that the dependencies
# would be active when applying the Module signature, just not active
# when creating the Module state. This use case has showed up in some
# TPU training code.
with tf.init_scope():
with scope_func():
self._init_state(name)

def _init_state(self, name):
Expand Down Expand Up @@ -489,17 +497,23 @@ def create_apply_graph(self, signature, input_tensors, name):
# TODO(b/112575006): The following adds functionality of function call
# within a TPU context. Work to generalize this for all function calls is
# ongoing.
if self._is_tpu_graph_function():
if _is_tpu_graph_function():
for k, v in self._state_map.items():
feed_map[k] = apply_graph.capture(v)
meta_graph_lib.prune_unused_nodes(meta_graph, signature_def)
# After we prune the metagraph def, we might need to prune away
# infeeds which no longer exist.
meta_graph_lib.prune_feed_map(meta_graph, infeed_map)
elif apply_graph.building_function:
raise NotImplementedError(
"Using TF-Hub module within a TensorFlow defined function "
"is currently not supported.")
# Log a warning if a user is using a hub module in function graph.
# This is only expected to work if the function graph is pruned and
# not all nodes are executed.
#
# E.g. it could work with "tf.compat.v1.wrap_function", but it will not
# work with defun, Dataset.map_fn, etc...
logging.warning("Using `hub.Module` while building a function: %s. This "
"can lead to errors if the function is not pruned.",
apply_graph.name)

# As state ops in the apply graph are unused, replace them with Placeholders
# so that in a heirarchical instantiation, apply_graph state ops are
Expand Down Expand Up @@ -536,7 +550,7 @@ def create_apply_graph(self, signature, input_tensors, name):
meta_graph_lib.filter_collections(meta_graph, import_collections)
meta_graph_lib.prefix_shared_name_attributes(meta_graph,
absolute_scope_name)
if len(meta_graph.collection_def) and self._is_tpu_graph_function():
if len(meta_graph.collection_def) and _is_tpu_graph_function():
raise NotImplementedError(
"Applying modules with collections inside TPU functions is not "
"supported.")
Expand All @@ -560,12 +574,6 @@ def get_tensor(name):

return tensor_info.build_output_map(signature_def.outputs, get_tensor)

def _is_tpu_graph_function(self):
apply_graph = tf_v1.get_default_graph()
return (apply_graph.building_function and
type(apply_graph._get_control_flow_context()).__name__.endswith( # pylint: disable=protected-access
"TPUReplicateContext"))

def export(self, path, session):
"""See `Module.export`."""
def variables_saver(variables_path):
Expand Down Expand Up @@ -1091,3 +1099,10 @@ def find_signature_inputs_from_multivalued_ops(inputs):
"colocation constraints have to be rewritten.\nAffected inputs: %s" %
", ".join("%s='%s'" % pair for pair in warnings))
return None


def _is_tpu_graph_function():
graph = tf_v1.get_default_graph()
return (graph.building_function and
type(graph._get_control_flow_context()).__name__.endswith( # pylint: disable=protected-access
"TPUReplicateContext"))
79 changes: 26 additions & 53 deletions tensorflow_hub/native_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,59 +778,6 @@ def import_computation(first, second):
got = sess.run(x(9.0, 6.0))
self.assertEqual(got, [19.0, 16.0])

# The following tests should all fail until b/112575006 is resolved.
def testModuleWithDefun(self):
spec = hub.create_module_spec(stateful_rv_with_input_module_fn)

@function.Defun()
def import_computation(first, second):
m = hub.Module(spec, name="module_", trainable=True)
return [m(first), m(second)]
with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
# In the case where we don't handle the variables, they will not be
# hoisted so they are not handled properly.
with self.assertRaisesRegexp(
NotImplementedError,
"Using TF-Hub module within a TensorFlow defined function "
"is currently not supported."):
import_computation(9.0, 6.0)

def testModuleWithEagerDefun(self):
spec = hub.create_module_spec(stateful_rv_with_input_module_fn)

def import_computation(first, second):
# In the case where we don't handle the variables, they will not be
# hoisted so they are not handled properly.
with self.assertRaisesRegexp(
NotImplementedError,
"Using TF-Hub module within a TensorFlow defined function "
"is currently not supported."):
m = hub.Module(spec, trainable=True)
return [m(first), m(second)]

x = function_eager.defun(import_computation)
with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
sess.run(x(9.0, 6.0))

def testModuleWithWrapFunc(self):
spec = hub.create_module_spec(stateful_rv_with_input_module_fn)

def import_computation(first, second):
m = hub.Module(spec, trainable=True)
return [m(first), m(second)]

# In the case where we don't handle the variables, they will not be
# hoisted so they are not handled properly.
with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
with self.assertRaisesRegexp(
NotImplementedError,
"Using TF-Hub module within a TensorFlow defined function "
"is currently not supported."):
tf_v1.wrap_function(
import_computation,
[tf.TensorSpec((), tf.float32),
tf.TensorSpec((), tf.float32)])

def _exportModulewithTrainedVariable(self):
export_path = os.path.join(self.get_temp_dir(), "var-module")
with tf.Graph().as_default():
Expand Down Expand Up @@ -1765,5 +1712,31 @@ def testExportModuleSpec_withWrongScope(self):
name_transform_fn=lambda x: "block/" + x)


class TFHubUsageWithEager(tf.test.TestCase):

def testWrapFunction(self):
if not tf.executing_eagerly():
self.skipTest("Test requires eager.")

spec = hub.create_module_spec(stateful_rv_with_input_module_fn)

initializers = []
def use_module(x, y):
m = hub.Module(spec, name="module_", trainable=True)
initializers.append(tf_v1.initializers.global_variables())
return [m(x), m(y)]

input_signature = [
tf.TensorSpec((), tf.float32),
tf.TensorSpec((), tf.float32),
]

f = tf_v1.wrap_function(use_module, input_signature)
f.prune([], initializers)()
self.assertAllEqual(
[x.numpy() for x in f(9.0, 6.0)],
[19.0, 16.0])


if __name__ == "__main__":
tf.test.main()

0 comments on commit 54a242d

Please sign in to comment.