Skip to content

Commit

Permalink
Generalize function to get testdata and move it to test_utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 276661074
  • Loading branch information
TensorFlow Hub Authors authored and andresusanopinto committed Oct 25, 2019
1 parent 26679cb commit eb51c5a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
11 changes: 2 additions & 9 deletions tensorflow_hub/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
class End2EndTest(tf.test.TestCase):

def setUp(self):
super(End2EndTest, self).setUp()
# Set current directory to test temp directory where we can create
# files and serve them through the HTTP server.
os.chdir(self.get_temp_dir())
Expand Down Expand Up @@ -184,19 +185,11 @@ def add(self, x):
self.assertIsNotNone(restored_module)
self.assertTrue(hasattr(restored_module, "add"))

def _full_module_path(self, module_name):
for directory, _, files in tf_v1.gfile.Walk(test_utils.test_srcdir()):
for f in files:
full_path = f
if full_path.endswith(module_name):
return os.path.join(directory, f)
raise ValueError("No %s in test source directory" % module_name)

def test_load_v1(self):
if (not hasattr(tf_v1.saved_model, "load_v2") or
not tf_v1.executing_eagerly()):
return # The test only applies when running V2 mode.
full_module_path = self._full_module_path("half_plus_two_v1.tar.gz")
full_module_path = test_utils.get_test_data_path("half_plus_two_v1.tar.gz")
os.chdir(os.path.dirname(full_module_path))
server_port = test_utils.start_http_server()
handle = "http://localhost:%d/half_plus_two_v1.tar.gz" % server_port
Expand Down
10 changes: 10 additions & 0 deletions tensorflow_hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import threading

from absl import flags
import tensorflow as tf

# TODO(b/73987364): It is not possible to extend feature columns without
# depending on TensorFlow internal implementation details.
Expand Down Expand Up @@ -161,3 +162,12 @@ def get_dense_features_module():
if hasattr(feature_column_v2, "DenseFeatures"):
return feature_column_v2
return dense_features_v2


def get_test_data_path(file_or_dirname):
"""Return full test data path."""
for directory, subdirs, files in tf.io.gfile.walk(test_srcdir()):
for f in subdirs + files:
if f.endswith(file_or_dirname):
return os.path.join(directory, f)
raise ValueError("No %s in test directory" % file_or_dirname)

0 comments on commit eb51c5a

Please sign in to comment.