Skip to content

Commit

Permalink
Prune corresponding infeed map after pruning unused nodes from the gr…
Browse files Browse the repository at this point in the history
…aph.

PiperOrigin-RevId: 238289552
  • Loading branch information
TensorFlow Hub Authors authored and arnoegw committed Mar 27, 2019
1 parent 8fc2098 commit 9f79ae3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow_hub/meta_graph_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def prune_unused_nodes(meta_graph, signature_def):
del graph


def prune_feed_map(meta_graph, feed_map):
"""Function to prune the feedmap of nodes which no longer exist."""
node_names = [x.name + ":0" for x in meta_graph.graph_def.node]
keys_to_delete = []
for k, _ in feed_map.items():
if k not in node_names:
keys_to_delete.append(k)
for k in keys_to_delete:
del feed_map[k]


def filter_collections(meta_graph, collections):
collections = frozenset(collections)
for name in list(meta_graph.collection_def.keys()):
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_hub/native_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ def create_apply_graph(self, signature, input_tensors, name):
for k, v in self._state_map.items():
feed_map[k] = apply_graph.capture(v)
meta_graph_lib.prune_unused_nodes(meta_graph, signature_def)
# After we prune the metagraph def, we might need to prune away
# infeeds which no longer exist.
meta_graph_lib.prune_feed_map(meta_graph, infeed_map)
elif apply_graph.building_function:
raise NotImplementedError(
"Using TF-Hub module within a TensorFlow defined function "
Expand Down
18 changes: 18 additions & 0 deletions tensorflow_hub/native_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,24 @@ def import_computation():
# the values would be different.
self.assertEqual(got[0], got[1])

def testTPUPruneWithUnusedInput(self):
spec = hub.create_module_spec(unused_input_module_fn)

@function.Defun()
def import_computation(x):
context = TPUReplicateContext()
context.Enter()
m = hub.Module(spec, name="module_", trainable=True)
return m({
"x": tf.cast(x, dtype=tf.int64),
"unused": tf.constant(2, dtype=tf.int64)
})

with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
x = import_computation(5)
got = sess.run(x)
self.assertEqual(got, 25)

def testTPUModuleDoesntPruneControlDependencies(self):
spec = hub.create_module_spec(control_dependency_module_fn)

Expand Down

0 comments on commit 9f79ae3

Please sign in to comment.