Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f4d938c
intro: TFInputGraph
phi-dbq Sep 19, 2017
cd3aa8d
tests
phi-dbq Sep 19, 2017
2e8f7a1
Merge branch 'tf-transformer-part1' into tf-transformer-part3
phi-dbq Sep 23, 2017
e09027f
Merge branch 'tf-transformer-part1' into tf-transformer-part3
phi-dbq Sep 23, 2017
40caace
and so there is no helper classes
phi-dbq Sep 23, 2017
e963d11
and into more pieces
phi-dbq Sep 23, 2017
ce60629
class & docs
phi-dbq Sep 25, 2017
e0cf2ff
update docs
phi-dbq Sep 25, 2017
202e7ea
refactoring tfx API
phi-dbq Sep 29, 2017
faf8cdd
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Sep 29, 2017
cf72beb
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Sep 29, 2017
20e2dbc
update tfx utils usage
phi-dbq Sep 29, 2017
20a5346
one way to build these tests
phi-dbq Sep 29, 2017
cf64708
tests refactored
phi-dbq Sep 30, 2017
c3b3a86
test cases in a single class
phi-dbq Sep 30, 2017
e47060f
shuffle things around
phi-dbq Sep 30, 2017
4e8f4e3
docs mostly
phi-dbq Sep 30, 2017
eaa5fa0
yapf'd
phi-dbq Sep 30, 2017
43d6583
consolidate tempdir creation
phi-dbq Oct 2, 2017
ee3acf1
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Oct 3, 2017
7f16396
Merge branch 'tf-transformer-part1' into tf-transformer-part3
phi-dbq Oct 3, 2017
f5107ad
(wip) PR comments
phi-dbq Oct 3, 2017
ac681b0
more tests
phi-dbq Oct 3, 2017
4d173c5
change test generator module name
phi-dbq Oct 3, 2017
707697d
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Oct 3, 2017
fe719b2
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Oct 5, 2017
a39b6d3
TFTransformer Part-3 Test Refactor (#14)
thunterdb Nov 18, 2017
decdc8f
PR comments
phi-dbq Nov 22, 2017
479a9aa
TensorFlow Transformer Part-4 (#11)
phi-dbq Nov 22, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/docs/sparkdl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ Subpackages

.. toctree::

sparkdl.estimators
sparkdl.graph
sparkdl.image
sparkdl.param
sparkdl.transformers
sparkdl.udf
sparkdl.utils
Expand Down
311 changes: 311 additions & 0 deletions python/sparkdl/graph/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import, division, print_function

import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2 # pylint: disable=no-name-in-module

import sparkdl.graph.utils as tfx

__all__ = ["TFInputGraph"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole file has much more complexity that it really needs to, given the aim of its content.

We do not need a builder pattern, based on looking at the other PRs. You also happen to expose these builder classes implicitly because you return them, but this is not transparent from __all__. This is a limitation of python that does not help for software engineering.

Once we drop that, all this code really needs to expose is one immutable data structure and six functions that are stateless. As we discussed, there is no need to create additional classes, static methods or other python features. Everything here can be done with regular functions. Here is what I think it should look like:

TFInputGraph=namedtuple(["graph_def", "inputs", "outputs"])
"""
A frozen representation of a tensorflow graph, and some extra information to map spark columns to
the graph's input and output tensors. Users should not have to peek into its content, or make
any assumption about the content.

graph_def: a tensorflow GraphDef object
inputs: a dictionary of {string: string} where the key is XXX and the value is XXX
outputs: XXX
"""

def fromGraph(graph, sess, feed_names, fetch_names):
    """
    Builds an internal representation of a tensorflow graph, based on a tf.Graph object
    :param graph: XXX
    :param sess: XXX
    :param feed_names: XXX <- put the details here!
    :param fetch_names: XXX <- put the details here!
    :return: a TFInputGraph object
    """
    raise

... other public functions with full documentation

... all the other private functions. They can be lighter on documentation, but they should not require any sophisticated python features. No extra classes are required.

You should also take a look at named tuples for building them, but there you should absolutely not need any of the more advanced features for them (for the time being).
https://docs.python.org/2/library/collections.html#collections.namedtuple

This is a significant rewrite, but it is going to dramatically improve the readability and the correctness in the face of future changes.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using classmethod to create class instances is a well-accepted approach. It is commonly used in CPython implementation of datetime.

Classes created from collections.namedtuple are self-documenting, resembling Scala's case classes to some degree.

>>> ABC = namedtuple('ABC', 'a, b, c', verbose=True)
class ABC(tuple):
    'ABC(a, b, c)'

    __slots__ = ()

    _fields = ('a', 'b', 'c')

    def __new__(_cls, a, b, c):
        'Create new instance of ABC(a, b, c)'
        return _tuple.__new__(_cls, (a, b, c))

    @classmethod
    def _make(cls, iterable, new=tuple.__new__, len=len):
        'Make a new ABC object from a sequence or iterable'
        result = new(cls, iterable)
        if len(result) != 3:
            raise TypeError('Expected 3 arguments, got %d' % len(result))
        return result

    def __repr__(self):
        'Return a nicely formatted representation string'
        return 'ABC(a=%r, b=%r, c=%r)' % self

    def _asdict(self):
        'Return a new OrderedDict which maps field names to their values'
        return OrderedDict(zip(self._fields, self))

    def _replace(_self, **kwds):
        'Return a new ABC object replacing specified fields with new values'
        result = _self._make(map(kwds.pop, ('a', 'b', 'c'), _self))
        if kwds:
            raise ValueError('Got unexpected field names: %r' % kwds.keys())
        return result

    def __getnewargs__(self):
        'Return self as a plain tuple.  Used by copy and pickle.'
        return tuple(self)

    __dict__ = _property(_asdict)

    def __getstate__(self):
        'Exclude the OrderedDict from pickling'
        pass

    a = _property(_itemgetter(0), doc='Alias for field number 0')

    b = _property(_itemgetter(1), doc='Alias for field number 1')

    c = _property(_itemgetter(2), doc='Alias for field number 2')

You can see that they provide a lot more functions than we actually need.
In addition, since the class inherit from tuple, the following should be noted.

>>> ABC = namedtuple('ABC', 'a, b, c'); isinstance(ABC(a=1, b=2, c=3), tuple)
True


# pylint: disable=invalid-name,wrong-spelling-in-comment,wrong-spelling-in-docstring

class TFInputGraph(object):
"""
An opaque object containing TensorFlow graph.
This object can be serialized.

.. note:: We recommend constructing this object using one of the class constructor methods.

- :py:meth:`fromGraph`
- :py:meth:`fromGraphDef`
- :py:meth:`fromCheckpoint`
- :py:meth:`fromCheckpointWithSignature`
- :py:meth:`fromSavedModel`
- :py:meth:`fromSavedModelWithSignature`


When the graph contains serving signatures in which a set of well-known names are associated
with their corresponding raw tensor names in the graph, we extract and store them here.
For example, the TensorFlow saved model may contain the following structure,
so that end users can retrieve the the input tensor via `well_known_input_sig` and
the output tensor via `well_known_output_sig` without knowing the actual tensor names a priori.

.. code-block:: python

sigdef: {'well_known_prediction_signature':
inputs { key: "well_known_input_sig"
value {
name: "tnsrIn:0"
dtype: DT_DOUBLE
tensor_shape { dim { size: -1 } dim { size: 17 } }
}
}
outputs { key: "well_known_output_sig"
value {
name: "tnsrOut:0"
dtype: DT_DOUBLE
tensor_shape { dim { size: -1 } }
}
}}


In this case, the class will internally store the mapping from signature names to tensor names.

.. code-block:: python

{'well_known_input_sig': 'tnsrIn:0'}
{'well_known_output_sig': 'tnsrOut:0'}


:param graph_def: :py:obj:`tf.GraphDef`, a serializable object containing the topology and
computation units of the TensorFlow graph. The graph object is prepared for
inference, i.e. the variables are converted to constants and operations like
BatchNormalization_ are converted to be independent of input batch.

.. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

:param input_tensor_name_from_signature: dict, signature key names mapped to tensor names.
Please see the example above.
:param output_tensor_name_from_signature: dict, signature key names mapped to tensor names
Please see the example above.
"""


def __init__(self, graph_def, input_tensor_name_from_signature,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wish we had better names for these maps... naming is so hard. sig_key_to_input_tensor_names, sig_key_to_output_tensor_names ? the problem is, "_from_" generally doesn't hint that it's a map.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Naming things is difficult, especially in this case.
TF does not seem to have a proper name for "well_known_input_sig" yet.
And signature_def_key refers to "well_known_prediction_signature" in our example.
Maybe input_signature_to_tensor_name and output_signature_to_tensor_name?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i think those are actually good, since it conveys the meaning that it's mapping from some signature thing to tensor, and really we don't need the user to know exactly what the keys are -- users aren't expected to use these directly.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as the conversion methods (in part 4) make it clear what their inputs are.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an internal variable that users should not access directly, let's keep is as it is.

output_tensor_name_from_signature):
self.graph_def = graph_def
self.input_tensor_name_from_signature = input_tensor_name_from_signature
self.output_tensor_name_from_signature = output_tensor_name_from_signature

@classmethod
def fromGraph(cls, graph, sess, feed_names, fetch_names):
"""
Construct a TFInputGraph from a in memory `tf.Graph` object.
The graph might contain variables that are maintained in the provided session.
Thus we need an active session in which the graph's variables are initialized or
restored. We do not close the session. As a result, this constructor can be used
inside a standard TensorFlow session context.

.. code-block:: python

with tf.Session() as sess:
graph = import_my_tensorflow_graph(...)
input = TFInputGraph.fromGraph(graph, sess, ...)

:param graph: a :py:class:`tf.Graph` object containing the topology and computation units of
the TensorFlow graph.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
fetch_names=fetch_names)

@classmethod
def fromGraphDef(cls, graph_def, feed_names, fetch_names):
"""
Construct a TFInputGraph from a tf.GraphDef object.

:param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and
computation units of the TensorFlow graph.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
assert isinstance(graph_def, tf.GraphDef), \
('expect tf.GraphDef type but got', type(graph_def))

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
fetch_names=fetch_names)

@classmethod
def fromCheckpoint(cls, checkpoint_dir, feed_names, fetch_names):
"""
Construct a TFInputGraph object from a checkpoint, ignore the embedded
signature_def, if there is any.

:param checkpoint_dir: str, name of the directory containing the TensorFlow graph
training checkpoint.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
return _from_checkpoint_impl(checkpoint_dir, signature_def_key=None, feed_names=feed_names,
fetch_names=fetch_names)

@classmethod
def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key):
"""
Construct a TFInputGraph object from a checkpoint, using the embedded
signature_def. Throw an error if we cannot find an entry with the `signature_def_key`
inside the `signature_def`.

:param checkpoint_dir: str, name of the directory containing the TensorFlow graph
training checkpoint.
:param signature_def_key: str, key (name) of the signature_def to use. It should be in
the list of `signature_def` structures saved with the checkpoint.
"""
assert signature_def_key is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you checking for some parameters but not some others? This is python, there is only so much you can do.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be vocal about signature_def_key must not be empty.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

return _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names=None,
fetch_names=None)

@classmethod
def fromSavedModel(cls, saved_model_dir, tag_set, feed_names, fetch_names):
"""
Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory.
Ignore the the embedded signature_def, if there is any.

:param saved_model_dir: str, name of the directory containing the TensorFlow graph
training checkpoint.
:param tag_set: str, name of the graph stored in this meta_graph of the saved model
that we are interested in using.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=None,
feed_names=feed_names, fetch_names=fetch_names)

@classmethod
def fromSavedModelWithSignature(cls, saved_model_dir, tag_set, signature_def_key):
"""
Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory,
using the embedded signature_def. Throw error if we cannot find an entry with
the `signature_def_key` inside the `signature_def`.

:param saved_model_dir: str, name of the directory containing the TensorFlow graph
training checkpoint.
:param tag_set: str, name of the graph stored in this meta_graph of the saved model
that we are interested in using.
:param signature_def_key: str, key (name) of the signature_def to use. It should be in
the list of `signature_def` structures saved with the
TensorFlow `SavedModel`.
"""
assert signature_def_key is not None
return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=signature_def_key,
feed_names=None, fetch_names=None)


def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names):
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly given the docs above we don't really need the docs for this function and _from_saved..._impl below. do these appear in the generated docs? not sure what we do for docs with "private" methods.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methods whose name starts with '_' do not appear in the docs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good to know, thanks for the clarification.

Construct a TFInputGraph from a model checkpoint.
Notice that one should either provide the `signature_def_key` or provide both

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fact that you need to provide an alternative is a very strong indication that you should break it into 2 functions, and that your logic is more complicated than it ought to be. From the sig key, you can extract the feed and fetch names / mapping and then call the same code path. If we had unlimited time and effort, I would ask you to do this change, but the current logic is understandable enough as it stands to debug it.

`feed_names` and `fetch_names`. Please set the unprovided values to None.

:param signature_def_key: str, name of the mapping contained inside the `signature_def`
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

from which we retrieve the signature key to tensor names mapping.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
assert (feed_names is None) == (fetch_names is None), \
'feed_names and fetch_names, if provided must be both non-None.'
assert (feed_names is None) != (signature_def_key is None), \
'must either provide feed_names or singnature_def_key'

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
# Load checkpoint and import the graph
ckpt_path = tf.train.latest_checkpoint(checkpoint_dir)

# NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def
# the current `import_graph_def` function seems to ignore
# any signature_def fields in a checkpoint's meta_graph_def.
meta_graph_def = meta_graph_pb2.MetaGraphDef()
with open("{}.meta".format(ckpt_path), 'rb') as fin:
meta_graph_def.ParseFromString(fin.read())

saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True)
saver.restore(sess, ckpt_path)

if signature_def_key is not None:
sig_def = meta_graph_def.signature_def[signature_def_key]
return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def)
else:
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
fetch_names=fetch_names)

def _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key, feed_names, fetch_names):
"""
Construct a TFInputGraph from a SavedModel.
Notice that one should either provide the `signature_def_key` or provide both
`feed_names` and `fetch_names`. Please set the unprovided values to None.

:param signature_def_key: str, name of the mapping contained inside the `signature_def`
from which we retrieve the signature key to tensor names mapping.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
assert (feed_names is None) == (fetch_names is None), \
'feed_names and fetch_names, if provided must appear together'
assert (feed_names is None) != (signature_def_key is None), \
'must either provide feed_names or singnature_def_key'

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tag_sets = tag_set.split(',')
meta_graph_def = tf.saved_model.loader.load(sess, tag_sets, saved_model_dir)

if signature_def_key is not None:
sig_def = tf.contrib.saved_model.get_signature_def_by_key(meta_graph_def,
signature_def_key)
return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def)
else:
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
fetch_names=fetch_names)


def _build_with_sig_def(sess, graph, sig_def):
# pylint: disable=protected-access
assert sig_def, 'signature_def must not be None'

with sess.as_default(), graph.as_default():
feed_mapping = {}
feed_names = []
for sigdef_key, tnsr_info in sig_def.inputs.items():
tnsr_name = tnsr_info.name
feed_mapping[sigdef_key] = tnsr_name
feed_names.append(tnsr_name)

fetch_mapping = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add some tests that specifically check that these mappings are created correctly.

fetch_names = []
for sigdef_key, tnsr_info in sig_def.outputs.items():
tnsr_name = tnsr_info.name
fetch_mapping[sigdef_key] = tnsr_name
fetch_names.append(tnsr_name)

for tnsr_name in feed_names:
assert tfx.get_op(tnsr_name, graph), \
'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names]
graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping,
output_tensor_name_from_signature=fetch_mapping)


def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names):
assert feed_names is not None, "must provide feed_names"
assert fetch_names is not None, "must provide fetch names"

with sess.as_default(), graph.as_default():
for tnsr_name in feed_names:
assert tfx.get_op(tnsr_name, graph), \
'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names]
graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None,
output_tensor_name_from_signature=None)
9 changes: 8 additions & 1 deletion python/sparkdl/param/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,14 @@ def _check_is_tensor_name(_maybe_tnsr_name):
raise TypeError(err_msg.format(type(_maybe_tnsr_name)))

# The check is taken from TensorFlow's NodeDef protocol buffer.
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
# Each input is "node:src_output" with "node" being a string name and

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good point. In practice, most nodes have only one output, so I have not been too concerned about multiple outputs.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that was copied-and-pasted from comments inTensorFlow node_def.

# "src_output" indicating which output tensor to use from "node". If
# "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
# may optionally be followed by control inputs that have the format
# "^node".
# Reference:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto
# https://stackoverflow.com/questions/36150834/how-does-tensorflow-name-tensors
try:
_, src_idx = _maybe_tnsr_name.split(":")
_ = int(src_idx)
Expand Down
Loading