-
Notifications
You must be signed in to change notification settings - Fork 0
TensorFlow Transformer Part-2 #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
0c8c219
522279a
7287ab7
467480e
27f0617
93e659d
f8d7930
aef1661
ead1ed6
055ce14
2d48b32
742cdaf
f0912fb
3c849f2
d729528
0c2eda1
63967b4
de088a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,8 +16,6 @@ | |
|
||
import logging | ||
import six | ||
import webbrowser | ||
from tempfile import NamedTemporaryFile | ||
|
||
import tensorflow as tf | ||
|
||
|
@@ -95,31 +93,49 @@ def get_tensor(graph, tfobj_or_name): | |
'cannot locate tensor {} in current graph'.format(_tensor_name) | ||
return tnsr | ||
|
||
def as_tensor_name(name): | ||
def as_tensor_name(tfobj_or_name): | ||
""" | ||
Derive tf.Tensor name from an op/tensor name. | ||
We do not check if the tensor exist (as no graph parameter is passed in). | ||
If the input is a name, we do not check if the tensor exist | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "If the input is a TensorFlow object, or the graph is given, we also check that the tensor exists in the associated graph." |
||
(as no graph parameter is passed in). | ||
|
||
:param name: op name or tensor name | ||
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :param graph: should go here? |
||
assert isinstance(name, six.string_types) | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
if len(name_parts) < 2: | ||
name += ":0" | ||
return name | ||
if isinstance(tfobj_or_name, six.string_types): | ||
# If input is a string, assume it is a name and infer the corresponding tensor name. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for comment |
||
# WARNING: this depends on TensorFlow's tensor naming convention | ||
name = tfobj_or_name | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
if len(name_parts) < 2: | ||
name += ":0" | ||
return name | ||
elif hasattr(tfobj_or_name, 'graph'): | ||
tfobj = tfobj_or_name | ||
return get_tensor(tfobj.graph, tfobj).name | ||
else: | ||
raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name))) | ||
|
||
def as_op_name(name): | ||
def as_op_name(tfobj_or_name): | ||
""" | ||
Derive tf.Operation name from an op/tensor name | ||
We do not check if the operation exist (as no graph parameter is passed in). | ||
Derive tf.Operation name from an op/tensor name. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "name or object." |
||
If the input is a name, we do not check if the operation exist | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "If the input is a TensorFlow object, or the graph is given, we also check that the tensor exists in the associated graph." |
||
(as no graph parameter is passed in). | ||
|
||
:param name: op name or tensor name | ||
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :param graph: should go here? |
||
assert isinstance(name, six.string_types) | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
return name_parts[0] | ||
if isinstance(tfobj_or_name, six.string_types): | ||
# If input is a string, assume it is a name and infer the corresponding operation name. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for comment |
||
# WARNING: this depends on TensorFlow's operation naming convention | ||
name = tfobj_or_name | ||
name_parts = name.split(":") | ||
assert len(name_parts) <= 2, name_parts | ||
return name_parts[0] | ||
elif hasattr(tfobj_or_name, 'graph'): | ||
tfobj = tfobj_or_name | ||
return get_op(tfobj.graph, tfobj).name | ||
else: | ||
raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name))) | ||
|
||
def op_name(graph, tfobj_or_name): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright 2017 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import tensorflow as tf | ||
|
||
import sparkdl.graph.utils as tfx | ||
|
||
from ..tests import PythonUnitTestCase | ||
|
||
class TFeXtensionGraphUtilsTest(PythonUnitTestCase): | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. each test function should test a particular aspect of a particular function in the tfx module, so it's easy to understand what exactly is being tested, and thus if a test fails it's immediately obvious what exactly failed. could you restructure so that is true? e.g. a few I see embedded in the existing tests below are:
If there is shared set-up code, you can make helper functions for them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pending on the result on test refactoring in PR-1 |
||
def test_infer_graph_element_names(self): | ||
for tnsr_idx in range(17): | ||
op_name = 'someOp' | ||
tnsr_name = '{}:{}'.format(op_name, tnsr_idx) | ||
self.assertEqual(op_name, tfx.as_op_name(tnsr_name)) | ||
self.assertEqual(tnsr_name, tfx.as_tensor_name(tnsr_name)) | ||
|
||
with self.assertRaises(TypeError): | ||
for wrong_value in [7, 1.2, tf.Graph()]: | ||
tfx.as_op_name(wrong_value) | ||
tfx.as_tensor_name(wrong_value) | ||
|
||
def test_get_graph_elements(self): | ||
op_name = 'someConstOp' | ||
tnsr_name = '{}:0'.format(op_name) | ||
tnsr = tf.constant(1427.08, name=op_name) | ||
graph = tnsr.graph | ||
|
||
self.assertEqual(op_name, tfx.as_op_name(tnsr)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the difference between as_op_name and op_name ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.assertEqual(op_name, tfx.op_name(graph, tnsr)) | ||
self.assertEqual(tnsr_name, tfx.as_tensor_name(tnsr)) | ||
self.assertEqual(tnsr_name, tfx.tensor_name(graph, tnsr)) | ||
self.assertEqual(tnsr, tfx.get_tensor(graph, tnsr)) | ||
self.assertEqual(tnsr.op, tfx.get_op(graph, tnsr)) | ||
self.assertEqual(graph, tfx.get_op(graph, tnsr).graph) | ||
self.assertEqual(graph, tfx.get_tensor(graph, tnsr).graph) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,9 @@ | |
from pyspark.sql import SQLContext | ||
from pyspark.sql import SparkSession | ||
|
||
class PythonUnitTestCase(unittest.TestCase): | ||
# Just the plain test unittest.TestCase, but won't have to do import check | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean the unittest2 check above? if so make that explicit. |
||
pass | ||
|
||
class SparkDLTestCase(unittest.TestCase): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Derive tf.Tensor name from a op/tensor name or object.