-
Notifications
You must be signed in to change notification settings - Fork 0
TensorFlow Transformer Part-2 #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## tf-transformer-part1 #9 +/- ##
========================================================
+ Coverage 82.66% 83.13% +0.47%
========================================================
Files 24 24
Lines 1281 1293 +12
Branches 5 5
========================================================
+ Hits 1059 1075 +16
+ Misses 222 218 -4
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a description to the PR so I can understand why this change is needed?
python/tests/tests.py
Outdated
@@ -29,6 +29,9 @@ | |||
from pyspark.sql import SQLContext | |||
from pyspark.sql import SparkSession | |||
|
|||
class PythonUnitTestCase(unittest.TestCase): | |||
# Just the plain test unittest.TestCase, but won't have to do import check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean the unittest2 check above? if so make that explicit.
python/tests/graph/test_utils.py
Outdated
tnsr = tf.constant(1427.08, name=op_name) | ||
graph = tnsr.graph | ||
|
||
self.assertEqual(op_name, tfx.as_op_name(tnsr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between as_op_name and op_name ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_name
takes a graph and a tfobj_or_name
argument, and check if the latter is part of the former.
as_op_name
returns a valid tensor name if the input is a string. Otherwise, it uses the graph of the tfobj
and call op_name
.
Thus op_name
is a bit more strict than as_op_name
.
python/tests/graph/test_utils.py
Outdated
from ..tests import PythonUnitTestCase | ||
|
||
class TFeXtensionGraphUtilsTest(PythonUnitTestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each test function should test a particular aspect of a particular function in the tfx module, so it's easy to understand what exactly is being tested, and thus if a test fails it's immediately obvious what exactly failed. could you restructure so that is true? e.g. a few I see embedded in the existing tests below are:
def test_op_name():
# test that op_name from an op name equals input
def test_op_name_from_tensor():
# test that op_name from a tensor is correct
...
If there is shared set-up code, you can make helper functions for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pending on the result on test refactoring in PR-1
Using the following YAPF style ======================================================== based_on_style = pep8 ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False COLUMN_LIMIT=100 SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True SPLIT_BEFORE_FIRST_ARGUMENT=False SPLIT_BEFORE_NAMED_ASSIGNS=False SPLIT_PENALTY_AFTER_OPENING_BRACKET=30 USE_TABS=False ========================================================
7287ab7
to
4f11374
Compare
Can you say in the PR description which PR this is used in? Trying to understand how it is used (outside of tests). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think i understand op_name and as_op_name better now. will we break a lot of things if we changed this function to be def op_name(tfobj_or_name, graph=None) and rolled in as_op_name into it? Same for tensor_name.
then we have only one function, with the assumption that if the graph is given, we'll verify that the tensor / op name is legit.
having all these functions that are almost the same is confusing for the user.
7ae89e4
to
76e9fb9
Compare
06280b2
to
202e7ea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consolidated a bunch of (as_|get_)?(tensor|op)(_name)?
functions as per the comments.
202e7ea
to
ead1ed6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks pretty good. thanks for consolidating the op_name tensor_name functions! mostly small comments below.
python/sparkdl/graph/utils.py
Outdated
""" | ||
Derive tf.Tensor name from an op/tensor name. | ||
We do not check if the tensor exist (as no graph parameter is passed in). | ||
If the input is a name, we do not check if the tensor exist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"If the input is a TensorFlow object, or the graph is given, we also check that the tensor exists in the associated graph."
python/sparkdl/graph/utils.py
Outdated
tnsr = graph.get_tensor_by_name(_tensor_name) | ||
assert tnsr is not None, \ | ||
'cannot locate tensor {} in current graph'.format(_tensor_name) | ||
return tnsr | ||
|
||
def as_tensor_name(name): | ||
def tensor_name(tfobj_or_name, graph=None): | ||
""" | ||
Derive tf.Tensor name from an op/tensor name. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Derive tf.Tensor name from a op/tensor name or object.
python/sparkdl/graph/utils.py
Outdated
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either | ||
""" | ||
graph = validated_graph(graph) | ||
return get_op(graph, tfobj_or_name).name | ||
# If `graph` is provided, directly get the graph operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for comment
python/sparkdl/graph/utils.py
Outdated
# If `graph` is provided, directly get the graph operation | ||
if graph is not None: | ||
return get_tensor(tfobj_or_name, graph).name | ||
# If `graph` is absent, check if other cases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for comment
python/sparkdl/graph/utils.py
Outdated
""" | ||
Get the name of a tf.Tensor | ||
Derive tf.Operation name from an op/tensor name. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"name or object."
python/tests/graph/test_utils.py
Outdated
|
||
class TFeXtensionGraphUtilsTest(PythonUnitTestCase): | ||
@parameterized.expand(_gen_graph_elems_names) | ||
def test_valid_graph_element_names(self, data, description): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_valid_tensor_op_string_inputs
is more clear.
python/tests/graph/test_utils.py
Outdated
yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)), | ||
description='get tensor name from tensor (with graph)') | ||
yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)), | ||
description='get tensor from the same tensor (with graph)') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the four cases (this and below) require graph
as input, so we don't need to add "(with graph)" in the description. having it makes it seem like there should be "(no graph)" versions.
python/tests/graph/test_utils.py
Outdated
tnsr = tf.constant(1427.08, name=op_name) | ||
graph = tnsr.graph | ||
|
||
yield TestCase(data=(op_name, tfx.op_name(tnsr)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since there are so many test cases here, let's add comments in between sections, e.g.
# Tests for op_name
yield ..
...
# Tests for tensor_name
...
# Tests for get_op
...
# Tests for get_tensor
...
# ...
python/tests/graph/test_utils.py
Outdated
description='get tensor from the same tensor (with graph)') | ||
yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)), | ||
description='get op from tensor (with graph)') | ||
yield TestCase(data=(graph, tfx.get_op(tnsr, graph).graph), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would we be worried about this not being true? is there some complex logic in get_op
we can mess up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One basic contract for the get_op
and get_tensor
functions is that the returned tensor or operation should belong to the graph that we pass as the second argument.
Otherwise, there could be subtle errors that are hard to debug. For example, we might feed data to a tensor/placeholder on the wrong graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a statement about handling bad input. It'd make more sense to test that it raises error if the graph and the op are not compatible, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a statement about handling bad input. We should also test that these raise errors if the graph and the op are not compatible.
python/tests/graph/test_utils.py
Outdated
|
||
yield TestCase(data=(op_name, tfx.op_name(tnsr)), | ||
description='get op name from tensor (no graph)') | ||
yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should have tests for getting op_name and tesor_name from op objects as well
1. docs in graph/utils.py
@sueann I updated the PR regarding the last round of comments. |
python/tests/graph/test_utils.py
Outdated
description='get tensor from the same tensor (with graph)') | ||
yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)), | ||
description='get op from tensor (with graph)') | ||
yield TestCase(data=(graph, tfx.get_op(tnsr, graph).graph), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a statement about handling bad input. It'd make more sense to test that it raises error if the graph and the op are not compatible, no?
python/tests/graph/test_utils.py
Outdated
for wrong_val in [7, 1.2, tf.Graph()]: | ||
yield TestCase(data=wrong_val, description='wrong type {}'.format(type(wrong_val))) | ||
|
||
|
||
def _gen_graph_elems(): | ||
def _gen_valid_tensor_op_objects(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now it seems to be testing all input combos. if so, rename to _gen_valid_tensor_op_input_combos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a few comments
python/tests/graph/test_utils.py
Outdated
description='get tensor from the same tensor (with graph)') | ||
yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)), | ||
description='get op from tensor (with graph)') | ||
yield TestCase(data=(graph, tfx.get_op(tnsr, graph).graph), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a statement about handling bad input. We should also test that these raise errors if the graph and the op are not compatible.
c53f0a7
to
26c8f24
Compare
26c8f24
to
d729528
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
almost there! just a few nits. thanks!
python/sparkdl/graph/utils.py
Outdated
@@ -71,8 +72,7 @@ def get_op(tfobj_or_name, graph): | |||
raise TypeError('invalid op request for [type {}] {}'.format(type(name), name)) | |||
_op_name = op_name(name, graph=None) | |||
op = graph.get_operation_by_name(_op_name) | |||
assert op is not None, \ | |||
'cannot locate op {} in current graph'.format(_op_name) | |||
assert isinstance(op, tf.Operation), 'expect tf.Operation, but got {}'.format(type(op)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the previous error message was more clear. the current one sounds a little bit like you expected the user to give us the operation object, not our internal invocation of get_operation_by_name
.
python/sparkdl/graph/utils.py
Outdated
|
||
|
||
def _assert_same_graph(tfobj, graph): | ||
if graph is None or not hasattr(tfobj, 'graph'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simpler (no need for the extra return statement):
if graph is not None and hasattr(tfobj, 'graph'):
assert ...
python/tests/graph/test_utils.py
Outdated
|
||
# Test get_tensor and get_op returns tensor or op contained in the same graph | ||
yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph), | ||
description='test graph from getting op fron tensor') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'test get_op from tensor with wrong graph' is more clear (similarly for other testcases in this method)
python/tests/graph/test_utils.py
Outdated
other_graph = tf.Graph() | ||
op_name = tnsr.op.name | ||
|
||
# Test get_tensor and get_op returns tensor or op contained in the same graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'Test get_tensor and get_op with non-associated tensor/op and graph inputs'
python/tests/graph/test_utils.py
Outdated
yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)), | ||
description='test op from op name') | ||
|
||
# Test get_tensor and get_op returns tensor or op contained in the same graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests don't seem necessary now that we have the asserts you've most recently added (we invoke the same function calls in other tests and they will fail if the asserts fail). let's remove them to save on testing and reading time. i like this new way of having the asserts in the methods much better than the tests here, since what we're checking here are really implementation details.
6947e62
to
0c2eda1
Compare
python/tests/graph/test_utils.py
Outdated
@@ -128,24 +129,6 @@ def _gen_valid_tensor_op_input_combos(): | |||
yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)), | |||
description='test op from op name') | |||
|
|||
# Test get_tensor and get_op returns tensor or op contained in the same graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed these tests as they are covered by other test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated regarding latest PR comments.
python/tests/graph/test_utils.py
Outdated
@@ -47,23 +47,24 @@ def _gen_invalid_tensor_or_op_with_graph_pairing(): | |||
other_graph = tf.Graph() | |||
op_name = tnsr.op.name | |||
|
|||
# Test get_tensor and get_op returns tensor or op contained in the same graph | |||
# Test get_tensor and get_op with non-associated tensor/op and graph inputs | |||
_comm_suffix = ' with non-associated tensor/op and graph inputs' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed the description to clarify the intention of each test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
' with wrong graph' seems good enough here and more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to be clear, my comment was for _comm_suffix
. the long way to say it for the comment is good since there are now two ways the same thing is described to the user - so it should decrease uncertainty if one of them doesn't quite make sense to the reader.
python/sparkdl/graph/utils.py
Outdated
@@ -216,7 +218,6 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): | |||
|
|||
|
|||
def _assert_same_graph(tfobj, graph): | |||
if graph is None or not hasattr(tfobj, 'graph'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the first return statement.
python/sparkdl/graph/utils.py
Outdated
@@ -72,12 +72,13 @@ def get_op(tfobj_or_name, graph): | |||
raise TypeError('invalid op request for [type {}] {}'.format(type(name), name)) | |||
_op_name = op_name(name, graph=None) | |||
op = graph.get_operation_by_name(_op_name) | |||
assert isinstance(op, tf.Operation), 'expect tf.Operation, but got {}'.format(type(op)) | |||
err_msg = 'cannot locate op {} in the current graph, got [type {}] {}' | |||
assert isinstance(op, tf.Operation), err_msg.format(_op_name, type(op), op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed error message to better align to the intended usage of the function rather than focusing locally on this assertion statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. a few nits for you to fix, but will approve (no further review required).
python/tests/graph/test_utils.py
Outdated
@@ -47,23 +47,24 @@ def _gen_invalid_tensor_or_op_with_graph_pairing(): | |||
other_graph = tf.Graph() | |||
op_name = tnsr.op.name | |||
|
|||
# Test get_tensor and get_op returns tensor or op contained in the same graph | |||
# Test get_tensor and get_op with non-associated tensor/op and graph inputs | |||
_comm_suffix = ' with non-associated tensor/op and graph inputs' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
' with wrong graph' seems good enough here and more readable.
python/tests/graph/test_utils.py
Outdated
yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph), | ||
description='test graph from getting op fron tensor') | ||
description='test get_op with from tensor' + _comm_suffix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove "with"
python/tests/graph/test_utils.py
Outdated
yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph), | ||
description='test graph from getting op fron tensor name') | ||
description='test get_op fron tensor name' + _comm_suffix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fron -> from
e033795
to
63967b4
Compare
* intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
Utilities to infer tensor and operation names:
as_tensor_name
andas_op_name
.They implement the following functionalities.
{tf.Tensor, tf.Operation, tf.Tensor name, tf.Operation name} |-> name of a tf.Operation
{tf.Tensor, tf.Operation, tf.Tensor name, tf.Operation name} |-> name of a tf.Tensor
These two methods differ from existing ones in that if the input is already a string type, they will simply go and infer the name of the graph element (either tensor or operation).
These are useful when working outside the context of a
tf.Graph
object.Existing methods such:
tensor_name
,op_name
,get_tensor
andget_op
all require atf.Graph
object as input.