forked from tensorflow/hub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodule.py
383 lines (309 loc) · 15 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Hub Module definition."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_hub import module_spec
from tensorflow_hub import native_module
from tensorflow_hub import tensor_info
def as_module_spec(spec):
if isinstance(spec, module_spec.ModuleSpec):
return spec
elif isinstance(spec, str):
return native_module.load_module_spec(spec)
else:
raise ValueError("Unknown module spec type: %r" % type(spec))
# Module class provides a unified access to all ModuleSpecs implementations and
# should not contain specific implementation code in it (e.g. SavedModel code).
class Module(object):
"""Part of a TensorFlow model that can be transferred between models.
A Module represents a part of a TensorFlow graph that can be exported to disk
(based on the SavedModel format) and later re-loaded. A Module has a defined
interface that allows it to be used in a replaceable way, with little or no
knowledge of its internals and its serialization format. Example:
```python
m = hub.Module("/tmp/text-embedding")
embeddings = m(sentences)
```
The module to instantiate is defined by its spec (a `ModuleSpec` or a
path where to load it from) which contains the module weights, assets and
signatures.
During instantiation the Module adds the state (e.g. variables and tables ops)
to the current graph. Afterwards, the method `__call__()` allows to apply the
module `signatures` multiple times, which adds ops for the computation.
A Module may provide different variants of its graph for different purposes
(say, training or serving, which may behave differently, e.g., for batch
normalization). Graph variants are identified by sets of string-valued tags.
The graph variant used to create a module that is exported must define all the
variables needed by any other graph variant that is subsequently used.
To make it possible to easily replace a module with another, they all assume
that they will be used with common TensorFlow conventions such as session
initialization and restore, use of collections for variables, regularization
losses and updates, etc.
"""
def __init__(self, spec, trainable=False, name="module", tags=None):
"""Constructs a Module to be used in the current graph.
This creates the module `state-graph` under an unused variable_scope based
on `name`. During this call a Module will:
- Add GLOBAL_VARIABLES under its scope. Those variables may be added to
to the TRAINABLE_VARIABLES collection (depending on `trainable` parameter)
and to the MODEL_VARIABLES. The variables must be initialized before use,
and can be checkpointed as usual.
- Add ops to the INIT_TABLE_OPS collection, which must be run during session
initialization and add constant tensors to ASSET_FILEPATHS that are
needed during the execution of such ops.
- Add tensors to the REGULARIZATION_LOSSES collection (depending on
`trainable` parameter).
Args:
spec: A ModuleSpec defining the Module to instantiate or a path where
to load a ModuleSpec from via `load_module_spec`.
trainable: whether the Module is trainable. If False, no variables are
added to TRAINABLE_VARIABLES collection, and no tensors are added to
REGULARIZATION_LOSSES collection.
name: A string, the variable scope name under which to create the Module.
It will be uniquified and the equivalent name scope must be unused.
tags: A set of strings specifying the graph variant to use.
Raises:
RuntimeError: explaning the reason why it failed to instantiate the
Module.
ValueError: if the requested graph variant does not exists.
"""
self._graph = tf.get_default_graph()
self._spec = as_module_spec(spec)
self._trainable = trainable
self._tags = set(tags or [])
if self._tags not in self._spec.get_tags():
raise ValueError("No such graph variant: tags=%r" % sorted(list(tags)))
abs_state_scope = _try_get_state_scope(name, mark_name_scope_used=False)
self._name = abs_state_scope.split("/")[-2]
abs_parent_scope = abs_state_scope.split("/")[:-2]
if abs_parent_scope:
abs_parent_scope = "/".join(abs_parent_scope) + "/"
else:
abs_parent_scope = ""
with tf.name_scope(abs_parent_scope):
# pylint: disable=protected-access
self._impl = self._spec._create_impl(
name=self._name,
trainable=self._trainable,
tags=self._tags)
# pylint: enable=protected-access
def __call__(self, inputs=None, # pylint: disable=invalid-name
_sentinel=None, signature=None, as_dict=None):
"""Instantiates a module signature in the graph.
Example calls:
```python
# Use default signature with one input and default output.
embeddings = m(["hello world", "good morning"])
# Use "encode" signature with one input and default output.
encodings = m(["hello world"], signature="encode")
# Use default signature with input dict and output dict.
dict_outputs = m({"text": [...], "lang": [...]}, as_dict=True)
```
The method __call__() allows to create the graph ops that compute a
signature outputs given the inputs and using this module instance state.
Each signature can be applied multiple times with different inputs and they
all share the same module state.
A Module may define multiple signatures. Use `signature=<name>` to identify
the specific signature to instantiate. If omitted or None, the default
signature is used.
A signature may define various outputs. Use `as_dict=True` to return a dict
of all outputs. If omitted or False, the output named 'default' is
returned.
During this call a Module will:
- Add ops in the current name scope to convert the inputs in tensors to feed
to the signature.
- Add ops to the UPDATE_OPS collection which depend on at least one of the
provided inputs if the Module was constructed with `trainable=True`.
- Add constant tensors to ASSET_FILEPATHS, even if those are not needed
directly needed for the signature.
Args:
inputs: Inputs to the signature. A dict from input names to tensor
values. If the signature only expects one input, one may pass
a single value. If the signature has no inputs, it may be omitted.
_sentinel: Used to prevent positional parameters besides `inputs`.
signature: A string with the signature name to apply. If none, the
default signature is used.
as_dict: A boolean indicating whether to the return all the outputs
of the signature as a dict or return only the default output.
Returns:
A tensor (single or sparse) if the signature defines a default output or
a dict from strings (output names) to tensors if `as_dict=True` is used.
Raises:
TypeError: If there is a mismatch on arguments, inputs or outputs of
the module signature.
RuntimeError: If there are errors during creation of the signature graph.
"""
if self._graph is not tf.get_default_graph():
raise RuntimeError(
"Module must be applied in the graph it was instantiated for.")
signature = self._impl.get_signature_name(signature)
name = "%s_apply_%s" % (self._name, signature)
dict_inputs = _prepare_dict_inputs(
inputs, self._spec.get_input_info_dict(signature=signature,
tags=self._tags))
dict_outputs = self._impl.create_apply_graph(
signature=signature,
inputs=dict_inputs,
name=name)
return _prepare_outputs(dict_outputs, as_dict=as_dict)
def get_signature_names(self):
"""Returns the module's signature names as an iterable of strings."""
return self._spec.get_signature_names(tags=self._tags)
def get_input_info_dict(self, signature=None):
"""Describes the inputs required by a signature.
Args:
signature: A string with the signature to get inputs information for.
If None, the default signature is used if defined.
Returns:
The result of ModuleSpec.get_input_info_dict() for the given signature,
and the graph variant selected by `tags` when this Module was initialized.
Raises:
KeyError: if there is no such signature.
"""
return self._spec.get_input_info_dict(signature=signature, tags=self._tags)
def get_output_info_dict(self, signature=None):
"""Describes the outputs provided by a signature.
Args:
signature: A string with the signature to get ouputs information for.
If None, the default signature is used if defined.
Returns:
The result of ModuleSpec.get_input_info_dict() for the given signature,
and the graph variant selected by `tags` when this Module was initialized.
Raises:
KeyError: if there is no such signature.
"""
return self._spec.get_output_info_dict(signature=signature, tags=self._tags)
def export(self, path, session):
"""Exports the module with the variables from the session in `path`.
Note that it is the module definition in the ModuleSpec used to create this
module that gets exported. The session is only used to provide the value
of variables.
Args:
path: path where to export the module to.
session: session where to export the variables from.
Raises:
RuntimeError: if there is an issue during the export.
"""
if self._graph is not tf.get_default_graph():
raise RuntimeError("default graph differs from the graph where the "
"module was instantiated.")
if self._graph is not session.graph:
raise RuntimeError("session graph differs from the graph where the "
"module was instantiated.")
self._impl.export(path, session)
@property
def variable_map(self):
"""Map from original variable names into tf.Variables (or lists of them).
This map translates between variable names relative to the module and the
corresponding Variable objects that have been created by instantiating it
in the current graph (with the applicable scoping added). Each key in the
map is a variable name as created by running the module's defining
`module_fn` in the root scope of an empty graph. Each value in the map is
a Variable object, or in case of partitioned variables a list of Variable
objects.
This property can be used with `tf.init_from_checkpoint` as `assignment_map`
in order to restore a pre-trained checkpoint into a Module before calling
`Module.export()`.
Returns:
A dict from the variable names in the Module to the instantiated
tf.Variables or list of tf.Variables (if partitioned). The keys of this
map are the same regardless of the scope of where the Module was
instantiated.
"""
return self._impl.variable_map
def _try_get_state_scope(name, mark_name_scope_used=True):
"""Returns a fresh variable/name scope for a module's state.
In order to import a module into a given scope without major complications
we require the scope to be empty. This function deals with deciding an unused
scope where to define the module state. This is non trivial in cases where
name_scope and variable_scopes are out of sync, e.g. tpus or re-entering
scopes.
Args:
name: A string with the name of the module as supplied by the client.
mark_name_scope_used: a boolean, indicating whether to mark the name
scope of the returned value as used.
Raises:
RuntimeError: if the name scope of the freshly created variable scope is
already used.
"""
tmp_scope_name = tf.get_variable_scope().name
if tmp_scope_name:
tmp_scope_name += "/"
with tf.name_scope(tmp_scope_name):
# Pick an unused variable scope.
with tf.variable_scope(
None, default_name=name, auxiliary_name_scope=False) as vs:
abs_state_scope = vs.name + "/"
# Verify that the name scope is available and mark it used if requested.
graph = tf.get_default_graph()
unique_name_scope = graph.unique_name(name, mark_name_scope_used) + "/"
if unique_name_scope != abs_state_scope:
raise RuntimeError(
"variable_scope %s was unused but the corresponding "
"name_scope was already taken." % abs_state_scope)
return abs_state_scope
def _prepare_dict_inputs(inputs, tensor_info_map):
"""Converts from inputs into dict inputs.
This handles:
- converting of a single value into a dict with a single key
if the signature only has one expected input.
- converting all input values into tensors compatible with the
expected input tensor (dtype, shape).
- check sparse/non-sparse tensor types.
- check that exactly the needed inputs are given: i.e. no extra
args and no missing args.
Args:
inputs: inputs fed to Module.__call__().
tensor_info_map: A map from string to `tensor_info.ParsedTensorInfo`
describing the signature inputs.
Returns:
A dict of tensors to feed to the signature instantiation.
Raises:
TypeError: If it fails to convert the input values into a dict of tensors
to feed to the signature instantiation.
"""
if inputs is None:
dict_inputs = {}
elif isinstance(inputs, dict):
dict_inputs = inputs
elif len(tensor_info_map) == 1:
dict_inputs = {list(tensor_info_map.keys())[0]: inputs}
elif not tensor_info_map:
raise TypeError("Signature expects no inputs.")
else:
raise TypeError("Signature expects multiple inputs. Use a dict.")
# Finally convert a dict of values into a dict of tensors.
return tensor_info.make_compatible_dict(dict_inputs, tensor_info_map)
def _prepare_outputs(dict_outputs, as_dict):
"""Converts from dict outputs into the return value of Module.__call__().
Args:
dict_outputs: A dict output from applying a signature.
as_dict: A boolean indicating whether to return the outputs of the Module
as a dict or return the output named 'default.
Returns:
A tensor with the output named 'default' or a dict of output tensors if
`as_dict=True`.
Raises:
TypeError: If as_dict is False and there is no output named 'default'.
"""
if as_dict:
return dict_outputs
if "default" in dict_outputs:
return dict_outputs["default"]
else:
raise TypeError("There is no output named 'default'. Use as_dict=True.")