Skip to content

TensorFlow Transformer Part-2 #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 35 additions & 19 deletions python/sparkdl/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

import logging
import six
import webbrowser
from tempfile import NamedTemporaryFile

import tensorflow as tf

Expand Down Expand Up @@ -95,31 +93,49 @@ def get_tensor(graph, tfobj_or_name):
'cannot locate tensor {} in current graph'.format(_tensor_name)
return tnsr

def as_tensor_name(name):
def as_tensor_name(tfobj_or_name):
"""
Derive tf.Tensor name from an op/tensor name.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Derive tf.Tensor name from a op/tensor name or object.

We do not check if the tensor exist (as no graph parameter is passed in).
If the input is a name, we do not check if the tensor exist
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"If the input is a TensorFlow object, or the graph is given, we also check that the tensor exists in the associated graph."

(as no graph parameter is passed in).

:param name: op name or tensor name
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:param graph: should go here?

assert isinstance(name, six.string_types)
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
if len(name_parts) < 2:
name += ":0"
return name
if isinstance(tfobj_or_name, six.string_types):
# If input is a string, assume it is a name and infer the corresponding tensor name.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for comment

# WARNING: this depends on TensorFlow's tensor naming convention
name = tfobj_or_name
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
if len(name_parts) < 2:
name += ":0"
return name
elif hasattr(tfobj_or_name, 'graph'):
tfobj = tfobj_or_name
return get_tensor(tfobj.graph, tfobj).name
else:
raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name)))

def as_op_name(name):
def as_op_name(tfobj_or_name):
"""
Derive tf.Operation name from an op/tensor name
We do not check if the operation exist (as no graph parameter is passed in).
Derive tf.Operation name from an op/tensor name.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"name or object."

If the input is a name, we do not check if the operation exist
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"If the input is a TensorFlow object, or the graph is given, we also check that the tensor exists in the associated graph."

(as no graph parameter is passed in).

:param name: op name or tensor name
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:param graph: should go here?

assert isinstance(name, six.string_types)
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
return name_parts[0]
if isinstance(tfobj_or_name, six.string_types):
# If input is a string, assume it is a name and infer the corresponding operation name.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for comment

# WARNING: this depends on TensorFlow's operation naming convention
name = tfobj_or_name
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
return name_parts[0]
elif hasattr(tfobj_or_name, 'graph'):
tfobj = tfobj_or_name
return get_op(tfobj.graph, tfobj).name
else:
raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name)))

def op_name(graph, tfobj_or_name):
"""
Expand Down
49 changes: 49 additions & 0 deletions python/tests/graph/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tensorflow as tf

import sparkdl.graph.utils as tfx

from ..tests import PythonUnitTestCase

class TFeXtensionGraphUtilsTest(PythonUnitTestCase):

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each test function should test a particular aspect of a particular function in the tfx module, so it's easy to understand what exactly is being tested, and thus if a test fails it's immediately obvious what exactly failed. could you restructure so that is true? e.g. a few I see embedded in the existing tests below are:

def test_op_name():
  # test that op_name from an op name equals input

def test_op_name_from_tensor():
  # test that op_name from a tensor is correct

...

If there is shared set-up code, you can make helper functions for them.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pending on the result on test refactoring in PR-1

def test_infer_graph_element_names(self):
for tnsr_idx in range(17):
op_name = 'someOp'
tnsr_name = '{}:{}'.format(op_name, tnsr_idx)
self.assertEqual(op_name, tfx.as_op_name(tnsr_name))
self.assertEqual(tnsr_name, tfx.as_tensor_name(tnsr_name))

with self.assertRaises(TypeError):
for wrong_value in [7, 1.2, tf.Graph()]:
tfx.as_op_name(wrong_value)
tfx.as_tensor_name(wrong_value)

def test_get_graph_elements(self):
op_name = 'someConstOp'
tnsr_name = '{}:0'.format(op_name)
tnsr = tf.constant(1427.08, name=op_name)
graph = tnsr.graph

self.assertEqual(op_name, tfx.as_op_name(tnsr))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between as_op_name and op_name ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_name takes a graph and a tfobj_or_name argument, and check if the latter is part of the former.
as_op_name returns a valid tensor name if the input is a string. Otherwise, it uses the graph of the tfobj and call op_name.
Thus op_name is a bit more strict than as_op_name.

self.assertEqual(op_name, tfx.op_name(graph, tnsr))
self.assertEqual(tnsr_name, tfx.as_tensor_name(tnsr))
self.assertEqual(tnsr_name, tfx.tensor_name(graph, tnsr))
self.assertEqual(tnsr, tfx.get_tensor(graph, tnsr))
self.assertEqual(tnsr.op, tfx.get_op(graph, tnsr))
self.assertEqual(graph, tfx.get_op(graph, tnsr).graph)
self.assertEqual(graph, tfx.get_tensor(graph, tnsr).graph)
3 changes: 3 additions & 0 deletions python/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession

class PythonUnitTestCase(unittest.TestCase):
# Just the plain test unittest.TestCase, but won't have to do import check
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the unittest2 check above? if so make that explicit.

pass

class SparkDLTestCase(unittest.TestCase):

Expand Down