Skip to content

Commit

Permalink
TensorFlow Transformer Part-2 (#9)
Browse files Browse the repository at this point in the history
* update utils

* tests

* fix style

Using the following YAPF style
========================================================
based_on_style = pep8
ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False
COLUMN_LIMIT=100
SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False
SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True
SPLIT_BEFORE_FIRST_ARGUMENT=False
SPLIT_BEFORE_NAMED_ASSIGNS=False
SPLIT_PENALTY_AFTER_OPENING_BRACKET=30
USE_TABS=False
========================================================

* refactoring tfx API

* test refactoring

* PR comments

1. docs in graph/utils.py

* (wip) utils test

* a few more tests for utils

* test update cont'd

* PR comments

* PR comments

* PR comments

* TensorFlow Transformer Part-3 (#10)

* intro: TFInputGraph

* tests

* Merge branch 'tf-transformer-part1' into tf-transformer-part3

* and so there is no helper classes

* and into more pieces

* class & docs

* update docs

* refactoring tfx API

* update tfx utils usage

* one way to build these tests

* tests refactored

* test cases in a single class

THis will make things easier when we want to extend other base class functions.

* shuffle things around

Signed-off-by: Philip Yang <philip.yang@databricks.com>

* docs mostly

* yapf'd

* consolidate tempdir creation

* (wip) PR comments

* more tests

* change test generator module name

* TFTransformer Part-3 Test Refactor (#14)

* profiling

* tests

* renamed test

* removed original tests

* removed the profiler utils

* fixes indents

* imports

* added some tests

* added test

* fix test

* one more test

* PR comments

* TensorFlow Transformer Part-4 (#11)

* flat param API impl

* support input graph scenarios

* (WIP) new interface implementation

* docs and cleanup

* using tensorflow API instead of our utilities

* automatic type conversion

* cleanup

* PR comments

1. Move `InputGraph` to its module.

* (WIP) address comments

* (WIP) respond to PR comments

* test refactor

* (wip) consolidating params

* rebase upstream

* import params fix

* (wip) TFInputGraph impl

* (wip) moving to new API

* (wip) enable saved_model tests

* (wip) enable checkpoint test

* (wip) enable multiple tensor tests

* enable all tests

* optimize graph for inference

* allows setting TFInputGraph

* utilize test_input_graph for transformer tests

* enable all tests

Signed-off-by: Philip Yang <philip.yang@databricks.com>

* input graph

* docs

* tensor tests

* tensor test update

* TFTransformer Part-4 Test Refactor (#15)

* adding new tests

* remove original test design

* cleanup

* deleting original testing ideas

* PR comments
  • Loading branch information
phi-dbq authored Nov 22, 2017
1 parent a8531ec commit cf856db
Show file tree
Hide file tree
Showing 19 changed files with 1,297 additions and 119 deletions.
2 changes: 2 additions & 0 deletions python/docs/sparkdl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ Subpackages

.. toctree::

sparkdl.estimators
sparkdl.graph
sparkdl.image
sparkdl.param
sparkdl.transformers
sparkdl.udf
sparkdl.utils
Expand Down
8 changes: 5 additions & 3 deletions python/sparkdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
#

from .graph.input import TFInputGraph
from .image.imageIO import imageSchema, imageType, readImages
from .transformers.keras_image import KerasImageFileTransformer
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
from .transformers.tf_image import TFImageTransformer
from .transformers.tf_tensor import TFTransformer
from .transformers.utils import imageInputPlaceholder


__all__ = [
'imageSchema', 'imageType', 'readImages',
'TFImageTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer',
'KerasImageFileTransformer',
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer',
'imageInputPlaceholder']
18 changes: 9 additions & 9 deletions python/sparkdl/graph/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
self.graph = graph or tf.Graph()
self.sess = tf.Session(graph=self.graph)
if using_keras:
self.using_keras = True
self.keras_prev_sess = K.get_session()
else:
self.using_keras = False
self.keras_prev_sess = None

def __enter__(self):
self.sess.as_default()
self.sess.__enter__()
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.sess)
return self

def __exit__(self, *args):
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.keras_prev_sess)
self.sess.__exit__(*args)

Expand Down Expand Up @@ -87,8 +88,8 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True):
else:
gdef = self.graph.as_graph_def(add_shapes=True)
return GraphFunction(graph_def=gdef,
input_names=[tfx.validated_input(self.graph, elem) for elem in inputs],
output_names=[tfx.validated_output(self.graph, elem) for elem in outputs])
input_names=[tfx.validated_input(elem, self.graph) for elem in inputs],
output_names=[tfx.validated_output(elem, self.graph) for elem in outputs])

def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs):
"""
Expand Down Expand Up @@ -130,8 +131,8 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k
return_elements=gfn.output_names,
name=scope_name,
**gdef_kargs)
feeds = [tfx.get_tensor(self.graph, name) for name in input_names]
fetches = [tfx.get_tensor(self.graph, name) for name in output_names]
feeds = [tfx.get_tensor(name, self.graph) for name in input_names]
fetches = [tfx.get_tensor(name, self.graph) for name in output_names]
return (feeds, fetches)


Expand Down Expand Up @@ -233,7 +234,7 @@ def fromList(cls, functions):
_, first_gfn = functions[0]
feeds, _ = issn.importGraphFunction(first_gfn, prefix='')
for tnsr in feeds:
name = tfx.op_name(issn.graph, tnsr)
name = tfx.op_name(tnsr, issn.graph)
first_input_info.append((tnsr.dtype, tnsr.shape, name))
# TODO: make sure that this graph is not reused to prevent name conflict
# Report error if the graph is not manipulated by anyone else
Expand Down Expand Up @@ -268,4 +269,3 @@ def fromList(cls, functions):
gfn = issn.asGraphFunction(first_inputs, last_outputs)

return gfn

Loading

0 comments on commit cf856db

Please sign in to comment.