Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFTransformer Part-3 Test Refactor #14

Merged
merged 11 commits into from
Nov 18, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
removed the profiler utils
  • Loading branch information
thunterdb committed Nov 9, 2017
commit 0eb56e731552388ae675758f29b4e2216160f47e
13 changes: 1 addition & 12 deletions python/tests/graph/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,11 @@

from ..tests import SparkDLTestCase
from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF
from ..utils import do_cprofile


class GraphFunctionWithIsolatedSessionTest(SparkDLTestCase):

def test_tf_consistency(self):
self._test_tf_consistency()

@do_cprofile
def _test_tf_consistency(self):
""" Should get the same graph as running pure tf """

x_val = 2702.142857
Expand Down Expand Up @@ -75,7 +70,6 @@ def _test_tf_consistency(self):
# should be the same as that in the one exported directly from TensorFlow session
self.assertEqual(str(gfn.graph_def), str(gdef_ref))

@do_cprofile
def test_get_graph_elements(self):
""" Fetching graph elements by names and other graph elements """

Expand All @@ -94,7 +88,6 @@ def test_get_graph_elements(self):
self.assertEqual(tfx.tensor_name(z, g), "z:0")
self.assertEqual(tfx.tensor_name(x, g), "x:0")

@do_cprofile
def test_import_export_graph_function(self):
""" Function import and export must be consistent """

Expand All @@ -111,12 +104,8 @@ def test_import_export_graph_function(self):
self.assertEqual(gfn_tgt.output_names, gfn_ref.output_names)
self.assertEqual(str(gfn_tgt.graph_def), str(gfn_ref.graph_def))

def test_keras_consistency(self):
self._test_keras_consistency()
assert False

@do_cprofile
def _test_keras_consistency(self):
def test_keras_consistency(self):
""" Exported model in Keras should get same result as original """

img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))
Expand Down
19 changes: 0 additions & 19 deletions python/tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import cProfile

"""
This function was copied from here:
https://zapier.com/engineering/profiling-python-boss/
"""
def do_cprofile(func):
def profiled_func(*args, **kwargs):
profile = cProfile.Profile()
try:
profile.enable()
result = func(*args, **kwargs)
profile.disable()
return result
finally:
profile.print_stats()
profile.dump_stats("/tmp/stats.txt")
return profiled_func