forked from tensorflow/hub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport.py
278 lines (228 loc) · 9.48 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exporter tool for TF-Hub text embedding modules.
This tool creates TF-Hub Modules from embeddings text files in the following
format:
token1 1.0 2.0 3.0 4.0 5.0
token2 2.0 3.0 4.0 5.0 6.0
...
Example use:
python export.py --embedding_file=/tmp/embedding.txt --export_path=/tmp/module
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import shutil
import sys
import tempfile
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
FLAGS = None
EMBEDDINGS_VAR_NAME = "embeddings"
def parse_line(line):
"""Parses a line of a text embedding file.
Args:
line: (str) One line of the text embedding file.
Returns:
A token string and its embedding vector in floats.
"""
columns = line.split()
token = columns.pop(0)
values = [float(column) for column in columns]
return token, values
def load(file_path, parse_line_fn):
"""Loads a text embedding into memory as a numpy matrix.
Args:
file_path: Path to the text embedding file.
parse_line_fn: callback function to parse each file line.
Returns:
A tuple of (list of vocabulary tokens, numpy matrix of embedding vectors).
Raises:
ValueError: if the data in the sstable is inconsistent.
"""
vocabulary = []
embeddings = []
embeddings_dim = None
for line in tf.gfile.GFile(file_path):
token, embedding = parse_line_fn(line)
if not embeddings_dim:
embeddings_dim = len(embedding)
elif embeddings_dim != len(embedding):
raise ValueError(
"Inconsistent embedding dimension detected, %d != %d for token %s",
embeddings_dim, len(embedding), token)
vocabulary.append(token)
embeddings.append(embedding)
return vocabulary, np.array(embeddings)
def make_module_spec(vocabulary_file, vocab_size, embeddings_dim,
num_oov_buckets, preprocess_text):
"""Makes a module spec to simply perform token to embedding lookups.
Input of this module is a 1-D list of string tokens. For T tokens input and
an M dimensional embedding table, the lookup result is a [T, M] shaped Tensor.
Args:
vocabulary_file: Text file where each line is a key in the vocabulary.
vocab_size: The number of tokens contained in the vocabulary.
embeddings_dim: The embedding dimension.
num_oov_buckets: The number of out-of-vocabulary buckets.
preprocess_text: Whether to preprocess the input tensor by removing
punctuation and splitting on spaces.
Returns:
A module spec object used for constructing a TF-Hub module.
"""
def module_fn():
"""Spec function for a token embedding module."""
tokens = tf.placeholder(shape=[None], dtype=tf.string, name="tokens")
embeddings_var = tf.get_variable(
initializer=tf.zeros([vocab_size + num_oov_buckets, embeddings_dim]),
name=EMBEDDINGS_VAR_NAME,
dtype=tf.float32)
lookup_table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=num_oov_buckets,
)
ids = lookup_table.lookup(tokens)
combined_embedding = tf.nn.embedding_lookup(params=embeddings_var, ids=ids)
hub.add_signature("default", {"tokens": tokens},
{"default": combined_embedding})
def module_fn_with_preprocessing():
"""Spec function for a full-text embedding module with preprocessing."""
sentences = tf.placeholder(shape=[None], dtype=tf.string, name="sentences")
# Perform a minimalistic text preprocessing by removing punctuation and
# splitting on spaces.
normalized_sentences = tf.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
tokens = tf.string_split(normalized_sentences, " ")
# In case some of the input sentences are empty before or after
# normalization, we will end up with empty rows. We do however want to
# return embedding for every row, so we have to fill in the empty rows with
# a default.
tokens, _ = tf.sparse_fill_empty_rows(tokens, "")
# In case all of the input sentences are empty before or after
# normalization, we will end up with a SparseTensor with shape [?, 0]. After
# filling in the empty rows we must ensure the shape is set properly to
# [?, 1].
tokens = tf.sparse_reset_shape(tokens)
embeddings_var = tf.get_variable(
initializer=tf.zeros([vocab_size + num_oov_buckets, embeddings_dim]),
name=EMBEDDINGS_VAR_NAME,
dtype=tf.float32)
lookup_table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=num_oov_buckets,
)
sparse_ids = tf.SparseTensor(
indices=tokens.indices,
values=lookup_table.lookup(tokens.values),
dense_shape=tokens.dense_shape)
combined_embedding = tf.nn.embedding_lookup_sparse(
params=embeddings_var,
sp_ids=sparse_ids,
sp_weights=None,
combiner="sqrtn")
hub.add_signature("default", {"sentences": sentences},
{"default": combined_embedding})
if preprocess_text:
return hub.create_module_spec(module_fn_with_preprocessing)
else:
return hub.create_module_spec(module_fn)
def export(export_path, vocabulary, embeddings, num_oov_buckets,
preprocess_text):
"""Exports a TF-Hub module that performs embedding lookups.
Args:
export_path: Location to export the module.
vocabulary: List of the N tokens in the vocabulary.
embeddings: Numpy array of shape [N+K,M] the first N rows are the
M dimensional embeddings for the respective tokens and the next K
rows are for the K out-of-vocabulary buckets.
num_oov_buckets: How many out-of-vocabulary buckets to add.
preprocess_text: Whether to preprocess the input tensor by removing
punctuation and splitting on spaces.
"""
# Write temporary vocab file for module construction.
tmpdir = tempfile.mkdtemp()
vocabulary_file = os.path.join(tmpdir, "tokens.txt")
with tf.gfile.GFile(vocabulary_file, "w") as f:
f.write("\n".join(vocabulary))
vocab_size = len(vocabulary)
embeddings_dim = embeddings.shape[1]
spec = make_module_spec(vocabulary_file, vocab_size, embeddings_dim,
num_oov_buckets, preprocess_text)
try:
with tf.Graph().as_default():
m = hub.Module(spec)
# The embeddings may be very large (e.g., larger than the 2GB serialized
# Tensor limit). To avoid having them frozen as constant Tensors in the
# graph we instead assign them through the placeholders and feed_dict
# mechanism.
p_embeddings = tf.placeholder(tf.float32)
load_embeddings = tf.assign(m.variable_map[EMBEDDINGS_VAR_NAME],
p_embeddings)
with tf.Session() as sess:
sess.run([load_embeddings], feed_dict={p_embeddings: embeddings})
m.export(export_path, sess)
finally:
shutil.rmtree(tmpdir)
def maybe_append_oov_vectors(embeddings, num_oov_buckets):
"""Adds zero vectors for oov buckets if num_oov_buckets > 0.
Since we are assigning zero vectors, adding more that one oov bucket is only
meaningful if we perform fine-tuning.
Args:
embeddings: Embeddings to extend.
num_oov_buckets: Number of OOV buckets in the extended embedding.
"""
num_embeddings = np.shape(embeddings)[0]
embedding_dim = np.shape(embeddings)[1]
embeddings.resize(
[num_embeddings + num_oov_buckets, embedding_dim], refcheck=False)
def export_module_from_file(embedding_file, export_path, parse_line_fn,
num_oov_buckets, preprocess_text):
# Load pretrained embeddings into memory.
vocabulary, embeddings = load(embedding_file, parse_line_fn)
# Add OOV buckets if num_oov_buckets > 0.
maybe_append_oov_vectors(embeddings, num_oov_buckets)
# Export the embedding vectors into a TF-Hub module.
export(export_path, vocabulary, embeddings, num_oov_buckets, preprocess_text)
def main(_):
export_module_from_file(FLAGS.embedding_file, FLAGS.export_path, parse_line,
FLAGS.num_oov_buckets, FLAGS.preprocess_text)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--embedding_file",
type=str,
default=None,
help="Path to file with embeddings.")
parser.add_argument(
"--export_path",
type=str,
default=None,
help="Where to export the module.")
parser.add_argument(
"--preprocess_text",
type=bool,
default=False,
help="Whether to preprocess the input tensor by removing punctuation and "
"splitting on spaces. Use this if input is a dense tensor of untokenized "
"sentences.")
parser.add_argument(
"--num_oov_buckets",
type=int,
default="1",
help="How many OOV buckets to add.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)