Skip to content

Commit

Permalink
Add native_module.get_unsupported_collections().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 233584508
  • Loading branch information
TensorFlow Hub Authors authored and vbardiovskyg committed Feb 26, 2019
1 parent d9b489b commit d9e6efb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow_hub/native_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,10 @@ def check_collections_are_supported(saved_model_handler, supported):
" as appropriate." % list(unsupported))


def get_unsupported_collections(used_collection_keys):
return list(set(used_collection_keys) - _SUPPORTED_COLLECTIONS)


def register_ops_if_needed(graph_ops):
"""Register graph ops absent in op_def_registry, if present in c++ registry.
Expand Down
20 changes: 20 additions & 0 deletions tensorflow_hub/native_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,26 @@ def wrong_module_fn():
spec = native_module.create_module_spec(wrong_module_fn)
self.assertIn("No signatures present", str(cm.exception))

def testUnsupportedCollections(self):

def module_fn():
scale = tf.get_variable("x", (), collections=["my_scope"])
x = tf.placeholder(tf.float32, shape=[None, 3])
native_module.add_signature("my_func", {"x": x}, {"y": x*scale})

with self.assertRaises(ValueError) as cm:
_ = native_module.create_module_spec(module_fn)
self.assertIn("Unsupported collections in graph", cm)

with tf.Graph().as_default() as tmp_graph:
module_fn()
unsupported_collections = native_module.get_unsupported_collections(
tmp_graph.get_all_collection_keys())
self.assertEqual(["my_scope"], unsupported_collections)

_ = native_module.create_module_spec(
module_fn, drop_collections=unsupported_collections)


class RecoverPartitionedVariableMapTest(tf.test.TestCase):

Expand Down

0 comments on commit d9e6efb

Please sign in to comment.