Skip to content

TensorFlow Transformer Part-3 #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Nov 22, 2017
Merged

Conversation

phi-dbq
Copy link
Owner

@phi-dbq phi-dbq commented Sep 19, 2017

Introduces TFInputGraph and tests
TFInputGraph is used as the internal storage of the TensorFlow graph for the TFTransformer.

@phi-dbq phi-dbq force-pushed the tf-transformer-part3 branch from e1593c4 to cd3aa8d Compare September 19, 2017 22:24
@codecov-io
Copy link

codecov-io commented Sep 19, 2017

Codecov Report

Merging #10 into tf-transformer-part2 will increase coverage by 0.87%.
The diff coverage is 97.59%.

Impacted file tree graph

@@                   Coverage Diff                    @@
##           tf-transformer-part2      #10      +/-   ##
========================================================
+ Coverage                 83.13%   84.01%   +0.87%     
========================================================
  Files                        24       25       +1     
  Lines                      1293     1376      +83     
  Branches                      5        5              
========================================================
+ Hits                       1075     1156      +81     
- Misses                      218      220       +2
Impacted Files Coverage Δ
python/sparkdl/param/converters.py 82.66% <ø> (ø) ⬆️
python/sparkdl/graph/input.py 97.59% <97.59%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 63967b4...decdc8f. Read the comment docs.

Copy link

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll need to iterate a bit here to simplify the structure and logic - let's start with the comments here.

@classmethod
def fromGraph(cls, graph, sess, feed_names, fetch_names):
"""
Construct a TFInputGraphBuilder from a in memory tf.Graph object
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: an in-memory

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, this returns a TFInputGraph not TFInputGraphBuilder right?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Construct a TFInputGraphBuilder from a in memory tf.Graph object
"""
assert isinstance(graph, tf.Graph), \
('expect tf.Graph type but got', type(graph))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expected

feed_names=None, fetch_names=None)

@classmethod
def _from_checkpoint_impl(cls,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: really don't need _impl at the end. _from_checkpoint is clear enough. same for _from_saved_model_impl below.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names are changed as part of the refactoring

feed_names=None,
fetch_names=None):
"""
Construct a TFInputGraphBuilder from a model checkpoint
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFInputGraph

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed doc and checked spelling with pylint + pyenchant

feed_names=None,
fetch_names=None):
"""
Construct a TFInputGraphBuilder from a SavedModel
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFInputGraph

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed doc and checked spelling with pylint + pyenchant

@classmethod
def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key):
assert signature_def_key is not None
return cls._from_checkpoint_impl(checkpoint_dir,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: the first two argument should fit in the first line

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YAPF'ed

self.fetch_mapping[sigdef_key] = tnsr_name
self.fetch_names.append(tnsr_name)

class _GinBuilder(object):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this object is never used twice in our usage cases. let's make it a single function instead. then we don't have to worry about cleaning up state, e.g. the session.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed all classes and builder objects.

return _GinBuilder(import_graph_fn).build(feed_names, fetch_names)


class _GinBuilderInfo(object):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be really more like _SigDefInfo - let's name it more clearly.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed all classes and builder objects.

# pylint: disable=protected-access,attribute-defined-outside-init
gin = TFInputGraph._new_obj_internal()
assert (feed_names is None) == (fetch_names is None)
must_have_sig_def = fetch_names is None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need this variable since it's only used once. just use if fetch_names is None where we use this variable below.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is easier to reason with when assigned to a variable. For the time being fetch_name is None is the only requirement, but it might change in the future. And if it does change, one would only have to modify it here.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it changes in the future we can use a variable.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the boolean variable

"Please do NOT construct TFInputGraph directly. Instead, use one of the helper functions")

@classmethod
def _new_obj_internal(cls):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safer to just have this take in all three member variables to set without None default values. the way we use it, we can definitely construct this way.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


import sparkdl.graph.utils as tfx

__all__ = ["TFInputGraph"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole file has much more complexity that it really needs to, given the aim of its content.

We do not need a builder pattern, based on looking at the other PRs. You also happen to expose these builder classes implicitly because you return them, but this is not transparent from __all__. This is a limitation of python that does not help for software engineering.

Once we drop that, all this code really needs to expose is one immutable data structure and six functions that are stateless. As we discussed, there is no need to create additional classes, static methods or other python features. Everything here can be done with regular functions. Here is what I think it should look like:

TFInputGraph=namedtuple(["graph_def", "inputs", "outputs"])
"""
A frozen representation of a tensorflow graph, and some extra information to map spark columns to
the graph's input and output tensors. Users should not have to peek into its content, or make
any assumption about the content.

graph_def: a tensorflow GraphDef object
inputs: a dictionary of {string: string} where the key is XXX and the value is XXX
outputs: XXX
"""

def fromGraph(graph, sess, feed_names, fetch_names):
    """
    Builds an internal representation of a tensorflow graph, based on a tf.Graph object
    :param graph: XXX
    :param sess: XXX
    :param feed_names: XXX <- put the details here!
    :param fetch_names: XXX <- put the details here!
    :return: a TFInputGraph object
    """
    raise

... other public functions with full documentation

... all the other private functions. They can be lighter on documentation, but they should not require any sophisticated python features. No extra classes are required.

You should also take a look at named tuples for building them, but there you should absolutely not need any of the more advanced features for them (for the time being).
https://docs.python.org/2/library/collections.html#collections.namedtuple

This is a significant rewrite, but it is going to dramatically improve the readability and the correctness in the face of future changes.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using classmethod to create class instances is a well-accepted approach. It is commonly used in CPython implementation of datetime.

Classes created from collections.namedtuple are self-documenting, resembling Scala's case classes to some degree.

>>> ABC = namedtuple('ABC', 'a, b, c', verbose=True)
class ABC(tuple):
    'ABC(a, b, c)'

    __slots__ = ()

    _fields = ('a', 'b', 'c')

    def __new__(_cls, a, b, c):
        'Create new instance of ABC(a, b, c)'
        return _tuple.__new__(_cls, (a, b, c))

    @classmethod
    def _make(cls, iterable, new=tuple.__new__, len=len):
        'Make a new ABC object from a sequence or iterable'
        result = new(cls, iterable)
        if len(result) != 3:
            raise TypeError('Expected 3 arguments, got %d' % len(result))
        return result

    def __repr__(self):
        'Return a nicely formatted representation string'
        return 'ABC(a=%r, b=%r, c=%r)' % self

    def _asdict(self):
        'Return a new OrderedDict which maps field names to their values'
        return OrderedDict(zip(self._fields, self))

    def _replace(_self, **kwds):
        'Return a new ABC object replacing specified fields with new values'
        result = _self._make(map(kwds.pop, ('a', 'b', 'c'), _self))
        if kwds:
            raise ValueError('Got unexpected field names: %r' % kwds.keys())
        return result

    def __getnewargs__(self):
        'Return self as a plain tuple.  Used by copy and pickle.'
        return tuple(self)

    __dict__ = _property(_asdict)

    def __getstate__(self):
        'Exclude the OrderedDict from pickling'
        pass

    a = _property(_itemgetter(0), doc='Alias for field number 0')

    b = _property(_itemgetter(1), doc='Alias for field number 1')

    c = _property(_itemgetter(2), doc='Alias for field number 2')

You can see that they provide a lot more functions than we actually need.
In addition, since the class inherit from tuple, the following should be noted.

>>> ABC = namedtuple('ABC', 'a, b, c'); isinstance(ABC(a=1, b=2, c=3), tuple)
True

Copy link

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here are some initial comments on the test code flow

ref_feed = tfx.get_tensor(graph, self.input_op_name)
ref_fetch = tfx.get_tensor(graph, self.output_op_name)

def check_input_graph(tgt_gdef, test_idx):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just put the internals of the function inside the for-loop below since the function is not used anywhere else

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is another layer of indentation and it is hard to track which is aligned to which outside an editor.

_ = tf.reduce_mean(x, axis=1, name=self.output_op_name)

gin = TFInputGraph.fromGraph(sess.graph, sess, self.feed_names, self.fetch_names)
self.input_graphs.append(gin)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just repurpose _run_test_in_tf_session to run the test directly inside of these functions without having to save the various graphs inside the test class. that way it's more clear what tests fail.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The testing function can be refactored in the same way that we did for PR-1
Let's see if we like that first and then we can apply the same changes here.

builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)

with self._run_test_in_tf_session() as sess:
# Model definition: begin
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can all the tests here use the same graph built by calling a helper function? e.g.

def _build_graph():
  g = tf.Graph()
  with g.as_default():
    x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name)
    w = tf.Variable(tf.random_normal([self.vec_size], dtype=tf.float64),
                           dtype=tf.float64, name='varW')
    z = tf.reduce_mean(x * w, axis=1, name=self.output_op_name)
  return g

...

  def test_build_...():
    graph = _build_graph();
    with tf.Session(graph=graph) as sess:
      ...

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this is essentially what we are doing here, isn't it?
Testing functions also serve as example/documentation. Here having the whole workflow in the same place makes it easy for users to understand how to use our library.

@phi-dbq phi-dbq force-pushed the tf-transformer-part3 branch from 342aab4 to e963d11 Compare September 23, 2017 04:39
.. warning: This class should not be called by any user code.
"""

def __init__(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __init__(self, graph_def, ...)
  self.graph_def = graph_def
  ....

An opaque object containing TensorFlow graph.
This object can be serialized.

.. warning: This class should not be called by any user code.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document fields here (mentioning they are implementation details that should be not be relied on by users).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded documentation.

@phi-dbq
Copy link
Owner Author

phi-dbq commented Sep 26, 2017

TFInputGraph impl refactor is done. Tests changes pending on previous PR reviews.
@thunterdb @sueann

Copy link

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much more readable! I haven't reviewed the large docstring yet but have reviewed the code so that part is ready for you. Could you put a screenshot of the docs (at least for TFInputGraph) here for easier review (of the formatting etc).

"""


def __init__(self, graph_def, input_tensor_name_from_signature,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wish we had better names for these maps... naming is so hard. sig_key_to_input_tensor_names, sig_key_to_output_tensor_names ? the problem is, "_from_" generally doesn't hint that it's a map.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Naming things is difficult, especially in this case.
TF does not seem to have a proper name for "well_known_input_sig" yet.
And signature_def_key refers to "well_known_prediction_signature" in our example.
Maybe input_signature_to_tensor_name and output_signature_to_tensor_name?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i think those are actually good, since it conveys the meaning that it's mapping from some signature thing to tensor, and really we don't need the user to know exactly what the keys are -- users aren't expected to use these directly.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as the conversion methods (in part 4) make it clear what their inputs are.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an internal variable that users should not access directly, let's keep is as it is.


with tf.Session() as sess:
graph = import_my_tensorflow_graph(...)
TFInputGraph.fromGraph(graph, sess, ...)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input = TFInputGraph.fromGraph(graph, sess, ...)

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's different if you return here instead of outside the session block? are there implications for the session? if it's the same, it'd be better to return here for simplicity.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside the with tf.Session(graph=graph) as sess context manager block, the session will be closed.
In this case _build_with_feeds_fetches will fail.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although if we return inside the block, it is unclear to me if the session will be closed after the return statement.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like with tf.Session(graph=graph) as sess: ... return is equivalent to something like

try: 
  sess = tf.Session(graph=graph).__enter__()
   ... 
  return ...
finally: 
  tf.Session.__exit__()

(not 100% sure about the exact syntax for the context manager calling __enter__ and __exit__)
so it should be okay to return inside. Essentially all the bookkeeping done by the with statement still happens.

https://www.python.org/dev/peps/pep-0343/

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, return inside.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with returning from inside. But I don't think there are any substantial differences for returning inside v.s. outside.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change that, this is the style of this project.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From readability's perspective, returning outside signals that the returned value does not need the resource held by the context manager to operate.
For a user who is unaware of the way context managers work, "return-outside" provides makes our intention clear to him/her.


def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names):
# pylint: disable=protected-access,attribute-defined-outside-init
assert (feed_names is not None) and (fetch_names is not None), \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is necessary, but if you must, I'd do it one by one so it's more obvious when the error is thrown:

assert (feed_names is not None), "must provide feed_names"
assert (fetch_names is not None), "must provide fetch names"

also, is it bad if they are empty? if so, should check for that and modify the error msg above to say "must provide non-empty ..."

def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key):
"""
Construct a TFInputGraph object from a checkpoint, using the embedded
signature_def. Throw error if we cannot find an entry with the `signature_def_key`
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: an error

Notice that one should either provide the `signature_def_key` or provide both
`feed_names` and `fetch_names`. Please set the unprovided values to None.

:param signature_def_key: str, name of the mapping contained inside the `signature_def`
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here



def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names):
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly given the docs above we don't really need the docs for this function and _from_saved..._impl below. do these appear in the generated docs? not sure what we do for docs with "private" methods.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methods whose name starts with '_' do not appear in the docs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good to know, thanks for the clarification.

:param fetch_names: list, names of the output tensors.
"""
assert (feed_names is None) == (fetch_names is None), \
'feed_names and fetch_names, if provided must appear together'
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"must appear together" -> "must be both non-None"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-dbq did you miss this comment?


def _build_with_sig_def(sess, graph, sig_def):
# pylint: disable=protected-access,attribute-defined-outside-init
assert sig_def, \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assumes how sig_def came about, which this function shouldn't care about. you can just say "sig_def must not be None." and people can debug from there.

feed_mapping[sigdef_key] = tnsr_name
feed_names.append(tnsr_name)

fetch_mapping = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add some tests that specifically check that these mappings are created correctly.

@phi-dbq phi-dbq force-pushed the tf-transformer-part2 branch from 202e7ea to ead1ed6 Compare September 29, 2017 16:48
Copy link
Owner Author

@phi-dbq phi-dbq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly done, apart from adding test for mappings.



def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names):
"""
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methods whose name starts with '_' do not appear in the docs.

"""


def __init__(self, graph_def, input_tensor_name_from_signature,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Naming things is difficult, especially in this case.
TF does not seem to have a proper name for "well_known_input_sig" yet.
And signature_def_key refers to "well_known_prediction_signature" in our example.
Maybe input_signature_to_tensor_name and output_signature_to_tensor_name?

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside the with tf.Session(graph=graph) as sess context manager block, the session will be closed.
In this case _build_with_feeds_fetches will fail.

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although if we return inside the block, it is unclear to me if the session will be closed after the return statement.

@phi-dbq phi-dbq force-pushed the tf-transformer-part2 branch from 26c8f24 to d729528 Compare October 4, 2017 01:48
Copy link

@thunterdb thunterdb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-dbq there are still some changes that would be good to do. I am happy to talk about them in person.

- :py:meth:`fromSavedModelWithSignature`


When the graph contains serving signatures in which a set of well-known names are associtated

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, return inside.

:param signature_def_key: str, key (name) of the signature_def to use. It should be in
the list of `signature_def` structures saved with the checkpoint.
"""
assert signature_def_key is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you checking for some parameters but not some others? This is python, there is only so much you can do.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be vocal about signature_def_key must not be empty.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.


if signature_def_key is not None:
sig_def = meta_graph_def.signature_def[signature_def_key]
gin = _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return

sig_def = meta_graph_def.signature_def[signature_def_key]
gin = _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def)
else:
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return

feed_mapping[sigdef_key] = tnsr_name
feed_names.append(tnsr_name)

# TODO: IN-THIS-PR, test if these mappings are constructed correctly.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noting the TODO.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a (partial) test for that, let's remove the TODO.

@@ -167,7 +167,14 @@ def _check_is_tensor_name(_maybe_tnsr_name):
raise TypeError(err_msg.format(type(_maybe_tnsr_name)))

# The check is taken from TensorFlow's NodeDef protocol buffer.
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
# Each input is "node:src_output" with "node" being a string name and

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good point. In practice, most nodes have only one output, so I have not been too concerned about multiple outputs.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that was copied-and-pasted from comments inTensorFlow node_def.

#========================================================================
# Don't have to modify the content below

_TEST_CASES_GENERATORS = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too complicated, and I do not see some reasons that would warrant such complexity. Please write some normal tests, with parametrized for example.

self.output_mapping = {self.output_op_name: self.output_col}
self.fetch_names = [self.output_op_name + ':0']

@contextmanager

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too complicated and if something goes wrong, you will be hard pressed to understand what is going on.

You should really generate some data, transform this data and compare it against some expected output, and clean up some stuff after that if necessary.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, that is exactly what this class does. I checked randomly killing the tests and made sure that the stack traces are meaningful.

class TestGenBase(object):
def __init__(self, vec_size=17, test_batch_size=231):
# Testing data spec
self.vec_size = vec_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class has an enormous amount of state and moving parts, I cannot keep track what is happening. Please encapsulate what you need to pass to the relevant pieces.

Copy link
Owner Author

@phi-dbq phi-dbq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thunterdb Let's walk through the testing code once you are back.
The initial version of the testing code was written to directly construct examples and run tests. But the complexity went up quickly once more testing examples were needed.
This version strived to keep the graph construction and comparing numerical results separately, making the code more concise while still allow new example graphs to be added easily.
In the next PR, we show how one could inherit this class and change the testing logic, while still being able to use existing examples.

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with returning from inside. But I don't think there are any substantial differences for returning inside v.s. outside.

:param signature_def_key: str, key (name) of the signature_def to use. It should be in
the list of `signature_def` structures saved with the checkpoint.
"""
assert signature_def_key is not None
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be vocal about signature_def_key must not be empty.

self.output_mapping = {self.output_op_name: self.output_col}
self.fetch_names = [self.output_op_name + ':0']

@contextmanager
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, that is exactly what this class does. I checked randomly killing the tests and made sure that the stack traces are meaningful.

@@ -167,7 +167,14 @@ def _check_is_tensor_name(_maybe_tnsr_name):
raise TypeError(err_msg.format(type(_maybe_tnsr_name)))

# The check is taken from TensorFlow's NodeDef protocol buffer.
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
# Each input is "node:src_output" with "node" being a string name and
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that was copied-and-pasted from comments inTensorFlow node_def.

@phi-dbq
Copy link
Owner Author

phi-dbq commented Oct 11, 2017

@thunterdb @sueann I think the refactoring of testing infrastructure is outside the scope of this task/PR. I would recommend moving that to another task/PR.

@phi-dbq phi-dbq added this to the release 0.2 milestone Oct 12, 2017
* profiling

* tests

* renamed test

* removed original tests

* removed the profiler utils

* fixes indents

* imports

* added some tests

* added test

* fix test

* one more test
Copy link

@thunterdb thunterdb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-dbq still a few small comments from the previous review.

"""


def __init__(self, graph_def, input_tensor_name_from_signature,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an internal variable that users should not access directly, let's keep is as it is.

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change that, this is the style of this project.



def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names):
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good to know, thanks for the clarification.

feed_mapping[sigdef_key] = tnsr_name
feed_names.append(tnsr_name)

# TODO: IN-THIS-PR, test if these mappings are constructed correctly.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a (partial) test for that, let's remove the TODO.

:param signature_def_key: str, key (name) of the signature_def to use. It should be in
the list of `signature_def` structures saved with the checkpoint.
"""
assert signature_def_key is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

@phi-dbq
Copy link
Owner Author

phi-dbq commented Nov 22, 2017

@thunterdb I think I addressed your comments. Would you like to take another look? Thanks!

Copy link

@thunterdb thunterdb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

* flat param API impl

* support input graph scenarios

* (WIP) new interface implementation

* docs and cleanup

* using tensorflow API instead of our utilities

* automatic type conversion

* cleanup

* PR comments

1. Move `InputGraph` to its module.

* (WIP) address comments

* (WIP) respond to PR comments

* test refactor

* (wip) consolidating params

* rebase upstream

* import params fix

* (wip) TFInputGraph impl

* (wip) moving to new API

* (wip) enable saved_model tests

* (wip) enable checkpoint test

* (wip) enable multiple tensor tests

* enable all tests

* optimize graph for inference

* allows setting TFInputGraph

* utilize test_input_graph for transformer tests

* enable all tests

Signed-off-by: Philip Yang <philip.yang@databricks.com>

* input graph

* docs

* tensor tests

* tensor test update

* TFTransformer Part-4 Test Refactor (#15)

* adding new tests

* remove original test design

* cleanup

* deleting original testing ideas

* PR comments
@phi-dbq phi-dbq merged commit de088a9 into tf-transformer-part2 Nov 22, 2017
phi-dbq added a commit that referenced this pull request Nov 22, 2017
* update utils

* tests

* fix style

Using the following YAPF style
========================================================
based_on_style = pep8
ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False
COLUMN_LIMIT=100
SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False
SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True
SPLIT_BEFORE_FIRST_ARGUMENT=False
SPLIT_BEFORE_NAMED_ASSIGNS=False
SPLIT_PENALTY_AFTER_OPENING_BRACKET=30
USE_TABS=False
========================================================

* refactoring tfx API

* test refactoring

* PR comments

1. docs in graph/utils.py

* (wip) utils test

* a few more tests for utils

* test update cont'd

* PR comments

* PR comments

* PR comments

* TensorFlow Transformer Part-3 (#10)

* intro: TFInputGraph

* tests

* Merge branch 'tf-transformer-part1' into tf-transformer-part3

* and so there is no helper classes

* and into more pieces

* class & docs

* update docs

* refactoring tfx API

* update tfx utils usage

* one way to build these tests

* tests refactored

* test cases in a single class

THis will make things easier when we want to extend other base class functions.

* shuffle things around

Signed-off-by: Philip Yang <philip.yang@databricks.com>

* docs mostly

* yapf'd

* consolidate tempdir creation

* (wip) PR comments

* more tests

* change test generator module name

* TFTransformer Part-3 Test Refactor (#14)

* profiling

* tests

* renamed test

* removed original tests

* removed the profiler utils

* fixes indents

* imports

* added some tests

* added test

* fix test

* one more test

* PR comments

* TensorFlow Transformer Part-4 (#11)

* flat param API impl

* support input graph scenarios

* (WIP) new interface implementation

* docs and cleanup

* using tensorflow API instead of our utilities

* automatic type conversion

* cleanup

* PR comments

1. Move `InputGraph` to its module.

* (WIP) address comments

* (WIP) respond to PR comments

* test refactor

* (wip) consolidating params

* rebase upstream

* import params fix

* (wip) TFInputGraph impl

* (wip) moving to new API

* (wip) enable saved_model tests

* (wip) enable checkpoint test

* (wip) enable multiple tensor tests

* enable all tests

* optimize graph for inference

* allows setting TFInputGraph

* utilize test_input_graph for transformer tests

* enable all tests

Signed-off-by: Philip Yang <philip.yang@databricks.com>

* input graph

* docs

* tensor tests

* tensor test update

* TFTransformer Part-4 Test Refactor (#15)

* adding new tests

* remove original test design

* cleanup

* deleting original testing ideas

* PR comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants