-
Notifications
You must be signed in to change notification settings - Fork 0
TensorFlow Transformer Part-4 #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1. Move `InputGraph` to its module.
e46eedb
to
2722db2
Compare
Signed-off-by: Philip Yang <philip.yang@databricks.com>
2722db2
to
84a8138
Compare
Codecov Report
@@ Coverage Diff @@
## tf-transformer-part3 #11 +/- ##
========================================================
+ Coverage 84.01% 85.59% +1.58%
========================================================
Files 25 29 +4
Lines 1376 1673 +297
Branches 5 15 +10
========================================================
+ Hits 1156 1432 +276
- Misses 220 241 +21
Continue to review full report at Codecov.
|
52410c9
to
ca951b0
Compare
ca951b0
to
0144b8c
Compare
* adding new tests * remove original test design * cleanup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-dbq this looks good. I just have a few comments that will be quick to address.
@@ -92,6 +91,50 @@ def __init__(self, graph_def, input_tensor_name_from_signature, | |||
self.input_tensor_name_from_signature = input_tensor_name_from_signature | |||
self.output_tensor_name_from_signature = output_tensor_name_from_signature | |||
|
|||
def translateInputMapping(self, input_mapping): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these two functions are either too much or not enough: either you should provide some tests and some doc examples, or not include them. Since they are not used elsewhere, let's put then in a separate PR for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, they are actually in our API design. Let me add some tests for these guys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I see it below.
for idx in range(100): | ||
_dict = {'idx': idx} | ||
for colname, _ in _input_mapping.items(): | ||
_dict[colname] = np.random.randn(_tensor_size).tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use something deterministic instead
from ..tests import SparkDLTestCase | ||
|
||
class TFTransformerTests(SparkDLTestCase): | ||
def test_graph_novar(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great test, we will be able to add some more pretty easily after that
from pyspark.ml import Transformer | ||
|
||
import sparkdl.graph.utils as tfx | ||
from sparkdl.graph.input import TFInputGraph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
|
||
import sparkdl.graph.utils as tfx | ||
from sparkdl.graph.input import TFInputGraph | ||
from sparkdl.param import (keyword_only, SparkDLTypeConverters, HasInputMapping, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SparkDLTypeConverters
is unused
Restrictions of the current API: | ||
|
||
We assume that | ||
- All graphs have a "minibatch" dimension (i.e. an unknown leading |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the inputs of the graph
|
||
Restrictions of the current API: | ||
|
||
We assume that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transformer is expected to work on blocks of data at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this 3 lines below
output_node_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping] | ||
|
||
# NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type | ||
opt_gdef = infr_opt.optimize_for_inference(gin.graph_def, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great find! I could not find any guarantees about the stability of this function (it is part of a program). Do you know if they could become deprecated in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API looks pretty solid. I didn't see any explicit sign of this being deprecated in the near future.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference_lib.py#L89-L104
opt_gdef = infr_opt.optimize_for_inference(gin.graph_def, | ||
input_node_names, | ||
output_node_names, | ||
tf.float64.as_datatype_enum) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With respect to your note, put a comment here, this is the place where we will eventually have to do changes.
|
||
out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict) | ||
|
||
# We still have to rename output columns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, we still need to that, annoyingly
6687eab
to
95cf964
Compare
95cf964
to
af95b74
Compare
@thunterdb I addressed the last round of comments. Would you like to take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
@@ -92,6 +91,50 @@ def __init__(self, graph_def, input_tensor_name_from_signature, | |||
self.input_tensor_name_from_signature = input_tensor_name_from_signature | |||
self.output_tensor_name_from_signature = output_tensor_name_from_signature | |||
|
|||
def translateInputMapping(self, input_mapping): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I see it below.
* intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
* update utils * tests * fix style Using the following YAPF style ======================================================== based_on_style = pep8 ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False COLUMN_LIMIT=100 SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True SPLIT_BEFORE_FIRST_ARGUMENT=False SPLIT_BEFORE_NAMED_ASSIGNS=False SPLIT_PENALTY_AFTER_OPENING_BRACKET=30 USE_TABS=False ======================================================== * refactoring tfx API * test refactoring * PR comments 1. docs in graph/utils.py * (wip) utils test * a few more tests for utils * test update cont'd * PR comments * PR comments * PR comments * TensorFlow Transformer Part-3 (#10) * intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
The implementation of
TFTransformer
based on previous steps.