Skip to content

Commit

Permalink
Check if the temporary cache directory is empty when downloading a mo…
Browse files Browse the repository at this point in the history
…del.

PiperOrigin-RevId: 316417786
  • Loading branch information
Jin Dong authored and akhorlin committed Jun 15, 2020
1 parent 89cc3b5 commit c0610ab
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tensorflow_hub/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,12 @@ def atomic_download(handle,
overwrite=False)
# Must test condition again, since another process could have created
# the module and deleted the old lock file since last test.
if tf_v1.gfile.Exists(module_dir):
if (tf_v1.gfile.Exists(module_dir) and
tf_v1.gfile.ListDirectory(module_dir)):
# Lock file will be deleted in the finally-clause.
return module_dir
if tf_v1.gfile.Exists(module_dir):
tf_v1.gfile.DeleteRecursively(module_dir)
break # Proceed to downloading the module.
# These errors are believed to be permanent problems with the
# module_dir that justify failing the download.
Expand Down
32 changes: 32 additions & 0 deletions tensorflow_hub/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,38 @@ def fake_download_fn_with_rogue_behavior(handle, tmp_dir):
"Downloader Hostname: %s .PID:%d." % (re.escape(socket.gethostname()),
os.getpid()))

def testModuleDownloadedWhenEmptyFolderExists(self):
# Simulate the case when a module is cached in /tmp/module_dir but module
# files inside the folder are deleted. In this case, the download should
# still be conducted.
module_dir = os.path.join(self.get_temp_dir(), "module")
def fake_download_fn(handle, tmp_dir):
del handle, tmp_dir
tf_v1.gfile.MakeDirs(module_dir)
tf_utils.atomic_write_string_to_file(
os.path.join(module_dir, "file"), "content", False)

# Create an empty folder before downloading.
self.assertFalse(tf_v1.gfile.Exists(module_dir))
tf_v1.gfile.MakeDirs(module_dir)

self.assertEqual(
module_dir,
resolver.atomic_download("module", fake_download_fn, module_dir))
self.assertEqual(tf_v1.gfile.ListDirectory(module_dir), ["file"])
self.assertFalse(tf_v1.gfile.Exists(resolver._lock_filename(module_dir)))
parent_dir = os.path.abspath(os.path.join(module_dir, ".."))
self.assertEqual(
sorted(tf_v1.gfile.ListDirectory(parent_dir)),
["module", "module.descriptor.txt"])
self.assertRegexpMatches(
tf_utils.read_file_to_string(
resolver._module_descriptor_file(module_dir)),
"Module: module\n"
"Download Time: .*\n"
"Downloader Hostname: %s .PID:%d." % (re.escape(socket.gethostname()),
os.getpid()))

def testModuleConcurrentDownload(self):
module_dir = os.path.join(self.get_temp_dir(), "module")

Expand Down

0 comments on commit c0610ab

Please sign in to comment.