-
Notifications
You must be signed in to change notification settings - Fork 0
TensorFlow Transformer Part-4 #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
42c6e6e
ecbefb9
ab89bd2
8c7d72e
eb543c6
4743bb9
622c788
07f1cec
692b0eb
66d44e9
9b3fe86
8c32501
dbd9aaa
4572205
1cc7591
2fc6787
889df0a
86cd6d9
ac09182
6b22eed
a3517d6
6e46073
b232b3c
97b25c6
07c58e6
269ad15
84a8138
6e880ce
883321e
c72444b
89e2a1d
6aa85b9
85e0778
22754c9
0144b8c
c6eb87c
812f4d6
47d497c
07cc335
925fc0d
91b9379
af95b74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright 2017 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import logging | ||
import tensorflow as tf | ||
from tensorflow.python.tools import optimize_for_inference_lib as infr_opt | ||
import tensorframes as tfs | ||
|
||
from pyspark.ml import Transformer | ||
|
||
import sparkdl.graph.utils as tfx | ||
from sparkdl.param import (keyword_only, HasInputMapping, HasOutputMapping, | ||
HasTFInputGraph, HasTFHParams) | ||
|
||
__all__ = ['TFTransformer'] | ||
|
||
logger = logging.getLogger('sparkdl') | ||
|
||
class TFTransformer(Transformer, HasTFInputGraph, HasTFHParams, HasInputMapping, HasOutputMapping): | ||
""" | ||
Applies the TensorFlow graph to the array column in DataFrame. | ||
|
||
Restrictions of the current API: | ||
|
||
We assume that | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this 3 lines below |
||
- All the inputs of the graphs have a "minibatch" dimension (i.e. an unknown leading | ||
dimension) in the tensor shapes. | ||
- Input DataFrame has an array column where all elements have the same length | ||
- The transformer is expected to work on blocks of data at the same time. | ||
""" | ||
|
||
@keyword_only | ||
def __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): | ||
""" | ||
__init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) | ||
""" | ||
super(TFTransformer, self).__init__() | ||
kwargs = self._input_kwargs | ||
self.setParams(**kwargs) | ||
|
||
@keyword_only | ||
def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): | ||
""" | ||
setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) | ||
""" | ||
super(TFTransformer, self).__init__() | ||
kwargs = self._input_kwargs | ||
# Further conanonicalization, e.g. converting dict to sorted str pairs happens here | ||
return self._set(**kwargs) | ||
|
||
def _optimize_for_inference(self): | ||
""" Optimize the graph for inference """ | ||
gin = self.getTFInputGraph() | ||
input_mapping = self.getInputMapping() | ||
output_mapping = self.getOutputMapping() | ||
input_node_names = [tfx.op_name(tnsr_name) for _, tnsr_name in input_mapping] | ||
output_node_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping] | ||
|
||
# NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type | ||
opt_gdef = infr_opt.optimize_for_inference(gin.graph_def, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great find! I could not find any guarantees about the stability of this function (it is part of a program). Do you know if they could become deprecated in the future? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The API looks pretty solid. I didn't see any explicit sign of this being deprecated in the near future. |
||
input_node_names, | ||
output_node_names, | ||
# TODO: below is the place to change for | ||
# the `float64` data type issue. | ||
tf.float64.as_datatype_enum) | ||
return opt_gdef | ||
|
||
def _transform(self, dataset): | ||
graph_def = self._optimize_for_inference() | ||
input_mapping = self.getInputMapping() | ||
output_mapping = self.getOutputMapping() | ||
|
||
graph = tf.Graph() | ||
with tf.Session(graph=graph): | ||
analyzed_df = tfs.analyze(dataset) | ||
|
||
out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping] | ||
tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names) | ||
|
||
feed_dict = dict((tfx.op_name(tnsr_name, graph), col_name) | ||
for col_name, tnsr_name in input_mapping) | ||
fetches = [tfx.get_tensor(tnsr_op_name, graph) for tnsr_op_name in out_tnsr_op_names] | ||
|
||
out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict) | ||
|
||
# We still have to rename output columns | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, we still need to that, annoyingly |
||
for tnsr_name, new_colname in output_mapping: | ||
old_colname = tfx.op_name(tnsr_name, graph) | ||
if old_colname != new_colname: | ||
out_df = out_df.withColumnRenamed(old_colname, new_colname) | ||
|
||
return out_df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these two functions are either too much or not enough: either you should provide some tests and some doc examples, or not include them. Since they are not used elsewhere, let's put then in a separate PR for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, they are actually in our API design. Let me add some tests for these guys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I see it below.