-
Notifications
You must be signed in to change notification settings - Fork 0
TensorFlow Transformer Part-3 #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 28 commits
f4d938c
cd3aa8d
2e8f7a1
e09027f
40caace
e963d11
ce60629
e0cf2ff
202e7ea
faf8cdd
cf72beb
20e2dbc
20a5346
cf64708
c3b3a86
e47060f
4e8f4e3
eaa5fa0
43d6583
ee3acf1
7f16396
f5107ad
ac681b0
4d173c5
707697d
fe719b2
a39b6d3
decdc8f
479a9aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
# Copyright 2017 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import tensorflow as tf | ||
from tensorflow.core.protobuf import meta_graph_pb2 # pylint: disable=no-name-in-module | ||
|
||
import sparkdl.graph.utils as tfx | ||
|
||
__all__ = ["TFInputGraph"] | ||
|
||
# pylint: disable=invalid-name,wrong-spelling-in-comment,wrong-spelling-in-docstring | ||
|
||
class TFInputGraph(object): | ||
""" | ||
An opaque object containing TensorFlow graph. | ||
This object can be serialized. | ||
|
||
.. note:: We recommend constructing this object using one of the class constructor methods. | ||
|
||
- :py:meth:`fromGraph` | ||
- :py:meth:`fromGraphDef` | ||
- :py:meth:`fromCheckpoint` | ||
- :py:meth:`fromCheckpointWithSignature` | ||
- :py:meth:`fromSavedModel` | ||
- :py:meth:`fromSavedModelWithSignature` | ||
|
||
|
||
When the graph contains serving signatures in which a set of well-known names are associated | ||
with their corresponding raw tensor names in the graph, we extract and store them here. | ||
For example, the TensorFlow saved model may contain the following structure, | ||
so that end users can retrieve the the input tensor via `well_known_input_sig` and | ||
the output tensor via `well_known_output_sig` without knowing the actual tensor names a priori. | ||
|
||
.. code-block:: python | ||
|
||
sigdef: {'well_known_prediction_signature': | ||
inputs { key: "well_known_input_sig" | ||
value { | ||
name: "tnsrIn:0" | ||
dtype: DT_DOUBLE | ||
tensor_shape { dim { size: -1 } dim { size: 17 } } | ||
} | ||
} | ||
outputs { key: "well_known_output_sig" | ||
value { | ||
name: "tnsrOut:0" | ||
dtype: DT_DOUBLE | ||
tensor_shape { dim { size: -1 } } | ||
} | ||
}} | ||
|
||
|
||
In this case, the class will internally store the mapping from signature names to tensor names. | ||
|
||
.. code-block:: python | ||
|
||
{'well_known_input_sig': 'tnsrIn:0'} | ||
{'well_known_output_sig': 'tnsrOut:0'} | ||
|
||
|
||
:param graph_def: :py:obj:`tf.GraphDef`, a serializable object containing the topology and | ||
computation units of the TensorFlow graph. The graph object is prepared for | ||
inference, i.e. the variables are converted to constants and operations like | ||
BatchNormalization_ are converted to be independent of input batch. | ||
|
||
.. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization | ||
|
||
:param input_tensor_name_from_signature: dict, signature key names mapped to tensor names. | ||
Please see the example above. | ||
:param output_tensor_name_from_signature: dict, signature key names mapped to tensor names | ||
Please see the example above. | ||
""" | ||
|
||
|
||
def __init__(self, graph_def, input_tensor_name_from_signature, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i wish we had better names for these maps... naming is so hard. sig_key_to_input_tensor_names, sig_key_to_output_tensor_names ? the problem is, "_from_" generally doesn't hint that it's a map. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. Naming things is difficult, especially in this case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah i think those are actually good, since it conveys the meaning that it's mapping from some signature thing to tensor, and really we don't need the user to know exactly what the keys are -- users aren't expected to use these directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as the conversion methods (in part 4) make it clear what their inputs are. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is an internal variable that users should not access directly, let's keep is as it is. |
||
output_tensor_name_from_signature): | ||
self.graph_def = graph_def | ||
self.input_tensor_name_from_signature = input_tensor_name_from_signature | ||
self.output_tensor_name_from_signature = output_tensor_name_from_signature | ||
|
||
@classmethod | ||
def fromGraph(cls, graph, sess, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraph from a in memory `tf.Graph` object. | ||
The graph might contain variables that are maintained in the provided session. | ||
Thus we need an active session in which the graph's variables are initialized or | ||
restored. We do not close the session. As a result, this constructor can be used | ||
inside a standard TensorFlow session context. | ||
|
||
.. code-block:: python | ||
|
||
with tf.Session() as sess: | ||
graph = import_my_tensorflow_graph(...) | ||
input = TFInputGraph.fromGraph(graph, sess, ...) | ||
|
||
:param graph: a :py:class:`tf.Graph` object containing the topology and computation units of | ||
the TensorFlow graph. | ||
:param feed_names: list, names of the input tensors. | ||
:param fetch_names: list, names of the output tensors. | ||
""" | ||
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, | ||
fetch_names=fetch_names) | ||
|
||
@classmethod | ||
def fromGraphDef(cls, graph_def, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraph from a tf.GraphDef object. | ||
|
||
:param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and | ||
computation units of the TensorFlow graph. | ||
:param feed_names: list, names of the input tensors. | ||
:param fetch_names: list, names of the output tensors. | ||
""" | ||
assert isinstance(graph_def, tf.GraphDef), \ | ||
('expect tf.GraphDef type but got', type(graph_def)) | ||
|
||
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tf.import_graph_def(graph_def, name='') | ||
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, | ||
fetch_names=fetch_names) | ||
|
||
@classmethod | ||
def fromCheckpoint(cls, checkpoint_dir, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraph object from a checkpoint, ignore the embedded | ||
signature_def, if there is any. | ||
|
||
:param checkpoint_dir: str, name of the directory containing the TensorFlow graph | ||
training checkpoint. | ||
:param feed_names: list, names of the input tensors. | ||
:param fetch_names: list, names of the output tensors. | ||
""" | ||
return _from_checkpoint_impl(checkpoint_dir, signature_def_key=None, feed_names=feed_names, | ||
fetch_names=fetch_names) | ||
|
||
@classmethod | ||
def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key): | ||
""" | ||
Construct a TFInputGraph object from a checkpoint, using the embedded | ||
signature_def. Throw an error if we cannot find an entry with the `signature_def_key` | ||
inside the `signature_def`. | ||
|
||
:param checkpoint_dir: str, name of the directory containing the TensorFlow graph | ||
training checkpoint. | ||
:param signature_def_key: str, key (name) of the signature_def to use. It should be in | ||
the list of `signature_def` structures saved with the checkpoint. | ||
""" | ||
assert signature_def_key is not None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you checking for some parameters but not some others? This is python, there is only so much you can do. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to be vocal about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. |
||
return _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names=None, | ||
fetch_names=None) | ||
|
||
@classmethod | ||
def fromSavedModel(cls, saved_model_dir, tag_set, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory. | ||
Ignore the the embedded signature_def, if there is any. | ||
|
||
:param saved_model_dir: str, name of the directory containing the TensorFlow graph | ||
training checkpoint. | ||
:param tag_set: str, name of the graph stored in this meta_graph of the saved model | ||
that we are interested in using. | ||
:param feed_names: list, names of the input tensors. | ||
:param fetch_names: list, names of the output tensors. | ||
""" | ||
return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=None, | ||
feed_names=feed_names, fetch_names=fetch_names) | ||
|
||
@classmethod | ||
def fromSavedModelWithSignature(cls, saved_model_dir, tag_set, signature_def_key): | ||
""" | ||
Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory, | ||
using the embedded signature_def. Throw error if we cannot find an entry with | ||
the `signature_def_key` inside the `signature_def`. | ||
|
||
:param saved_model_dir: str, name of the directory containing the TensorFlow graph | ||
training checkpoint. | ||
:param tag_set: str, name of the graph stored in this meta_graph of the saved model | ||
that we are interested in using. | ||
:param signature_def_key: str, key (name) of the signature_def to use. It should be in | ||
the list of `signature_def` structures saved with the | ||
TensorFlow `SavedModel`. | ||
""" | ||
assert signature_def_key is not None | ||
return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=signature_def_key, | ||
feed_names=None, fetch_names=None) | ||
|
||
|
||
def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. honestly given the docs above we don't really need the docs for this function and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Methods whose name starts with '_' do not appear in the docs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good to know, thanks for the clarification. |
||
Construct a TFInputGraph from a model checkpoint. | ||
Notice that one should either provide the `signature_def_key` or provide both | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the fact that you need to provide an alternative is a very strong indication that you should break it into 2 functions, and that your logic is more complicated than it ought to be. From the sig key, you can extract the feed and fetch names / mapping and then call the same code path. If we had unlimited time and effort, I would ask you to do this change, but the current logic is understandable enough as it stands to debug it. |
||
`feed_names` and `fetch_names`. Please set the unprovided values to None. | ||
|
||
:param signature_def_key: str, name of the mapping contained inside the `signature_def` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
from which we retrieve the signature key to tensor names mapping. | ||
:param feed_names: list, names of the input tensors. | ||
:param fetch_names: list, names of the output tensors. | ||
""" | ||
assert (feed_names is None) == (fetch_names is None), \ | ||
'feed_names and fetch_names, if provided must be both non-None.' | ||
assert (feed_names is None) != (signature_def_key is None), \ | ||
'must either provide feed_names or singnature_def_key' | ||
|
||
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
# Load checkpoint and import the graph | ||
ckpt_path = tf.train.latest_checkpoint(checkpoint_dir) | ||
|
||
# NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def | ||
# the current `import_graph_def` function seems to ignore | ||
# any signature_def fields in a checkpoint's meta_graph_def. | ||
meta_graph_def = meta_graph_pb2.MetaGraphDef() | ||
with open("{}.meta".format(ckpt_path), 'rb') as fin: | ||
meta_graph_def.ParseFromString(fin.read()) | ||
|
||
saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True) | ||
saver.restore(sess, ckpt_path) | ||
|
||
if signature_def_key is not None: | ||
sig_def = meta_graph_def.signature_def[signature_def_key] | ||
return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) | ||
else: | ||
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, | ||
fetch_names=fetch_names) | ||
|
||
def _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraph from a SavedModel. | ||
Notice that one should either provide the `signature_def_key` or provide both | ||
`feed_names` and `fetch_names`. Please set the unprovided values to None. | ||
|
||
:param signature_def_key: str, name of the mapping contained inside the `signature_def` | ||
from which we retrieve the signature key to tensor names mapping. | ||
:param feed_names: list, names of the input tensors. | ||
:param fetch_names: list, names of the output tensors. | ||
""" | ||
assert (feed_names is None) == (fetch_names is None), \ | ||
'feed_names and fetch_names, if provided must appear together' | ||
assert (feed_names is None) != (signature_def_key is None), \ | ||
'must either provide feed_names or singnature_def_key' | ||
|
||
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tag_sets = tag_set.split(',') | ||
meta_graph_def = tf.saved_model.loader.load(sess, tag_sets, saved_model_dir) | ||
|
||
if signature_def_key is not None: | ||
sig_def = tf.contrib.saved_model.get_signature_def_by_key(meta_graph_def, | ||
signature_def_key) | ||
return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) | ||
else: | ||
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, | ||
fetch_names=fetch_names) | ||
|
||
|
||
def _build_with_sig_def(sess, graph, sig_def): | ||
# pylint: disable=protected-access | ||
assert sig_def, 'signature_def must not be None' | ||
|
||
with sess.as_default(), graph.as_default(): | ||
feed_mapping = {} | ||
feed_names = [] | ||
for sigdef_key, tnsr_info in sig_def.inputs.items(): | ||
tnsr_name = tnsr_info.name | ||
feed_mapping[sigdef_key] = tnsr_name | ||
feed_names.append(tnsr_name) | ||
|
||
fetch_mapping = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should add some tests that specifically check that these mappings are created correctly. |
||
fetch_names = [] | ||
for sigdef_key, tnsr_info in sig_def.outputs.items(): | ||
tnsr_name = tnsr_info.name | ||
fetch_mapping[sigdef_key] = tnsr_name | ||
fetch_names.append(tnsr_name) | ||
|
||
for tnsr_name in feed_names: | ||
assert tfx.get_op(tnsr_name, graph), \ | ||
'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) | ||
fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] | ||
graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) | ||
|
||
return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping, | ||
output_tensor_name_from_signature=fetch_mapping) | ||
|
||
|
||
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names): | ||
assert feed_names is not None, "must provide feed_names" | ||
assert fetch_names is not None, "must provide fetch names" | ||
|
||
with sess.as_default(), graph.as_default(): | ||
for tnsr_name in feed_names: | ||
assert tfx.get_op(tnsr_name, graph), \ | ||
'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) | ||
fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] | ||
graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) | ||
|
||
return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None, | ||
output_tensor_name_from_signature=None) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,7 +167,14 @@ def _check_is_tensor_name(_maybe_tnsr_name): | |
raise TypeError(err_msg.format(type(_maybe_tnsr_name))) | ||
|
||
# The check is taken from TensorFlow's NodeDef protocol buffer. | ||
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25 | ||
# Each input is "node:src_output" with "node" being a string name and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, good point. In practice, most nodes have only one output, so I have not been too concerned about multiple outputs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, that was copied-and-pasted from comments inTensorFlow node_def. |
||
# "src_output" indicating which output tensor to use from "node". If | ||
# "src_output" is 0 the ":0" suffix can be omitted. Regular inputs | ||
# may optionally be followed by control inputs that have the format | ||
# "^node". | ||
# Reference: | ||
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto | ||
# https://stackoverflow.com/questions/36150834/how-does-tensorflow-name-tensors | ||
try: | ||
_, src_idx = _maybe_tnsr_name.split(":") | ||
_ = int(src_idx) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this whole file has much more complexity that it really needs to, given the aim of its content.
We do not need a builder pattern, based on looking at the other PRs. You also happen to expose these builder classes implicitly because you return them, but this is not transparent from
__all__
. This is a limitation of python that does not help for software engineering.Once we drop that, all this code really needs to expose is one immutable data structure and six functions that are stateless. As we discussed, there is no need to create additional classes, static methods or other python features. Everything here can be done with regular functions. Here is what I think it should look like:
You should also take a look at named tuples for building them, but there you should absolutely not need any of the more advanced features for them (for the time being).
https://docs.python.org/2/library/collections.html#collections.namedtuple
This is a significant rewrite, but it is going to dramatically improve the readability and the correctness in the face of future changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
classmethod
to create class instances is a well-accepted approach. It is commonly used in CPython implementation of datetime.Classes created from
collections.namedtuple
are self-documenting, resembling Scala's case classes to some degree.You can see that they provide a lot more functions than we actually need.
In addition, since the class inherit from
tuple
, the following should be noted.