forked from tensorflow/hub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
resolver.py
497 lines (415 loc) · 18 KB
/
resolver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Interface and common utility methods to perform module address resolution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import datetime
import os
import socket
import sys
import tarfile
import tempfile
import time
import uuid
from absl import flags
from absl import logging
import tensorflow as tf
from tensorflow_hub import tf_utils
from tensorflow_hub import tf_v1
FLAGS = flags.FLAGS
flags.DEFINE_string(
"tfhub_cache_dir",
None,
"If set, TF-Hub will download and cache Modules into this directory. "
"Otherwise it will attempt to find a network path.")
_TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR"
_TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS"
def tfhub_cache_dir(default_cache_dir=None, use_temp=False):
"""Returns cache directory.
Returns cache directory from either TFHUB_CACHE_DIR environment variable
or --tfhub_cache_dir or default, if set.
Args:
default_cache_dir: Default cache location to use if neither TFHUB_CACHE_DIR
environment variable nor --tfhub_cache_dir are
not specified.
use_temp: bool, Optional to enable using system's temp directory as a
module cache directory if neither default_cache_dir nor
--tfhub_cache_dir nor TFHUB_CACHE_DIR environment variable are
specified .
"""
# Note: We are using FLAGS["tfhub_cache_dir"] (and not FLAGS.tfhub_cache_dir)
# to access the flag value in order to avoid parsing argv list. The flags
# should have been parsed by now in main() by tf.app.run(). If that was not
# the case (say in Colab env) we skip flag parsing because argv may contain
# unknown flags.
cache_dir = (
os.getenv(_TFHUB_CACHE_DIR, "") or FLAGS["tfhub_cache_dir"].value or
default_cache_dir)
if not cache_dir and use_temp:
# Place all TF-Hub modules under <system's temp>/tfhub_modules.
cache_dir = os.path.join(tempfile.gettempdir(), "tfhub_modules")
if cache_dir:
logging.log_first_n(logging.INFO, "Using %s to cache modules.", 1,
cache_dir)
return cache_dir
def create_local_module_dir(cache_dir, module_name):
"""Creates and returns the name of directory where to cache a module."""
tf_v1.gfile.MakeDirs(cache_dir)
return os.path.join(cache_dir, module_name)
class DownloadManager(object):
"""Helper class responsible for TF-Hub module download and extraction."""
def __init__(self, url):
"""Creates DownloadManager responsible for downloading a TF-Hub module.
Args:
url: URL pointing to the TF-Hub module to download and extract.
"""
self._url = url
self._last_progress_msg_print_time = time.time()
self._total_bytes_downloaded = 0
self._max_prog_str = 0
def _print_download_progress_msg(self, msg, flush=False):
"""Prints a message about download progress either to the console or TF log.
Args:
msg: Message to print.
flush: Indicates whether to flush the output (only used in interactive
mode).
"""
if self._interactive_mode():
# Print progress message to console overwriting previous progress
# message.
self._max_prog_str = max(self._max_prog_str, len(msg))
sys.stdout.write("\r%-{}s".format(self._max_prog_str) % msg)
sys.stdout.flush()
if flush:
print("\n")
else:
# Interactive progress tracking is disabled. Print progress to the
# standard TF log.
logging.info(msg)
def _log_progress(self, bytes_downloaded):
"""Logs progress information about ongoing module download.
Args:
bytes_downloaded: Number of bytes downloaded.
"""
self._total_bytes_downloaded += bytes_downloaded
now = time.time()
if (self._interactive_mode() or
now - self._last_progress_msg_print_time > 15):
# Print progress message every 15 secs or if interactive progress
# tracking is enabled.
self._print_download_progress_msg(
"Downloading %s: %s" % (self._url,
tf_utils.bytes_to_readable_str(
self._total_bytes_downloaded, True)))
self._last_progress_msg_print_time = now
def _interactive_mode(self):
"""Returns true if interactive logging is enabled."""
return os.getenv(_TFHUB_DOWNLOAD_PROGRESS, "")
def _extract_file(self, tgz, tarinfo, dst_path, buffer_size=10<<20):
"""Extracts 'tarinfo' from 'tgz' and writes to 'dst_path'."""
src = tgz.extractfile(tarinfo)
dst = tf_v1.gfile.GFile(dst_path, "wb")
while 1:
buf = src.read(buffer_size)
if not buf:
break
dst.write(buf)
self._log_progress(len(buf))
dst.close()
src.close()
def download_and_uncompress(self, fileobj, dst_path):
"""Streams the content for the 'fileobj' and stores the result in dst_path.
Args:
fileobj: File handle pointing to .tar/.tar.gz content.
dst_path: Absolute path where to store uncompressed data from 'fileobj'.
Raises:
ValueError: Unknown object encountered inside the TAR file.
"""
try:
with tarfile.open(mode="r|*", fileobj=fileobj) as tgz:
for tarinfo in tgz:
abs_target_path = _merge_relative_path(dst_path, tarinfo.name)
if tarinfo.isfile():
self._extract_file(tgz, tarinfo, abs_target_path)
elif tarinfo.isdir():
tf_v1.gfile.MakeDirs(abs_target_path)
else:
# We do not support symlinks and other uncommon objects.
raise ValueError(
"Unexpected object type in tar archive: %s" % tarinfo.type)
total_size_str = tf_utils.bytes_to_readable_str(
self._total_bytes_downloaded, True)
self._print_download_progress_msg(
"Downloaded %s, Total size: %s" % (self._url, total_size_str),
flush=True)
except tarfile.ReadError:
raise IOError("%s does not appear to be a valid module." % self._url)
def _merge_relative_path(dst_path, rel_path):
"""Merge a relative tar file to a destination (which can be "gs://...")."""
# Convert rel_path to be relative and normalize it to remove ".", "..", "//",
# which are valid directories in fileystems like "gs://".
norm_rel_path = os.path.normpath(rel_path.lstrip("/"))
if norm_rel_path == ".":
return dst_path
# Check that the norm rel path does not starts with "..".
if norm_rel_path.startswith(".."):
raise ValueError("Relative path %r is invalid." % rel_path)
merged = os.path.join(dst_path, norm_rel_path)
# After merging verify that the merged path keeps the original dst_path.
if not merged.startswith(dst_path):
raise ValueError("Relative path %r is invalid. Failed to merge with %r." % (
rel_path, dst_path))
return merged
def _module_descriptor_file(module_dir):
"""Returns the name of the file containing descriptor for the 'module_dir'."""
return "{}.descriptor.txt".format(module_dir)
def _write_module_descriptor_file(handle, module_dir):
"""Writes a descriptor file about the directory containing a module.
Args:
handle: Module name/handle.
module_dir: Directory where a module was downloaded.
"""
readme = _module_descriptor_file(module_dir)
readme_content = (
"Module: %s\nDownload Time: %s\nDownloader Hostname: %s (PID:%d)" %
(handle, str(datetime.datetime.today()), socket.gethostname(),
os.getpid()))
# The descriptor file has no semantic meaning so we allow 'overwrite' since
# there is a chance that another process might have written the file (and
# crashed), we just overwrite it.
tf_utils.atomic_write_string_to_file(readme, readme_content, overwrite=True)
def _lock_file_contents(task_uid):
"""Returns the content of the lock file."""
return "%s.%d.%s" % (socket.gethostname(), os.getpid(), task_uid)
def _lock_filename(module_dir):
"""Returns lock file name."""
return tf_utils.absolute_path(module_dir) + ".lock"
def _module_dir(lock_filename):
"""Returns module dir from a full 'lock_filename' path.
Args:
lock_filename: Name of the lock file, ends with .lock.
Raises:
ValueError: if lock_filename is ill specified.
"""
if not lock_filename.endswith(".lock"):
raise ValueError(
"Lock file name (%s) has to end with .lock." % lock_filename)
return lock_filename[0:-len(".lock")]
def _task_uid_from_lock_file(lock_filename):
"""Returns task UID of the task that created a given lock file."""
lock = tf_utils.read_file_to_string(lock_filename)
return lock.split(".")[-1]
def _temp_download_dir(module_dir, task_uid):
"""Returns the name of a temporary directory to download module to."""
return "{}.{}.tmp".format(tf_utils.absolute_path(module_dir), task_uid)
def _dir_size(directory):
"""Returns total size (in bytes) of the given 'directory'."""
size = 0
for elem in tf_v1.gfile.ListDirectory(directory):
elem_full_path = os.path.join(directory, elem)
stat = tf_v1.gfile.Stat(elem_full_path)
size += _dir_size(elem_full_path) if stat.is_directory else stat.length
return size
def _locked_tmp_dir_size(lock_filename):
"""Returns the size of the temp dir pointed to by the given lock file."""
task_uid = _task_uid_from_lock_file(lock_filename)
try:
return _dir_size(
_temp_download_dir(_module_dir(lock_filename), task_uid))
except tf.errors.NotFoundError:
return 0
def _wait_for_lock_to_disappear(handle, lock_file, lock_file_timeout_sec):
"""Waits for the lock file to disappear.
The lock file was created by another process that is performing a download
into its own temporary directory. The name of this temp directory is
sha1(<module>).<uuid>.tmp where <uuid> comes from the lock file.
Args:
handle: The location from where a module is being download.
lock_file: Lock file created by another process downloading this module.
lock_file_timeout_sec: The amount of time to wait (in seconds) before we
can declare that the other downloaded has been
abandoned. The download is declared abandoned if
there is no file size change in the temporary
directory within the last 'lock_file_timeout_sec'.
"""
locked_tmp_dir_size = 0
locked_tmp_dir_size_check_time = time.time()
lock_file_content = None
while tf_v1.gfile.Exists(lock_file):
try:
logging.log_every_n(
logging.INFO,
"Module '%s' already being downloaded by '%s'. Waiting.", 10,
handle, tf_utils.read_file_to_string(lock_file))
if (time.time() - locked_tmp_dir_size_check_time >
lock_file_timeout_sec):
# Check whether the holder of the current lock downloaded anything
# in its temporary directory in the last 'lock_file_timeout_sec'.
cur_locked_tmp_dir_size = _locked_tmp_dir_size(lock_file)
cur_lock_file_content = tf_utils.read_file_to_string(lock_file)
if (cur_locked_tmp_dir_size == locked_tmp_dir_size and
cur_lock_file_content == lock_file_content):
# There is was no data downloaded in the past
# 'lock_file_timeout_sec'. Steal the lock and proceed with the
# local download.
logging.warning("Deleting lock file %s due to inactivity.",
lock_file)
tf_v1.gfile.Remove(lock_file)
break
locked_tmp_dir_size = cur_locked_tmp_dir_size
locked_tmp_dir_size_check_time = time.time()
lock_file_content = cur_lock_file_content
except tf.errors.NotFoundError:
# Lock file or temp directory were deleted during check. Continue
# to check whether download succeeded or we need to start our own
# download.
pass
finally:
time.sleep(5)
def atomic_download(handle,
download_fn,
module_dir,
lock_file_timeout_sec=10 * 60):
"""Returns the path to a Module directory for a given TF-Hub Module handle.
Args:
handle: (string) Location of a TF-Hub Module.
download_fn: Callback function that actually performs download. The callback
receives two arguments, handle and the location of a temporary
directory to download the content into.
module_dir: Directory where to download the module files to.
lock_file_timeout_sec: The amount of time we give the current holder of
the lock to make progress in downloading a module.
If no progress is made, the lock is revoked.
Returns:
A string containing the path to a TF-Hub Module directory.
Raises:
ValueError: if the Module is not found.
tf.errors.OpError: file I/O failures raise the appropriate subtype.
"""
lock_file = _lock_filename(module_dir)
task_uid = uuid.uuid4().hex
lock_contents = _lock_file_contents(task_uid)
tmp_dir = _temp_download_dir(module_dir, task_uid)
# Attempt to protect against cases of processes being cancelled with
# KeyboardInterrupt by using a try/finally clause to remove the lock
# and tmp_dir.
try:
while True:
try:
tf_utils.atomic_write_string_to_file(lock_file, lock_contents,
overwrite=False)
# Must test condition again, since another process could have created
# the module and deleted the old lock file since last test.
if (tf_v1.gfile.Exists(module_dir) and
tf_v1.gfile.ListDirectory(module_dir)):
# Lock file will be deleted in the finally-clause.
return module_dir
if tf_v1.gfile.Exists(module_dir):
tf_v1.gfile.DeleteRecursively(module_dir)
break # Proceed to downloading the module.
# These errors are believed to be permanent problems with the
# module_dir that justify failing the download.
except (tf.errors.NotFoundError,
tf.errors.PermissionDeniedError,
tf.errors.UnauthenticatedError,
tf.errors.ResourceExhaustedError,
tf.errors.InternalError,
tf.errors.InvalidArgumentError,
tf.errors.UnimplementedError):
raise
# All other errors are retried.
# TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write
# should be good enough, but see discussion about misc filesystem types.
# TODO(b/144475403): How atomic is the overwrite=False check?
except tf.errors.OpError:
pass
# Wait for lock file to disappear.
_wait_for_lock_to_disappear(handle, lock_file, lock_file_timeout_sec)
# At this point we either deleted a lock or a lock got removed by the
# owner or another process. Perform one more iteration of the while-loop,
# we would either terminate due tf_v1.gfile.Exists(module_dir) or because
# we would obtain a lock ourselves, or wait again for the lock to
# disappear.
# Lock file acquired.
logging.info("Downloading TF-Hub Module '%s'.", handle)
tf_v1.gfile.MakeDirs(tmp_dir)
download_fn(handle, tmp_dir)
# Write module descriptor to capture information about which module was
# downloaded by whom and when. The file stored at the same level as a
# directory in order to keep the content of the 'model_dir' exactly as it
# was define by the module publisher.
#
# Note: The descriptor is written purely to help the end-user to identify
# which directory belongs to which module. The descriptor is not part of the
# module caching protocol and no code in the TF-Hub library reads its
# content.
_write_module_descriptor_file(handle, module_dir)
try:
tf_v1.gfile.Rename(tmp_dir, module_dir)
logging.info("Downloaded TF-Hub Module '%s'.", handle)
except tf.errors.AlreadyExistsError:
logging.warning("Module already exists in %s", module_dir)
finally:
try:
# Temp directory is owned by the current process, remove it.
tf_v1.gfile.DeleteRecursively(tmp_dir)
except tf.errors.NotFoundError:
pass
try:
contents = tf_utils.read_file_to_string(lock_file)
except tf.errors.NotFoundError:
contents = ""
if contents == lock_contents:
# Lock file exists and is owned by this process.
try:
tf_v1.gfile.Remove(lock_file)
except tf.errors.NotFoundError:
pass
return module_dir
class Resolver(object):
"""Resolver base class: all resolvers inherit from this class."""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def __call__(self, handle):
"""Resolves a handle into a Module path.
Args:
handle: (string) the Module handle to resolve.
Returns:
A string representing the Module path.
"""
pass
@abc.abstractmethod
def is_supported(self, handle):
"""Returns whether a handle is supported by this resolver.
Args:
handle: (string) the Module handle to resolve.
Returns:
True if the handle is properly formatted for this resolver.
Note that a True return value does not indicate that the
handle can be resolved, only that it is the correct format.
"""
pass
class PathResolver(Resolver):
"""Resolves handles which are absolute paths."""
def is_supported(self, handle):
# Path resolver is the last Resolver in the chain so __call__ can always be
# called.
return True
def __call__(self, handle):
if not tf_v1.gfile.Exists(handle):
raise IOError("%s does not exist." % handle)
return handle