Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/docs/sparkdl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ Subpackages

.. toctree::

sparkdl.estimators
sparkdl.graph
sparkdl.image
sparkdl.param
sparkdl.transformers
sparkdl.udf
sparkdl.utils
Expand Down
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ h5py>=2.7.0
keras==2.0.4 # NOTE: this package has only been tested with keras 2.0.4 and may not work with other releases
nose>=1.3.7 # for testing
numpy>=1.11.2
parameterized>=0.6.1 # for testing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yet another model I did not know about, great finding and usage!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, Google uses something like this in their testing framework https://github.com/abseil/abseil-py/blob/master/absl/testing/parameterized.py

pillow>=4.1.1,<4.2
pygments>=2.2.0
tensorflow==1.3.0
Expand Down
8 changes: 5 additions & 3 deletions python/sparkdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
#

from .graph.input import TFInputGraph
from .image.imageIO import imageSchema, imageType, readImages
from .transformers.keras_image import KerasImageFileTransformer
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
from .transformers.tf_image import TFImageTransformer
from .transformers.tf_tensor import TFTransformer
from .transformers.utils import imageInputPlaceholder


__all__ = [
'imageSchema', 'imageType', 'readImages',
'TFImageTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer',
'KerasImageFileTransformer',
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer',
'imageInputPlaceholder']
18 changes: 9 additions & 9 deletions python/sparkdl/graph/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
self.graph = graph or tf.Graph()
self.sess = tf.Session(graph=self.graph)
if using_keras:
self.using_keras = True
self.keras_prev_sess = K.get_session()
else:
self.using_keras = False
self.keras_prev_sess = None

def __enter__(self):
self.sess.as_default()
self.sess.__enter__()
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.sess)
return self

def __exit__(self, *args):
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.keras_prev_sess)
self.sess.__exit__(*args)

Expand Down Expand Up @@ -87,8 +88,8 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True):
else:
gdef = self.graph.as_graph_def(add_shapes=True)
return GraphFunction(graph_def=gdef,
input_names=[tfx.validated_input(self.graph, elem) for elem in inputs],
output_names=[tfx.validated_output(self.graph, elem) for elem in outputs])
input_names=[tfx.validated_input(elem, self.graph) for elem in inputs],
output_names=[tfx.validated_output(elem, self.graph) for elem in outputs])

def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs):
"""
Expand Down Expand Up @@ -130,8 +131,8 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k
return_elements=gfn.output_names,
name=scope_name,
**gdef_kargs)
feeds = [tfx.get_tensor(self.graph, name) for name in input_names]
fetches = [tfx.get_tensor(self.graph, name) for name in output_names]
feeds = [tfx.get_tensor(name, self.graph) for name in input_names]
fetches = [tfx.get_tensor(name, self.graph) for name in output_names]
return (feeds, fetches)


Expand Down Expand Up @@ -233,7 +234,7 @@ def fromList(cls, functions):
_, first_gfn = functions[0]
feeds, _ = issn.importGraphFunction(first_gfn, prefix='')
for tnsr in feeds:
name = tfx.op_name(issn.graph, tnsr)
name = tfx.op_name(tnsr, issn.graph)
first_input_info.append((tnsr.dtype, tnsr.shape, name))
# TODO: make sure that this graph is not reused to prevent name conflict
# Report error if the graph is not manipulated by anyone else
Expand Down Expand Up @@ -268,4 +269,3 @@ def fromList(cls, functions):
gfn = issn.asGraphFunction(first_inputs, last_outputs)

return gfn

Loading