From d9e6efb0e9fc3bc670df1427879e257504c5c599 Mon Sep 17 00:00:00 2001 From: TensorFlow Hub Authors Date: Tue, 12 Feb 2019 13:43:44 +0100 Subject: [PATCH] Add native_module.get_unsupported_collections(). PiperOrigin-RevId: 233584508 --- tensorflow_hub/native_module.py | 4 ++++ tensorflow_hub/native_module_test.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/tensorflow_hub/native_module.py b/tensorflow_hub/native_module.py index 226ab6e56..df6d6d76b 100644 --- a/tensorflow_hub/native_module.py +++ b/tensorflow_hub/native_module.py @@ -709,6 +709,10 @@ def check_collections_are_supported(saved_model_handler, supported): " as appropriate." % list(unsupported)) +def get_unsupported_collections(used_collection_keys): + return list(set(used_collection_keys) - _SUPPORTED_COLLECTIONS) + + def register_ops_if_needed(graph_ops): """Register graph ops absent in op_def_registry, if present in c++ registry. diff --git a/tensorflow_hub/native_module_test.py b/tensorflow_hub/native_module_test.py index 2a9a1d5f1..fc10de7ed 100644 --- a/tensorflow_hub/native_module_test.py +++ b/tensorflow_hub/native_module_test.py @@ -129,6 +129,26 @@ def wrong_module_fn(): spec = native_module.create_module_spec(wrong_module_fn) self.assertIn("No signatures present", str(cm.exception)) + def testUnsupportedCollections(self): + + def module_fn(): + scale = tf.get_variable("x", (), collections=["my_scope"]) + x = tf.placeholder(tf.float32, shape=[None, 3]) + native_module.add_signature("my_func", {"x": x}, {"y": x*scale}) + + with self.assertRaises(ValueError) as cm: + _ = native_module.create_module_spec(module_fn) + self.assertIn("Unsupported collections in graph", cm) + + with tf.Graph().as_default() as tmp_graph: + module_fn() + unsupported_collections = native_module.get_unsupported_collections( + tmp_graph.get_all_collection_keys()) + self.assertEqual(["my_scope"], unsupported_collections) + + _ = native_module.create_module_spec( + module_fn, drop_collections=unsupported_collections) + class RecoverPartitionedVariableMapTest(tf.test.TestCase):