Skip to content

Commit

Permalink
Allow to use default tags with hub.load() when none are specified.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 276676596
  • Loading branch information
TensorFlow Hub Authors authored and andresusanopinto committed Oct 28, 2019
1 parent eb51c5a commit 27db714
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 7 deletions.
24 changes: 21 additions & 3 deletions tensorflow_hub/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
licenses(["notice"]) # Apache 2.0 License
licenses(["notice"])

load("//tensorflow_hub:protos.bzl", "tf_hub_proto_library") # buildifier: disable=load-on-top

package(default_visibility = ["//:__subpackages__"])

exports_files(["LICENSE"])

load("//tensorflow_hub:protos.bzl", "tf_hub_proto_library")

# This is the public import users should use.
py_library(
name = "tensorflow_hub",
Expand Down Expand Up @@ -442,3 +442,21 @@ py_library(
"//tensorflow_hub:expect_tensorflow_installed",
],
)

py_test(
name = "module_v2_test",
srcs = ["module_v2_test.py"],
data = [
"testdata/hub_module_v1_mini",
"testdata/saved_model_v2_mini",
],
# python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":module_v2",
":test_utils",
"//tensorflow_hub:expect_tensorflow_installed",
":config",
],
)
14 changes: 10 additions & 4 deletions tensorflow_hub/module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import tensorflow as tf

from tensorflow_hub import native_module
from tensorflow_hub import registry
from tensorflow_hub import tf_v1

Expand Down Expand Up @@ -83,9 +84,14 @@ def load(handle, tags=None):
NotImplementedError: If the code is running against incompatible (1.x)
version of TF.
"""
if hasattr(tf_v1.saved_model, "load_v2"):
module_handle = resolve(handle)
return tf_v1.saved_model.load_v2(module_handle, tags=tags)
else:
if not hasattr(tf_v1.saved_model, "load_v2"):
raise NotImplementedError("hub.load() is not implemented for TF < 1.14.x, "
"Current version: %s" % tf.__version__)
module_path = resolve(handle)
is_hub_module_v1 = tf.io.gfile.exists(
native_module.get_module_proto_path(module_path))
if tags is None and is_hub_module_v1:
tags = []
obj = tf_v1.saved_model.load_v2(module_path, tags=tags)
obj._is_hub_module_v1 = is_hub_module_v1 # pylint: disable=protected-access
return obj
59 changes: 59 additions & 0 deletions tensorflow_hub/module_v2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_hub.module_v2."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint:disable=g-import-not-at-top,g-statement-before-imports
try:
import mock as mock
except ImportError:
import unittest.mock as mock
# pylint:disable=g-import-not-at-top,g-statement-before-imports

from absl.testing import parameterized
import tensorflow as tf
from tensorflow_hub import config
from tensorflow_hub import module_v2
from tensorflow_hub import test_utils

# Initialize resolvers and loaders.
config._run()


class ModuleV2Test(tf.test.TestCase, parameterized.TestCase):

@parameterized.named_parameters(
('v1_implicit_tags', 'hub_module_v1_mini', None, True),
('v1_explicit_tags', 'hub_module_v1_mini', [], True),
('v2_implicit_tags', 'saved_model_v2_mini', None, False),
('v2_explicit_tags', 'saved_model_v2_mini', ['serve'], False),
)
def test_load(self, module_name, tags, is_hub_module_v1):
path = test_utils.get_test_data_path(module_name)
m = module_v2.load(path, tags)
self.assertEqual(m._is_hub_module_v1, is_hub_module_v1)

@mock.patch.object(module_v2, 'tf_v1')
def test_load_with_old_tensorflow_raises_error(self, tf_v1_mock):
tf_v1_mock.saved_model = None
with self.assertRaises(NotImplementedError):
module_v2.load('dummy_module_name')


if __name__ == '__main__':
tf.test.main()
Binary file not shown.
1 change: 1 addition & 0 deletions tensorflow_hub/testdata/hub_module_v1_mini/tfhub_module.pb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 27db714

Please sign in to comment.