Skip to content

Commit

Permalink
Stop raising an error when attempting to load v1 modules. This should…
Browse files Browse the repository at this point in the history
… already be working for some modules.

PiperOrigin-RevId: 245691598
  • Loading branch information
TensorFlow Hub Authors authored and andresusanopinto committed May 2, 2019
1 parent 09f4596 commit d0a4c42
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
9 changes: 1 addition & 8 deletions tensorflow_hub/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,7 @@ def test_load_v1(self):
os.chdir(os.path.dirname(full_module_path))
server_port = test_utils.start_http_server()
handle = "http://localhost:%d/half_plus_two_v1.tar.gz" % server_port
try:
hub.load(handle)
self.fail("Loading v1 modules not support. Failure expected.")
except NotImplementedError as e:
self.assertEqual(str(e),
"TF Hub module '%s' is stored using TF 1.x "
"format. Loading of the module using hub.load() is not "
"supported." % handle)
hub.load(handle)


if __name__ == "__main__":
Expand Down
21 changes: 9 additions & 12 deletions tensorflow_hub/module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Hub Module API for Tensorflow 2.0"""
"""TensorFlow Hub Module API for Tensorflow 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow_hub import native_module
from tensorflow_hub import registry
from tensorflow_hub import tf_v1

Expand All @@ -40,12 +39,12 @@ def resolve(handle):
return registry.resolver(handle)


def load(handle):
def load(handle, tags=None):
"""Loads a module from a handle.
Currently this method only works with Tensorflow 2.x and can only load modules
created by calling tensorflow.saved_model.save(). The method works in both
eager and graph modes.
Currently this method is fully supported only with Tensorflow 2.x and with
modules created by calling tensorflow.saved_model.save(). The method works in
both eager and graph modes.
Depending on the type of handle used, the call may involve downloading a
Tensorflow Hub module to a local cache location specified by the
Expand All @@ -63,6 +62,8 @@ def load(handle):
Args:
handle: (string) the Module handle to resolve.
tags: A set of strings specifying the graph variant to use, if loading from
a v1 module.
Returns:
A trackable object (see tf.saved_model.load() documentation for details).
Expand All @@ -73,11 +74,7 @@ def load(handle):
"""
if hasattr(tf_v1.saved_model, "load_v2"):
module_handle = resolve(handle)
if tf_v1.gfile.Exists(native_module.get_module_proto_path(module_handle)):
raise NotImplementedError("TF Hub module '%s' is stored using TF 1.x "
"format. Loading of the module using "
"hub.load() is not supported." % handle)
return tf_v1.saved_model.load_v2(module_handle)
return tf_v1.saved_model.load_v2(module_handle, tags=tags)
else:
raise NotImplementedError("hub.load() is not implemented for TF < 1.14.x, "
"Current version: %s", tf.__version__)
"Current version: %s" % tf.__version__)

0 comments on commit d0a4c42

Please sign in to comment.