Skip to content

Commit 2d25c32

Browse files
committed
optimize graph for inference
1 parent 6e46073 commit 2d25c32

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

python/sparkdl/graph/builder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@ def __init__(self, graph=None, using_keras=False):
5252
self.keras_prev_sess = None
5353

5454
def __enter__(self):
55-
self.sess.as_default()
55+
#self.sess.as_default()
5656
self.sess.__enter__()
57-
if self.keras_prev_sess is not None:
58-
K.set_session(self.sess)
57+
K.set_session(self.sess)
5958
return self
6059

6160
def __exit__(self, *args):
@@ -268,4 +267,3 @@ def fromList(cls, functions):
268267
gfn = issn.asGraphFunction(first_inputs, last_outputs)
269268

270269
return gfn
271-

python/sparkdl/transformers/tf_tensor.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import tensorflow as tf
19+
from tensorflow.python.tools import optimize_for_inference_lib as infr_opt
1920
import tensorframes as tfs
2021

2122
from pyspark.ml import Transformer
@@ -60,17 +61,32 @@ def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tf
6061
# Further conanonicalization, e.g. converting dict to sorted str pairs happens here
6162
return self._set(**kwargs)
6263

63-
def _transform(self, dataset):
64+
def _optimize_for_inference(self):
65+
""" Optimize the graph for inference """
6466
gin = self.getTFInputGraph()
6567
input_mapping = self.getInputMapping()
6668
output_mapping = self.getOutputMapping()
69+
input_node_names = [tfx.as_op_name(tnsr_name) for _, tnsr_name in input_mapping]
70+
output_node_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping]
71+
72+
# NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type
73+
opt_gdef = infr_opt.optimize_for_inference(gin.graph_def,
74+
input_node_names,
75+
output_node_names,
76+
tf.float64.as_datatype_enum)
77+
return opt_gdef
78+
79+
def _transform(self, dataset):
80+
graph_def = self._optimize_for_inference()
81+
input_mapping = self.getInputMapping()
82+
output_mapping = self.getOutputMapping()
6783

6884
graph = tf.Graph()
6985
with tf.Session(graph=graph):
7086
analyzed_df = tfs.analyze(dataset)
7187

7288
out_tnsr_op_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping]
73-
tf.import_graph_def(graph_def=gin.graph_def, name='', return_elements=out_tnsr_op_names)
89+
tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names)
7490

7591
feed_dict = dict((tfx.op_name(graph, tnsr_name), col_name)
7692
for col_name, tnsr_name in input_mapping)

python/tests/transformers/tf_tensor_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ def _run_test_in_tf_session(self):
157157
_results.append(np.ravel(curr_res))
158158
out_tgt = np.hstack(_results)
159159

160+
err_msg = 'not close => {} != {}, max_diff {}'
160161
self.assertTrue(np.allclose(out_ref, out_tgt),
161-
msg='not close => {} != {}'.format(out_ref.shape, out_tgt.shape))
162+
msg=err_msg.format(out_ref.shape, out_tgt.shape,
163+
np.max(np.abs(out_ref - out_tgt))))
162164

163165

164166
def test_build_from_tf_graph(self):

0 commit comments

Comments
 (0)