forked from tensorflow/hub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_utils.py
189 lines (153 loc) · 6.6 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Replicates TensorFlow utilities which are not part of the public API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import uuid
import tensorflow as tf
def read_file_to_string(filename):
"""Returns the entire contents of a file to a string.
Args:
filename: string, path to a file
"""
return tf.gfile.GFile(filename, mode="r").read()
def atomic_write_string_to_file(filename, contents, overwrite):
"""Writes to `filename` atomically.
This means that when `filename` appears in the filesystem, it will contain
all of `contents`. With write_string_to_file, it is possible for the file
to appear in the filesystem with `contents` only partially written.
Accomplished by writing to a temp file and then renaming it.
Args:
filename: string, pathname for a file
contents: string, contents that need to be written to the file
overwrite: boolean, if false it's an error for `filename` to be occupied by
an existing file.
"""
temp_pathname = (tf.compat.as_bytes(filename) +
tf.compat.as_bytes(".tmp") +
tf.compat.as_bytes(uuid.uuid4().hex))
with tf.gfile.GFile(temp_pathname, mode="w") as f:
f.write(contents)
try:
tf.gfile.Rename(temp_pathname, filename, overwrite)
except tf.errors.OpError:
tf.gfile.Remove(temp_pathname)
raise
# When we create a timestamped directory, there is a small chance that the
# directory already exists because another worker is also writing exports.
# In this case we just wait one second to get a new timestamp and try again.
# If this fails several times in a row, then something is seriously wrong.
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Each export is written into a new subdirectory named using the
current time. This guarantees monotonically increasing version
numbers even across multiple runs of the pipeline.
The timestamp used is the number of seconds since epoch UTC.
Args:
export_dir_base: A string containing a directory to write the exported
graph and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
export_timestamp = int(time.time())
export_dir = os.path.join(
tf.compat.as_bytes(export_dir_base),
tf.compat.as_bytes(str(export_timestamp)))
if not tf.gfile.Exists(export_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return export_dir
time.sleep(1)
attempts += 1
tf.logging.warn(
"Export directory {} already exists; retrying (attempt {}/{})".format(
export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
raise RuntimeError("Failed to obtain a unique export directory name after "
"{} attempts.".format(MAX_DIRECTORY_CREATION_ATTEMPTS))
def get_temp_export_dir(timestamped_export_dir):
"""Builds a directory name based on the argument but starting with 'temp-'.
This relies on the fact that TensorFlow Serving ignores subdirectories of
the base directory that can't be parsed as integers.
Args:
timestamped_export_dir: the name of the eventual export directory, e.g.
/foo/bar/<timestamp>
Returns:
A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>.
"""
(dirname, basename) = os.path.split(timestamped_export_dir)
temp_export_dir = os.path.join(
tf.compat.as_bytes(dirname),
tf.compat.as_bytes("temp-{}".format(basename)))
return temp_export_dir
# Note: This is written from scratch to mimic the pattern in:
# `tf.estimator.LatestExporter._garbage_collect_exports()`.
def garbage_collect_exports(export_dir_base, exports_to_keep):
"""Deletes older exports, retaining only a given number of the most recent.
Export subdirectories are assumed to be named with monotonically increasing
integers; the most recent are taken to be those with the largest values.
Args:
export_dir_base: the base directory under which each export is in a
versioned subdirectory.
exports_to_keep: Number of exports to keep. Older exports will be garbage
collected. Set to None to disable.
"""
if exports_to_keep is None:
return
version_paths = [] # List of tuples (version, path)
for filename in tf.gfile.ListDirectory(export_dir_base):
path = os.path.join(
tf.compat.as_bytes(export_dir_base),
tf.compat.as_bytes(filename))
if len(filename) == 10 and filename.isdigit():
version_paths.append((int(filename), path))
oldest_version_path = sorted(version_paths)[:-exports_to_keep]
for _, path in oldest_version_path:
try:
tf.gfile.DeleteRecursively(path)
except tf.errors.NotFoundError as e:
tf.logging.warn("Can not delete %s recursively: %s" % (path, e))
def bytes_to_readable_str(num_bytes, include_b=False):
"""Generate a human-readable string representing number of bytes.
The units B, kB, MB and GB are used.
Args:
num_bytes: (`int` or None) Number of bytes.
include_b: (`bool`) Include the letter B at the end of the unit.
Returns:
(`str`) A string representing the number of bytes in a human-readable way,
including a unit at the end.
"""
if num_bytes is None:
return str(num_bytes)
if num_bytes < 1024:
result = "%d" % num_bytes
elif num_bytes < 1048576:
result = "%.2fk" % (num_bytes / float(1 << 10))
elif num_bytes < 1073741824:
result = "%.2fM" % (num_bytes / float(1 << 20))
else:
result = "%.2fG" % (num_bytes / float(1 << 30))
if include_b:
result += "B"
return result