-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TFTransformer Part-3 Test Refactor #14
TFTransformer Part-3 Test Refactor #14
Conversation
Codecov Report
@@ Coverage Diff @@
## tf-transformer-part3 #14 +/- ##
========================================================
- Coverage 84.19% 84.04% -0.15%
========================================================
Files 25 25
Lines 1379 1379
Branches 5 5
========================================================
- Hits 1161 1159 -2
- Misses 218 220 +2
Continue to review full report at Codecov.
|
Thanks, @thunterdb! I haven't got the opportunity to take a deep look. |
@phi-dbq good point about assets, I did not know about these. I agree that they should be covered eventually. For now though, the models that we have seen in action do not have these, and we can add tests eventually that are specific to asset-laden models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments.
from sparkdl.graph.input import TFInputGraph | ||
|
||
|
||
class TestGraphImport(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we use unittest.TestCase
so as to be compatible with the rest of the code base. Thanks!
|
||
It takes a vector input, with vec_size columns and returns an int32 scalar. | ||
""" | ||
x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: test requirements)
As we have discussed earlier in a TensorFrames
PR, we need to test graphs with float Tensors.
This is because Spark uses double
for vector types and TensorFlow is very strict on Tensor typing.
Mismatched types will likely end up with pages of Catalyst error info and it is hard to debug.
It takes a vector input, with vec_size columns and returns an int32 scalar. | ||
""" | ||
x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) | ||
_ = tf.reduce_max(x, name=_tensor_output_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also need to test more complex graphs to make sure that floating point computation won't differ too much.
There have been cases where people have spent a lot of time trying to figure out when different frameworks give different results for the same deep learning model.
_serving_tag = "serving_tag" | ||
_serving_sigdef_key = 'prediction_signature' | ||
# The name of the input tensor | ||
_tensor_input_name = "input_tensor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: test requirements)
As Andy Feng has commented before, we need to test support for graphs with multiple inputs and multiple outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the original comment
databricks#39 (comment)
We ended up implementing support for multi-input-multi-output graphs in TFInputGraph
.
builder.save() | ||
|
||
|
||
@contextlib.contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: testing infra)
We should in general use fixtures for handling resources.
python/tests/graph/test_import.py
Outdated
finally: | ||
shutil.rmtree(temp_dir) | ||
|
||
def _build_graph_input(gin_function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: design)
I am wondering, would we call this pattern dependency injection?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sync'ed offline)
This pattern is used both in the proposed change and in the original test framework, though in the latter case it was done as a context manager.
We both agreed that this particular pattern makes sufficient amount of code to be shared and that's good.
_tensor_size = 3 | ||
|
||
|
||
def _build_checkpointed_model(session, tmp_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: design)
We actually modify the graph in the session by assigning signatures to it.
I am just trying to see if this is okay in terms of immutability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sync'ed offline)
It is okay for mutating tensorflow session/graph, because we cannot modify the library we are using.
Our goal is to try to limit the amount of mutability on our side.
gin = _build_graph_input_var(gin_fun) | ||
_check_input_novar(gin) | ||
|
||
def test_checkpoint_nosig_var(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: test requirements)
This is not needed. Checkpoints are only useful for training, where the model contains variables.
python/tests/graph/test_import.py
Outdated
_check_input_novar(gin) | ||
|
||
# TODO: we probably do not need this test | ||
def test_saved_graph_novar(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: test requirements)
This is actually useful. SavedModels are designed for inference/serving.
It should be named test_saved_graph_novar
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some additional comments on the design.
_tensor_size = 3 | ||
|
||
|
||
def _build_checkpointed_model(session, tmp_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sync'ed offline)
It is okay for mutating tensorflow session/graph, because we cannot modify the library we are using.
Our goal is to try to limit the amount of mutability on our side.
gin = TFInputGraph.fromGraph(sess.graph, sess, self.feed_names, self.fetch_names) | ||
self.register(gin=gin, description='saved model with graph') | ||
|
||
def build_from_checkpoint(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(From offline discussion)
@thunterdb commented that the tests are trying to satisfy more requirements than are needed.
To be specific, I was trying to design the test framework to test multiple different TensorFlow graphs, which is an unnecessary requirement.
_serving_tag = "serving_tag" | ||
_serving_sigdef_key = 'prediction_signature' | ||
# The name of the input tensor | ||
_tensor_input_name = "input_tensor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the original comment
databricks#39 (comment)
We ended up implementing support for multi-input-multi-output graphs in TFInputGraph
.
python/tests/graph/test_import.py
Outdated
finally: | ||
shutil.rmtree(temp_dir) | ||
|
||
def _build_graph_input(gin_function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sync'ed offline)
This pattern is used both in the proposed change and in the original test framework, though in the latter case it was done as a context manager.
We both agreed that this particular pattern makes sufficient amount of code to be shared and that's good.
# The basic stage contains the opaque :py:obj:`TFInputGraph` objects | ||
# Any derived that override the :py:obj:`build_input_graphs` method will | ||
# populate this field. | ||
self.input_graphs = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sync'ed offline)
@thunterdb mentioned that the main complaints were that the fields input_graps
and test_cases
are modified during test case construction, and that might hurt the code's readability.
This is, unfortunately, a walk-around due to the limit on Python's unittest framework.
[_tensor_input_name], [_tensor_output_name])) | ||
_check_input_novar(gin) | ||
|
||
def test_saved_model_novar(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: design)
One issue with the current design is that, in order to test two types of graphs, we essentially doubled the amount of code (and actually missed test_graphdef_var
, test_graph_var
, etc.).
python/tests/graph/test_import.py
Outdated
_ = tf.reduce_max(x, name=_tensor_output_name) | ||
|
||
|
||
def _build_graph_var(session): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: design)
Readability-wise, a user have to understand that the workflow
=> _build_graph_xxx
(the original TensorFlow graph)
=> _build_graph_input_xxx
(where the graph is set under an active TensorFlow session)
- =>
gin_function
(specified in each individual test case)- => where
_build_checkpointed_model
or_build_saved_model
would inject additional data to the graph.
- => where
=> _check_input_xxx
(provides input and expected output for the TensorFlow graph)
=> _check_output_xxx
(execute the TensorFlow graph and check the result)
As core library developers, we understand this (took some effort for me). I am not sure this is easy to understand for other contributors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context: design)
In addition, each graph type represented by _build_graph_xxx
has to get all the functions ended with _xxx
replicated.
Although, as we have discussed, given the time constraints, these are deliberately not considered non-essential.
As @thunterdb and I have discussed, I will merge this to my branch. |
* intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
* update utils * tests * fix style Using the following YAPF style ======================================================== based_on_style = pep8 ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False COLUMN_LIMIT=100 SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True SPLIT_BEFORE_FIRST_ARGUMENT=False SPLIT_BEFORE_NAMED_ASSIGNS=False SPLIT_PENALTY_AFTER_OPENING_BRACKET=30 USE_TABS=False ======================================================== * refactoring tfx API * test refactoring * PR comments 1. docs in graph/utils.py * (wip) utils test * a few more tests for utils * test update cont'd * PR comments * PR comments * PR comments * TensorFlow Transformer Part-3 (#10) * intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
@phi-dbq here is an alternative architecture for the tests. It does not exactly match all the tests that you wrote, but it should cover the all the important use cases that you outlined.
Some points to note:
Please comment on the PR, so that we can merge it into your branch.