-
Notifications
You must be signed in to change notification settings - Fork 0
TensorFlow Transformer Part-3 #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e1593c4
to
cd3aa8d
Compare
Codecov Report
@@ Coverage Diff @@
## tf-transformer-part2 #10 +/- ##
========================================================
+ Coverage 83.13% 84.01% +0.87%
========================================================
Files 24 25 +1
Lines 1293 1376 +83
Branches 5 5
========================================================
+ Hits 1075 1156 +81
- Misses 218 220 +2
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll need to iterate a bit here to simplify the structure and logic - let's start with the comments here.
python/sparkdl/graph/input.py
Outdated
@classmethod | ||
def fromGraph(cls, graph, sess, feed_names, fetch_names): | ||
""" | ||
Construct a TFInputGraphBuilder from a in memory tf.Graph object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: an in-memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, this returns a TFInputGraph not TFInputGraphBuilder right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
python/sparkdl/graph/input.py
Outdated
Construct a TFInputGraphBuilder from a in memory tf.Graph object | ||
""" | ||
assert isinstance(graph, tf.Graph), \ | ||
('expect tf.Graph type but got', type(graph)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: expected
python/sparkdl/graph/input.py
Outdated
feed_names=None, fetch_names=None) | ||
|
||
@classmethod | ||
def _from_checkpoint_impl(cls, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: really don't need _impl
at the end. _from_checkpoint
is clear enough. same for _from_saved_model_impl
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
names are changed as part of the refactoring
python/sparkdl/graph/input.py
Outdated
feed_names=None, | ||
fetch_names=None): | ||
""" | ||
Construct a TFInputGraphBuilder from a model checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFInputGraph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed doc and checked spelling with pylint + pyenchant
python/sparkdl/graph/input.py
Outdated
feed_names=None, | ||
fetch_names=None): | ||
""" | ||
Construct a TFInputGraphBuilder from a SavedModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFInputGraph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed doc and checked spelling with pylint + pyenchant
python/sparkdl/graph/input.py
Outdated
@classmethod | ||
def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key): | ||
assert signature_def_key is not None | ||
return cls._from_checkpoint_impl(checkpoint_dir, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: the first two argument should fit in the first line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YAPF'ed
python/sparkdl/graph/input.py
Outdated
self.fetch_mapping[sigdef_key] = tnsr_name | ||
self.fetch_names.append(tnsr_name) | ||
|
||
class _GinBuilder(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this object is never used twice in our usage cases. let's make it a single function instead. then we don't have to worry about cleaning up state, e.g. the session.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed all classes and builder objects.
python/sparkdl/graph/input.py
Outdated
return _GinBuilder(import_graph_fn).build(feed_names, fetch_names) | ||
|
||
|
||
class _GinBuilderInfo(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be really more like _SigDefInfo - let's name it more clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed all classes and builder objects.
python/sparkdl/graph/input.py
Outdated
# pylint: disable=protected-access,attribute-defined-outside-init | ||
gin = TFInputGraph._new_obj_internal() | ||
assert (feed_names is None) == (fetch_names is None) | ||
must_have_sig_def = fetch_names is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need this variable since it's only used once. just use if fetch_names is None
where we use this variable below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is easier to reason with when assigned to a variable. For the time being fetch_name is None
is the only requirement, but it might change in the future. And if it does change, one would only have to modify it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it changes in the future we can use a variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the boolean variable
python/sparkdl/graph/input.py
Outdated
"Please do NOT construct TFInputGraph directly. Instead, use one of the helper functions") | ||
|
||
@classmethod | ||
def _new_obj_internal(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safer to just have this take in all three member variables to set without None default values. the way we use it, we can definitely construct this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
python/sparkdl/graph/input.py
Outdated
|
||
import sparkdl.graph.utils as tfx | ||
|
||
__all__ = ["TFInputGraph"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this whole file has much more complexity that it really needs to, given the aim of its content.
We do not need a builder pattern, based on looking at the other PRs. You also happen to expose these builder classes implicitly because you return them, but this is not transparent from __all__
. This is a limitation of python that does not help for software engineering.
Once we drop that, all this code really needs to expose is one immutable data structure and six functions that are stateless. As we discussed, there is no need to create additional classes, static methods or other python features. Everything here can be done with regular functions. Here is what I think it should look like:
TFInputGraph=namedtuple(["graph_def", "inputs", "outputs"])
"""
A frozen representation of a tensorflow graph, and some extra information to map spark columns to
the graph's input and output tensors. Users should not have to peek into its content, or make
any assumption about the content.
graph_def: a tensorflow GraphDef object
inputs: a dictionary of {string: string} where the key is XXX and the value is XXX
outputs: XXX
"""
def fromGraph(graph, sess, feed_names, fetch_names):
"""
Builds an internal representation of a tensorflow graph, based on a tf.Graph object
:param graph: XXX
:param sess: XXX
:param feed_names: XXX <- put the details here!
:param fetch_names: XXX <- put the details here!
:return: a TFInputGraph object
"""
raise
... other public functions with full documentation
... all the other private functions. They can be lighter on documentation, but they should not require any sophisticated python features. No extra classes are required.
You should also take a look at named tuples for building them, but there you should absolutely not need any of the more advanced features for them (for the time being).
https://docs.python.org/2/library/collections.html#collections.namedtuple
This is a significant rewrite, but it is going to dramatically improve the readability and the correctness in the face of future changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using classmethod
to create class instances is a well-accepted approach. It is commonly used in CPython implementation of datetime.
Classes created from collections.namedtuple
are self-documenting, resembling Scala's case classes to some degree.
>>> ABC = namedtuple('ABC', 'a, b, c', verbose=True)
class ABC(tuple):
'ABC(a, b, c)'
__slots__ = ()
_fields = ('a', 'b', 'c')
def __new__(_cls, a, b, c):
'Create new instance of ABC(a, b, c)'
return _tuple.__new__(_cls, (a, b, c))
@classmethod
def _make(cls, iterable, new=tuple.__new__, len=len):
'Make a new ABC object from a sequence or iterable'
result = new(cls, iterable)
if len(result) != 3:
raise TypeError('Expected 3 arguments, got %d' % len(result))
return result
def __repr__(self):
'Return a nicely formatted representation string'
return 'ABC(a=%r, b=%r, c=%r)' % self
def _asdict(self):
'Return a new OrderedDict which maps field names to their values'
return OrderedDict(zip(self._fields, self))
def _replace(_self, **kwds):
'Return a new ABC object replacing specified fields with new values'
result = _self._make(map(kwds.pop, ('a', 'b', 'c'), _self))
if kwds:
raise ValueError('Got unexpected field names: %r' % kwds.keys())
return result
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return tuple(self)
__dict__ = _property(_asdict)
def __getstate__(self):
'Exclude the OrderedDict from pickling'
pass
a = _property(_itemgetter(0), doc='Alias for field number 0')
b = _property(_itemgetter(1), doc='Alias for field number 1')
c = _property(_itemgetter(2), doc='Alias for field number 2')
You can see that they provide a lot more functions than we actually need.
In addition, since the class inherit from tuple
, the following should be noted.
>>> ABC = namedtuple('ABC', 'a, b, c'); isinstance(ABC(a=1, b=2, c=3), tuple)
True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here are some initial comments on the test code flow
ref_feed = tfx.get_tensor(graph, self.input_op_name) | ||
ref_fetch = tfx.get_tensor(graph, self.output_op_name) | ||
|
||
def check_input_graph(tgt_gdef, test_idx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just put the internals of the function inside the for-loop below since the function is not used anywhere else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it is another layer of indentation and it is hard to track which is aligned to which outside an editor.
_ = tf.reduce_mean(x, axis=1, name=self.output_op_name) | ||
|
||
gin = TFInputGraph.fromGraph(sess.graph, sess, self.feed_names, self.fetch_names) | ||
self.input_graphs.append(gin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just repurpose _run_test_in_tf_session
to run the test directly inside of these functions without having to save the various graphs inside the test class. that way it's more clear what tests fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The testing function can be refactored in the same way that we did for PR-1
Let's see if we like that first and then we can apply the same changes here.
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) | ||
|
||
with self._run_test_in_tf_session() as sess: | ||
# Model definition: begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can all the tests here use the same graph built by calling a helper function? e.g.
def _build_graph():
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name)
w = tf.Variable(tf.random_normal([self.vec_size], dtype=tf.float64),
dtype=tf.float64, name='varW')
z = tf.reduce_mean(x * w, axis=1, name=self.output_op_name)
return g
...
def test_build_...():
graph = _build_graph();
with tf.Session(graph=graph) as sess:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, this is essentially what we are doing here, isn't it?
Testing functions also serve as example/documentation. Here having the whole workflow in the same place makes it easy for users to understand how to use our library.
342aab4
to
e963d11
Compare
python/sparkdl/graph/input.py
Outdated
.. warning: This class should not be called by any user code. | ||
""" | ||
|
||
def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, graph_def, ...)
self.graph_def = graph_def
....
python/sparkdl/graph/input.py
Outdated
An opaque object containing TensorFlow graph. | ||
This object can be serialized. | ||
|
||
.. warning: This class should not be called by any user code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
document fields here (mentioning they are implementation details that should be not be relied on by users).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expanded documentation.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks much more readable! I haven't reviewed the large docstring yet but have reviewed the code so that part is ready for you. Could you put a screenshot of the docs (at least for TFInputGraph) here for easier review (of the formatting etc).
""" | ||
|
||
|
||
def __init__(self, graph_def, input_tensor_name_from_signature, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wish we had better names for these maps... naming is so hard. sig_key_to_input_tensor_names, sig_key_to_output_tensor_names ? the problem is, "_from_" generally doesn't hint that it's a map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Naming things is difficult, especially in this case.
TF does not seem to have a proper name for "well_known_input_sig"
yet.
And signature_def_key
refers to "well_known_prediction_signature"
in our example.
Maybe input_signature_to_tensor_name
and output_signature_to_tensor_name
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah i think those are actually good, since it conveys the meaning that it's mapping from some signature thing to tensor, and really we don't need the user to know exactly what the keys are -- users aren't expected to use these directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as the conversion methods (in part 4) make it clear what their inputs are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is an internal variable that users should not access directly, let's keep is as it is.
python/sparkdl/graph/input.py
Outdated
|
||
with tf.Session() as sess: | ||
graph = import_my_tensorflow_graph(...) | ||
TFInputGraph.fromGraph(graph, sess, ...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input = TFInputGraph.fromGraph(graph, sess, ...)
python/sparkdl/graph/input.py
Outdated
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tf.import_graph_def(graph_def, name='') | ||
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's different if you return here instead of outside the session block? are there implications for the session? if it's the same, it'd be better to return here for simplicity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outside the with tf.Session(graph=graph) as sess
context manager block, the session will be closed.
In this case _build_with_feeds_fetches
will fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although if we return inside the block, it is unclear to me if the session will be closed after the return statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like with tf.Session(graph=graph) as sess: ... return
is equivalent to something like
try:
sess = tf.Session(graph=graph).__enter__()
...
return ...
finally:
tf.Session.__exit__()
(not 100% sure about the exact syntax for the context manager calling __enter__
and __exit__
)
so it should be okay to return inside. Essentially all the bookkeeping done by the with statement still happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, return inside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am fine with returning from inside. But I don't think there are any substantial differences for returning inside v.s. outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change that, this is the style of this project.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From readability's perspective, returning outside signals that the returned value does not need the resource held by the context manager to operate.
For a user who is unaware of the way context managers work, "return-outside" provides makes our intention clear to him/her.
python/sparkdl/graph/input.py
Outdated
|
||
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names): | ||
# pylint: disable=protected-access,attribute-defined-outside-init | ||
assert (feed_names is not None) and (fetch_names is not None), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this check is necessary, but if you must, I'd do it one by one so it's more obvious when the error is thrown:
assert (feed_names is not None), "must provide feed_names"
assert (fetch_names is not None), "must provide fetch names"
also, is it bad if they are empty? if so, should check for that and modify the error msg above to say "must provide non-empty ..."
python/sparkdl/graph/input.py
Outdated
def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key): | ||
""" | ||
Construct a TFInputGraph object from a checkpoint, using the embedded | ||
signature_def. Throw error if we cannot find an entry with the `signature_def_key` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: an error
Notice that one should either provide the `signature_def_key` or provide both | ||
`feed_names` and `fetch_names`. Please set the unprovided values to None. | ||
|
||
:param signature_def_key: str, name of the mapping contained inside the `signature_def` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
|
||
|
||
def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly given the docs above we don't really need the docs for this function and _from_saved..._impl
below. do these appear in the generated docs? not sure what we do for docs with "private" methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Methods whose name starts with '_' do not appear in the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good to know, thanks for the clarification.
:param fetch_names: list, names of the output tensors. | ||
""" | ||
assert (feed_names is None) == (fetch_names is None), \ | ||
'feed_names and fetch_names, if provided must appear together' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"must appear together" -> "must be both non-None"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-dbq did you miss this comment?
python/sparkdl/graph/input.py
Outdated
|
||
def _build_with_sig_def(sess, graph, sig_def): | ||
# pylint: disable=protected-access,attribute-defined-outside-init | ||
assert sig_def, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assumes how sig_def came about, which this function shouldn't care about. you can just say "sig_def must not be None." and people can debug from there.
feed_mapping[sigdef_key] = tnsr_name | ||
feed_names.append(tnsr_name) | ||
|
||
fetch_mapping = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should add some tests that specifically check that these mappings are created correctly.
202e7ea
to
ead1ed6
Compare
THis will make things easier when we want to extend other base class functions.
Signed-off-by: Philip Yang <philip.yang@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly done, apart from adding test for mappings.
|
||
|
||
def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Methods whose name starts with '_' do not appear in the docs.
""" | ||
|
||
|
||
def __init__(self, graph_def, input_tensor_name_from_signature, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Naming things is difficult, especially in this case.
TF does not seem to have a proper name for "well_known_input_sig"
yet.
And signature_def_key
refers to "well_known_prediction_signature"
in our example.
Maybe input_signature_to_tensor_name
and output_signature_to_tensor_name
?
python/sparkdl/graph/input.py
Outdated
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tf.import_graph_def(graph_def, name='') | ||
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outside the with tf.Session(graph=graph) as sess
context manager block, the session will be closed.
In this case _build_with_feeds_fetches
will fail.
python/sparkdl/graph/input.py
Outdated
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tf.import_graph_def(graph_def, name='') | ||
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although if we return inside the block, it is unclear to me if the session will be closed after the return statement.
26c8f24
to
d729528
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-dbq there are still some changes that would be good to do. I am happy to talk about them in person.
python/sparkdl/graph/input.py
Outdated
- :py:meth:`fromSavedModelWithSignature` | ||
|
||
|
||
When the graph contains serving signatures in which a set of well-known names are associtated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
python/sparkdl/graph/input.py
Outdated
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tf.import_graph_def(graph_def, name='') | ||
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, return inside.
:param signature_def_key: str, key (name) of the signature_def to use. It should be in | ||
the list of `signature_def` structures saved with the checkpoint. | ||
""" | ||
assert signature_def_key is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you checking for some parameters but not some others? This is python, there is only so much you can do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to be vocal about signature_def_key
must not be empty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
python/sparkdl/graph/input.py
Outdated
|
||
if signature_def_key is not None: | ||
sig_def = meta_graph_def.signature_def[signature_def_key] | ||
gin = _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return
python/sparkdl/graph/input.py
Outdated
sig_def = meta_graph_def.signature_def[signature_def_key] | ||
gin = _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) | ||
else: | ||
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return
python/sparkdl/graph/input.py
Outdated
feed_mapping[sigdef_key] = tnsr_name | ||
feed_names.append(tnsr_name) | ||
|
||
# TODO: IN-THIS-PR, test if these mappings are constructed correctly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noting the TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a (partial) test for that, let's remove the TODO
.
@@ -167,7 +167,14 @@ def _check_is_tensor_name(_maybe_tnsr_name): | |||
raise TypeError(err_msg.format(type(_maybe_tnsr_name))) | |||
|
|||
# The check is taken from TensorFlow's NodeDef protocol buffer. | |||
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25 | |||
# Each input is "node:src_output" with "node" being a string name and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, good point. In practice, most nodes have only one output, so I have not been too concerned about multiple outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that was copied-and-pasted from comments inTensorFlow node_def.
#======================================================================== | ||
# Don't have to modify the content below | ||
|
||
_TEST_CASES_GENERATORS = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is too complicated, and I do not see some reasons that would warrant such complexity. Please write some normal tests, with parametrized
for example.
self.output_mapping = {self.output_op_name: self.output_col} | ||
self.fetch_names = [self.output_op_name + ':0'] | ||
|
||
@contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is too complicated and if something goes wrong, you will be hard pressed to understand what is going on.
You should really generate some data, transform this data and compare it against some expected output, and clean up some stuff after that if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, that is exactly what this class does. I checked randomly killing the tests and made sure that the stack traces are meaningful.
class TestGenBase(object): | ||
def __init__(self, vec_size=17, test_batch_size=231): | ||
# Testing data spec | ||
self.vec_size = vec_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this class has an enormous amount of state and moving parts, I cannot keep track what is happening. Please encapsulate what you need to pass to the relevant pieces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thunterdb Let's walk through the testing code once you are back.
The initial version of the testing code was written to directly construct examples and run tests. But the complexity went up quickly once more testing examples were needed.
This version strived to keep the graph construction and comparing numerical results separately, making the code more concise while still allow new example graphs to be added easily.
In the next PR, we show how one could inherit this class and change the testing logic, while still being able to use existing examples.
python/sparkdl/graph/input.py
Outdated
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tf.import_graph_def(graph_def, name='') | ||
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am fine with returning from inside. But I don't think there are any substantial differences for returning inside v.s. outside.
:param signature_def_key: str, key (name) of the signature_def to use. It should be in | ||
the list of `signature_def` structures saved with the checkpoint. | ||
""" | ||
assert signature_def_key is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to be vocal about signature_def_key
must not be empty.
self.output_mapping = {self.output_op_name: self.output_col} | ||
self.fetch_names = [self.output_op_name + ':0'] | ||
|
||
@contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, that is exactly what this class does. I checked randomly killing the tests and made sure that the stack traces are meaningful.
@@ -167,7 +167,14 @@ def _check_is_tensor_name(_maybe_tnsr_name): | |||
raise TypeError(err_msg.format(type(_maybe_tnsr_name))) | |||
|
|||
# The check is taken from TensorFlow's NodeDef protocol buffer. | |||
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25 | |||
# Each input is "node:src_output" with "node" being a string name and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that was copied-and-pasted from comments inTensorFlow node_def.
@thunterdb @sueann I think the refactoring of testing infrastructure is outside the scope of this task/PR. I would recommend moving that to another task/PR. |
* profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-dbq still a few small comments from the previous review.
""" | ||
|
||
|
||
def __init__(self, graph_def, input_tensor_name_from_signature, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is an internal variable that users should not access directly, let's keep is as it is.
python/sparkdl/graph/input.py
Outdated
graph = tf.Graph() | ||
with tf.Session(graph=graph) as sess: | ||
tf.import_graph_def(graph_def, name='') | ||
gin = _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change that, this is the style of this project.
|
||
|
||
def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good to know, thanks for the clarification.
python/sparkdl/graph/input.py
Outdated
feed_mapping[sigdef_key] = tnsr_name | ||
feed_names.append(tnsr_name) | ||
|
||
# TODO: IN-THIS-PR, test if these mappings are constructed correctly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a (partial) test for that, let's remove the TODO
.
:param signature_def_key: str, key (name) of the signature_def to use. It should be in | ||
the list of `signature_def` structures saved with the checkpoint. | ||
""" | ||
assert signature_def_key is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
@thunterdb I think I addressed your comments. Would you like to take another look? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
* flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
* update utils * tests * fix style Using the following YAPF style ======================================================== based_on_style = pep8 ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False COLUMN_LIMIT=100 SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True SPLIT_BEFORE_FIRST_ARGUMENT=False SPLIT_BEFORE_NAMED_ASSIGNS=False SPLIT_PENALTY_AFTER_OPENING_BRACKET=30 USE_TABS=False ======================================================== * refactoring tfx API * test refactoring * PR comments 1. docs in graph/utils.py * (wip) utils test * a few more tests for utils * test update cont'd * PR comments * PR comments * PR comments * TensorFlow Transformer Part-3 (#10) * intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
Introduces
TFInputGraph
and testsTFInputGraph
is used as the internal storage of the TensorFlow graph for theTFTransformer
.