Skip to content

TensorFlow Transformer Part-4 #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
42c6e6e
flat param API impl
phi-dbq Aug 8, 2017
ecbefb9
support input graph scenarios
phi-dbq Aug 25, 2017
ab89bd2
(WIP) new interface implementation
phi-dbq Sep 9, 2017
8c7d72e
docs and cleanup
phi-dbq Sep 9, 2017
eb543c6
using tensorflow API instead of our utilities
phi-dbq Sep 10, 2017
4743bb9
automatic type conversion
phi-dbq Sep 10, 2017
622c788
cleanup
phi-dbq Sep 10, 2017
07f1cec
PR comments
phi-dbq Sep 11, 2017
692b0eb
(WIP) address comments
phi-dbq Sep 12, 2017
66d44e9
(WIP) respond to PR comments
phi-dbq Sep 13, 2017
9b3fe86
test refactor
phi-dbq Sep 13, 2017
8c32501
Merge remote-tracking branch 'upstream/master' into tf-1d-transformer
phi-dbq Sep 16, 2017
dbd9aaa
(wip) consolidating params
phi-dbq Sep 16, 2017
4572205
rebase upstream
phi-dbq Sep 16, 2017
1cc7591
import params fix
phi-dbq Sep 16, 2017
2fc6787
(wip) TFInputGraph impl
phi-dbq Sep 16, 2017
889df0a
(wip) moving to new API
phi-dbq Sep 17, 2017
86cd6d9
(wip) enable saved_model tests
phi-dbq Sep 17, 2017
ac09182
(wip) enable checkpoint test
phi-dbq Sep 17, 2017
6b22eed
(wip) enable multiple tensor tests
phi-dbq Sep 17, 2017
a3517d6
enable all tests
phi-dbq Sep 17, 2017
6e46073
Merge branch 'tf-transformer-part1' into api-tf-transformer
phi-dbq Sep 18, 2017
b232b3c
optimize graph for inference
phi-dbq Sep 18, 2017
97b25c6
Baseline
phi-dbq Sep 21, 2017
07c58e6
allows setting TFInputGraph
phi-dbq Sep 21, 2017
269ad15
utilize test_input_graph for transformer tests
phi-dbq Sep 21, 2017
84a8138
enable all tests
phi-dbq Sep 21, 2017
6e880ce
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Sep 23, 2017
883321e
input graph
phi-dbq Sep 23, 2017
c72444b
docs
phi-dbq Sep 23, 2017
89e2a1d
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Sep 23, 2017
6aa85b9
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Sep 29, 2017
85e0778
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Oct 3, 2017
22754c9
tensor tests
phi-dbq Oct 3, 2017
0144b8c
tensor test update
phi-dbq Oct 3, 2017
c6eb87c
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Oct 3, 2017
812f4d6
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Oct 5, 2017
47d497c
TFTransformer Part-4 Test Refactor (#15)
phi-dbq Nov 18, 2017
07cc335
deleting original testing ideas
phi-dbq Nov 18, 2017
925fc0d
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Nov 18, 2017
91b9379
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Nov 22, 2017
af95b74
PR comments
phi-dbq Nov 22, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/sparkdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
#

from .graph.input import TFInputGraph
from .image.imageIO import imageSchema, imageType, readImages
from .transformers.keras_image import KerasImageFileTransformer
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
from .transformers.tf_image import TFImageTransformer
from .transformers.tf_tensor import TFTransformer
from .transformers.utils import imageInputPlaceholder


__all__ = [
'imageSchema', 'imageType', 'readImages',
'TFImageTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer',
'KerasImageFileTransformer',
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer',
'imageInputPlaceholder']
7 changes: 4 additions & 3 deletions python/sparkdl/graph/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
self.graph = graph or tf.Graph()
self.sess = tf.Session(graph=self.graph)
if using_keras:
self.using_keras = True
self.keras_prev_sess = K.get_session()
else:
self.using_keras = False
self.keras_prev_sess = None

def __enter__(self):
self.sess.as_default()
self.sess.__enter__()
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.sess)
return self

def __exit__(self, *args):
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.keras_prev_sess)
self.sess.__exit__(*args)

Expand Down
44 changes: 44 additions & 0 deletions python/sparkdl/graph/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,50 @@ def __init__(self, graph_def, input_tensor_name_from_signature,
self.input_tensor_name_from_signature = input_tensor_name_from_signature
self.output_tensor_name_from_signature = output_tensor_name_from_signature

def translateInputMapping(self, input_mapping):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two functions are either too much or not enough: either you should provide some tests and some doc examples, or not include them. Since they are not used elsewhere, let's put then in a separate PR for now.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, they are actually in our API design. Let me add some tests for these guys.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see it below.

"""
When the meta_graph contains signature_def, we expect users to provide
input and output mapping with respect to the tensor reference keys
embedded in the `signature_def`.

This function translates the input_mapping into the canonical format,
which maps input DataFrame column names to tensor names.

:param input_mapping: dict, DataFrame column name to tensor reference names
defined in the signature_def key.
"""
assert self.input_tensor_name_from_signature is not None
_input_mapping = {}
if isinstance(input_mapping, dict):
input_mapping = list(input_mapping.items())
assert isinstance(input_mapping, list)
for col_name, sig_key in input_mapping:
tnsr_name = self.input_tensor_name_from_signature[sig_key]
_input_mapping[col_name] = tnsr_name
return _input_mapping

def translateOutputMapping(self, output_mapping):
"""
When the meta_graph contains signature_def, we expect users to provide
input and output mapping with respect to the tensor reference keys
embedded in the `signature_def`.

This function translates the output_mapping into the canonical format,
which maps tensor names into input DataFrame column names.

:param output_mapping: dict, tensor reference names defined in the signature_def keys
into the output DataFrame column names.
"""
assert self.output_tensor_name_from_signature is not None
_output_mapping = {}
if isinstance(output_mapping, dict):
output_mapping = list(output_mapping.items())
assert isinstance(output_mapping, list)
for sig_key, col_name in output_mapping:
tnsr_name = self.output_tensor_name_from_signature[sig_key]
_output_mapping[tnsr_name] = col_name
return _output_mapping

@classmethod
def fromGraph(cls, graph, sess, feed_names, fetch_names):
"""
Expand Down
2 changes: 1 addition & 1 deletion python/sparkdl/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sparkdl.param.shared_params import (
keyword_only, HasInputCol, HasOutputCol, HasLabelCol,
# TFTransformer Params
HasInputMapping, HasOutputMapping, HasTFHParams,
HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams,
# Keras Estimator Params
HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName)
from sparkdl.param.converters import SparkDLTypeConverters
Expand Down
8 changes: 8 additions & 0 deletions python/sparkdl/param/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from pyspark.ml.param import TypeConverters

from sparkdl.graph.input import *
import sparkdl.utils.keras_model as kmutil

__all__ = ['SparkDLTypeConverters']
Expand All @@ -52,6 +53,13 @@ def toTFGraph(value):
raise TypeError("Could not convert %s to tf.Graph" % type(value))
return value

@staticmethod
def toTFInputGraph(value):
if isinstance(value, TFInputGraph):
return value
else:
raise TypeError("Could not convert %s to TFInputGraph" % type(value))

@staticmethod
def asColumnToTensorNameMap(value):
"""
Expand Down
32 changes: 28 additions & 4 deletions python/sparkdl/param/shared_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
"""
import textwrap
from functools import wraps
import six

from pyspark.ml.param import Param, Params, TypeConverters

from sparkdl.graph.input import TFInputGraph
from sparkdl.param.converters import SparkDLTypeConverters

########################################################
Expand Down Expand Up @@ -196,8 +198,9 @@ class HasOutputMapping(Params):
"""
Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs
"""
outputMapping = Param(Params._dummy(), "outputMapping",
"Mapping output :class:`tf.Operation` names to DataFrame column names",
outputMapping = Param(Params._dummy(),
"outputMapping",
"Mapping output :class:`tf.Tensor` names to DataFrame column names",
typeConverter=SparkDLTypeConverters.asTensorNameToColumnMap)

def setOutputMapping(self, value):
Expand All @@ -211,8 +214,9 @@ class HasInputMapping(Params):
"""
Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorOpName') pairs
"""
inputMapping = Param(Params._dummy(), "inputMapping",
"Mapping input DataFrame column names to :class:`tf.Operation` names",
inputMapping = Param(Params._dummy(),
"inputMapping",
"Mapping input DataFrame column names to :class:`tf.Tensor` names",
typeConverter=SparkDLTypeConverters.asColumnToTensorNameMap)

def setInputMapping(self, value):
Expand All @@ -222,6 +226,26 @@ def getInputMapping(self):
return self.getOrDefault(self.inputMapping)


class HasTFInputGraph(Params):
"""
Mixin for param tfInputGraph: a serializable object derived from a TensorFlow computation graph.
"""
tfInputGraph = Param(Params._dummy(),
"tfInputGraph",
"A serializable object derived from a TensorFlow computation graph",
typeConverter=SparkDLTypeConverters.toTFInputGraph)

def __init__(self):
super(HasTFInputGraph, self).__init__()
self._setDefault(tfInputGraph=None)

def setTFInputGraph(self, value):
return self._set(tfInputGraph=value)

def getTFInputGraph(self):
return self.getOrDefault(self.tfInputGraph)


class HasTFHParams(Params):
"""
Mixin for TensorFlow model hyper-parameters
Expand Down
105 changes: 105 additions & 0 deletions python/sparkdl/transformers/tf_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import, division, print_function

import logging
import tensorflow as tf
from tensorflow.python.tools import optimize_for_inference_lib as infr_opt
import tensorframes as tfs

from pyspark.ml import Transformer

import sparkdl.graph.utils as tfx
from sparkdl.param import (keyword_only, HasInputMapping, HasOutputMapping,
HasTFInputGraph, HasTFHParams)

__all__ = ['TFTransformer']

logger = logging.getLogger('sparkdl')

class TFTransformer(Transformer, HasTFInputGraph, HasTFHParams, HasInputMapping, HasOutputMapping):
"""
Applies the TensorFlow graph to the array column in DataFrame.

Restrictions of the current API:

We assume that

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformer is expected to work on blocks of data at the same time.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this 3 lines below

- All the inputs of the graphs have a "minibatch" dimension (i.e. an unknown leading
dimension) in the tensor shapes.
- Input DataFrame has an array column where all elements have the same length
- The transformer is expected to work on blocks of data at the same time.
"""

@keyword_only
def __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None):
"""
__init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None)
"""
super(TFTransformer, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None):
"""
setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None)
"""
super(TFTransformer, self).__init__()
kwargs = self._input_kwargs
# Further conanonicalization, e.g. converting dict to sorted str pairs happens here
return self._set(**kwargs)

def _optimize_for_inference(self):
""" Optimize the graph for inference """
gin = self.getTFInputGraph()
input_mapping = self.getInputMapping()
output_mapping = self.getOutputMapping()
input_node_names = [tfx.op_name(tnsr_name) for _, tnsr_name in input_mapping]
output_node_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]

# NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type
opt_gdef = infr_opt.optimize_for_inference(gin.graph_def,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find! I could not find any guarantees about the stability of this function (it is part of a program). Do you know if they could become deprecated in the future?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API looks pretty solid. I didn't see any explicit sign of this being deprecated in the near future.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference_lib.py#L89-L104

input_node_names,
output_node_names,
# TODO: below is the place to change for
# the `float64` data type issue.
tf.float64.as_datatype_enum)
return opt_gdef

def _transform(self, dataset):
graph_def = self._optimize_for_inference()
input_mapping = self.getInputMapping()
output_mapping = self.getOutputMapping()

graph = tf.Graph()
with tf.Session(graph=graph):
analyzed_df = tfs.analyze(dataset)

out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]
tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names)

feed_dict = dict((tfx.op_name(tnsr_name, graph), col_name)
for col_name, tnsr_name in input_mapping)
fetches = [tfx.get_tensor(tnsr_op_name, graph) for tnsr_op_name in out_tnsr_op_names]

out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict)

# We still have to rename output columns

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, we still need to that, annoyingly

for tnsr_name, new_colname in output_mapping:
old_colname = tfx.op_name(tnsr_name, graph)
if old_colname != new_colname:
out_df = out_df.withColumnRenamed(old_colname, new_colname)

return out_df
40 changes: 36 additions & 4 deletions python/tests/graph/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ def gin_fun(session):
gin = _build_graph_input(gin_fun)
_check_input_novar(gin)

def test_saved_model_iomap(self):
with _make_temp_directory() as tmp_dir:
saved_model_dir = os.path.join(tmp_dir, 'saved_model')
graph = tf.Graph()
with tf.Session(graph=graph) as sess, graph.as_default():
_build_graph()
_build_saved_model(sess, saved_model_dir)
# Build the transformer from exported serving model
# We are using signatures, thus must provide the keys
gin = TFInputGraph.fromSavedModelWithSignature(saved_model_dir, _serving_tag,
_serving_sigdef_key)

_input_mapping_with_sigdef = {'inputCol': _tensor_input_signature}
# Input mapping for the Transformer
_translated_input_mapping = gin.translateInputMapping(_input_mapping_with_sigdef)
_expected_input_mapping = {'inputCol': tfx.tensor_name(_tensor_input_name)}
# Output mapping for the Transformer
_output_mapping_with_sigdef = {_tensor_output_signature: 'outputCol'}
_translated_output_mapping = gin.translateOutputMapping(_output_mapping_with_sigdef)
_expected_output_mapping = {tfx.tensor_name(_tensor_output_name): 'outputCol'}

err_msg = "signature based input mapping {} and output mapping {} " + \
"must be translated correctly into tensor name based mappings"
assert _translated_input_mapping == _expected_input_mapping \
and _translated_output_mapping == _expected_output_mapping, \
err_msg.format(_translated_input_mapping, _translated_output_mapping)


def test_saved_graph_novar(self):
with _make_temp_directory() as tmp_dir:
saved_model_dir = os.path.join(tmp_dir, 'saved_model')
Expand Down Expand Up @@ -118,6 +146,10 @@ def gin_fun(session):
_tensor_input_name_2 = "input_tensor_2"
# The name of the output tensor (scalar)
_tensor_output_name = "output_tensor"
# Input signature name
_tensor_input_signature = 'well_known_input_sig'
# Output signature name
_tensor_output_signature = 'well_known_output_sig'
# The name of the variable
_tensor_var_name = "variable"
# The size of the input tensor
Expand All @@ -135,8 +167,8 @@ def _build_checkpointed_model(session, tmp_dir):
w = tfx.get_tensor(_tensor_var_name, session.graph)
saver = tf.train.Saver(var_list=[w])
_ = saver.save(session, ckpt_path_prefix, global_step=2702)
sig_inputs = {'input_sig': tf.saved_model.utils.build_tensor_info(input_tensor)}
sig_outputs = {'output_sig': tf.saved_model.utils.build_tensor_info(output_tensor)}
sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)}
sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)}
serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
inputs=sig_inputs, outputs=sig_outputs)

Expand All @@ -163,8 +195,8 @@ def _build_saved_model(session, saved_model_dir):
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
input_tensor = tfx.get_tensor(_tensor_input_name, session.graph)
output_tensor = tfx.get_tensor(_tensor_output_name, session.graph)
sig_inputs = {'input_sig': tf.saved_model.utils.build_tensor_info(input_tensor)}
sig_outputs = {'output_sig': tf.saved_model.utils.build_tensor_info(output_tensor)}
sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)}
sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)}
serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
inputs=sig_inputs, outputs=sig_outputs)

Expand Down
17 changes: 14 additions & 3 deletions python/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,32 @@ class PythonUnitTestCase(unittest.TestCase):
# This class is created to avoid replicating this logic in various places.
pass

class SparkDLTestCase(unittest.TestCase):

class TestSparkContext(object):
@classmethod
def setUpClass(cls):
def setup_env(cls):
cls.sc = SparkContext('local[*]', cls.__name__)
cls.sql = SQLContext(cls.sc)
cls.session = SparkSession.builder.getOrCreate()

@classmethod
def tearDownClass(cls):
def tear_down_env(cls):
cls.session.stop()
cls.session = None
cls.sc.stop()
cls.sc = None
cls.sql = None


class SparkDLTestCase(TestSparkContext, unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.setup_env()

@classmethod
def tearDownClass(cls):
cls.tear_down_env()

def assertDfHasCols(self, df, cols = []):
map(lambda c: self.assertIn(c, df.columns), cols)
Loading