Skip to content

TensorFlow Transformer Part-2 #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/docs/sparkdl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ Subpackages

.. toctree::

sparkdl.estimators
sparkdl.graph
sparkdl.image
sparkdl.param
sparkdl.transformers
sparkdl.udf
sparkdl.utils
Expand Down
8 changes: 5 additions & 3 deletions python/sparkdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
#

from .graph.input import TFInputGraph
from .image.imageIO import imageSchema, imageType, readImages
from .transformers.keras_image import KerasImageFileTransformer
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
from .transformers.tf_image import TFImageTransformer
from .transformers.tf_tensor import TFTransformer
from .transformers.utils import imageInputPlaceholder


__all__ = [
'imageSchema', 'imageType', 'readImages',
'TFImageTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer',
'KerasImageFileTransformer',
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer',
'imageInputPlaceholder']
18 changes: 9 additions & 9 deletions python/sparkdl/graph/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
self.graph = graph or tf.Graph()
self.sess = tf.Session(graph=self.graph)
if using_keras:
self.using_keras = True
self.keras_prev_sess = K.get_session()
else:
self.using_keras = False
self.keras_prev_sess = None

def __enter__(self):
self.sess.as_default()
self.sess.__enter__()
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.sess)
return self

def __exit__(self, *args):
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.keras_prev_sess)
self.sess.__exit__(*args)

Expand Down Expand Up @@ -87,8 +88,8 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True):
else:
gdef = self.graph.as_graph_def(add_shapes=True)
return GraphFunction(graph_def=gdef,
input_names=[tfx.validated_input(self.graph, elem) for elem in inputs],
output_names=[tfx.validated_output(self.graph, elem) for elem in outputs])
input_names=[tfx.validated_input(elem, self.graph) for elem in inputs],
output_names=[tfx.validated_output(elem, self.graph) for elem in outputs])

def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs):
"""
Expand Down Expand Up @@ -130,8 +131,8 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k
return_elements=gfn.output_names,
name=scope_name,
**gdef_kargs)
feeds = [tfx.get_tensor(self.graph, name) for name in input_names]
fetches = [tfx.get_tensor(self.graph, name) for name in output_names]
feeds = [tfx.get_tensor(name, self.graph) for name in input_names]
fetches = [tfx.get_tensor(name, self.graph) for name in output_names]
return (feeds, fetches)


Expand Down Expand Up @@ -233,7 +234,7 @@ def fromList(cls, functions):
_, first_gfn = functions[0]
feeds, _ = issn.importGraphFunction(first_gfn, prefix='')
for tnsr in feeds:
name = tfx.op_name(issn.graph, tnsr)
name = tfx.op_name(tnsr, issn.graph)
first_input_info.append((tnsr.dtype, tnsr.shape, name))
# TODO: make sure that this graph is not reused to prevent name conflict
# Report error if the graph is not manipulated by anyone else
Expand Down Expand Up @@ -268,4 +269,3 @@ def fromList(cls, functions):
gfn = issn.asGraphFunction(first_inputs, last_outputs)

return gfn

Loading