From 54a242de6e148d0836c4a58308b07bbe79c274b6 Mon Sep 17 00:00:00 2001 From: TensorFlow Hub Authors Date: Tue, 30 Apr 2019 14:23:24 +0200 Subject: [PATCH] Make hub.Module usable within "tf.compat.v1.wrap_function". This makes it impossible to raise an error about incorrect usage within functions that do not prune the ops, so log a warning instead. PiperOrigin-RevId: 245928528 --- tensorflow_hub/module.py | 4 ++ tensorflow_hub/native_module.py | 43 ++++++++++----- tensorflow_hub/native_module_test.py | 79 +++++++++------------------- 3 files changed, 59 insertions(+), 67 deletions(-) diff --git a/tensorflow_hub/module.py b/tensorflow_hub/module.py index 6b3b4ed51..8772ba59d 100644 --- a/tensorflow_hub/module.py +++ b/tensorflow_hub/module.py @@ -210,6 +210,10 @@ def __call__(self, inputs=None, # pylint: disable=invalid-name - Add constant tensors to ASSET_FILEPATHS, even if those are not needed directly needed for the signature. + Note: `hub.Module` implementation depends on graph pruning that happens + usually during `session.run` as so it can lead to errors when used inside + function graphs that execute all its ops (e.g. `tf.data.Dataset.map`). + Args: inputs: Inputs to the signature. A dict from input names to tensor values. If the signature only expects one input, one may pass diff --git a/tensorflow_hub/native_module.py b/tensorflow_hub/native_module.py index 1f4eb1990..c6acfd75c 100644 --- a/tensorflow_hub/native_module.py +++ b/tensorflow_hub/native_module.py @@ -381,13 +381,21 @@ def __init__(self, spec, meta_graph, trainable, checkpoint_path, name): register_ops_if_needed({ op.name for op in self._meta_graph.meta_info_def.stripped_op_list.op}) - # Use an init scope to clear dependencies and lift state ops outside of - # function-building graphs. Modules can be constructed from deep inside + if _is_tpu_graph_function(): + # TODO(b/129142908): Hub should not use `tf.init_scope` since that makes + # it incompatible with tf.compat.v1.wrap_function. For now the only use + # case where hub used it was for tpu compatibility. This should be cleaned + # up at an early convinience. + scope_func = tf.init_scope + else: + scope_func = lambda: tf.control_dependencies(None) + + # Clear dependencies so modules can be constructed from deep inside # functions that have dependencies active. Note that the dependencies # would be active when applying the Module signature, just not active # when creating the Module state. This use case has showed up in some # TPU training code. - with tf.init_scope(): + with scope_func(): self._init_state(name) def _init_state(self, name): @@ -489,7 +497,7 @@ def create_apply_graph(self, signature, input_tensors, name): # TODO(b/112575006): The following adds functionality of function call # within a TPU context. Work to generalize this for all function calls is # ongoing. - if self._is_tpu_graph_function(): + if _is_tpu_graph_function(): for k, v in self._state_map.items(): feed_map[k] = apply_graph.capture(v) meta_graph_lib.prune_unused_nodes(meta_graph, signature_def) @@ -497,9 +505,15 @@ def create_apply_graph(self, signature, input_tensors, name): # infeeds which no longer exist. meta_graph_lib.prune_feed_map(meta_graph, infeed_map) elif apply_graph.building_function: - raise NotImplementedError( - "Using TF-Hub module within a TensorFlow defined function " - "is currently not supported.") + # Log a warning if a user is using a hub module in function graph. + # This is only expected to work if the function graph is pruned and + # not all nodes are executed. + # + # E.g. it could work with "tf.compat.v1.wrap_function", but it will not + # work with defun, Dataset.map_fn, etc... + logging.warning("Using `hub.Module` while building a function: %s. This " + "can lead to errors if the function is not pruned.", + apply_graph.name) # As state ops in the apply graph are unused, replace them with Placeholders # so that in a heirarchical instantiation, apply_graph state ops are @@ -536,7 +550,7 @@ def create_apply_graph(self, signature, input_tensors, name): meta_graph_lib.filter_collections(meta_graph, import_collections) meta_graph_lib.prefix_shared_name_attributes(meta_graph, absolute_scope_name) - if len(meta_graph.collection_def) and self._is_tpu_graph_function(): + if len(meta_graph.collection_def) and _is_tpu_graph_function(): raise NotImplementedError( "Applying modules with collections inside TPU functions is not " "supported.") @@ -560,12 +574,6 @@ def get_tensor(name): return tensor_info.build_output_map(signature_def.outputs, get_tensor) - def _is_tpu_graph_function(self): - apply_graph = tf_v1.get_default_graph() - return (apply_graph.building_function and - type(apply_graph._get_control_flow_context()).__name__.endswith( # pylint: disable=protected-access - "TPUReplicateContext")) - def export(self, path, session): """See `Module.export`.""" def variables_saver(variables_path): @@ -1091,3 +1099,10 @@ def find_signature_inputs_from_multivalued_ops(inputs): "colocation constraints have to be rewritten.\nAffected inputs: %s" % ", ".join("%s='%s'" % pair for pair in warnings)) return None + + +def _is_tpu_graph_function(): + graph = tf_v1.get_default_graph() + return (graph.building_function and + type(graph._get_control_flow_context()).__name__.endswith( # pylint: disable=protected-access + "TPUReplicateContext")) diff --git a/tensorflow_hub/native_module_test.py b/tensorflow_hub/native_module_test.py index ace183117..534a699cc 100644 --- a/tensorflow_hub/native_module_test.py +++ b/tensorflow_hub/native_module_test.py @@ -778,59 +778,6 @@ def import_computation(first, second): got = sess.run(x(9.0, 6.0)) self.assertEqual(got, [19.0, 16.0]) - # The following tests should all fail until b/112575006 is resolved. - def testModuleWithDefun(self): - spec = hub.create_module_spec(stateful_rv_with_input_module_fn) - - @function.Defun() - def import_computation(first, second): - m = hub.Module(spec, name="module_", trainable=True) - return [m(first), m(second)] - with tf_v1.Graph().as_default(), tf_v1.Session() as sess: - # In the case where we don't handle the variables, they will not be - # hoisted so they are not handled properly. - with self.assertRaisesRegexp( - NotImplementedError, - "Using TF-Hub module within a TensorFlow defined function " - "is currently not supported."): - import_computation(9.0, 6.0) - - def testModuleWithEagerDefun(self): - spec = hub.create_module_spec(stateful_rv_with_input_module_fn) - - def import_computation(first, second): - # In the case where we don't handle the variables, they will not be - # hoisted so they are not handled properly. - with self.assertRaisesRegexp( - NotImplementedError, - "Using TF-Hub module within a TensorFlow defined function " - "is currently not supported."): - m = hub.Module(spec, trainable=True) - return [m(first), m(second)] - - x = function_eager.defun(import_computation) - with tf_v1.Graph().as_default(), tf_v1.Session() as sess: - sess.run(x(9.0, 6.0)) - - def testModuleWithWrapFunc(self): - spec = hub.create_module_spec(stateful_rv_with_input_module_fn) - - def import_computation(first, second): - m = hub.Module(spec, trainable=True) - return [m(first), m(second)] - - # In the case where we don't handle the variables, they will not be - # hoisted so they are not handled properly. - with tf_v1.Graph().as_default(), tf_v1.Session() as sess: - with self.assertRaisesRegexp( - NotImplementedError, - "Using TF-Hub module within a TensorFlow defined function " - "is currently not supported."): - tf_v1.wrap_function( - import_computation, - [tf.TensorSpec((), tf.float32), - tf.TensorSpec((), tf.float32)]) - def _exportModulewithTrainedVariable(self): export_path = os.path.join(self.get_temp_dir(), "var-module") with tf.Graph().as_default(): @@ -1765,5 +1712,31 @@ def testExportModuleSpec_withWrongScope(self): name_transform_fn=lambda x: "block/" + x) +class TFHubUsageWithEager(tf.test.TestCase): + + def testWrapFunction(self): + if not tf.executing_eagerly(): + self.skipTest("Test requires eager.") + + spec = hub.create_module_spec(stateful_rv_with_input_module_fn) + + initializers = [] + def use_module(x, y): + m = hub.Module(spec, name="module_", trainable=True) + initializers.append(tf_v1.initializers.global_variables()) + return [m(x), m(y)] + + input_signature = [ + tf.TensorSpec((), tf.float32), + tf.TensorSpec((), tf.float32), + ] + + f = tf_v1.wrap_function(use_module, input_signature) + f.prune([], initializers)() + self.assertAllEqual( + [x.numpy() for x in f(9.0, 6.0)], + [19.0, 16.0]) + + if __name__ == "__main__": tf.test.main()