|
16 | 16 |
|
17 | 17 | import logging
|
18 | 18 | import tensorflow as tf
|
| 19 | +from tensorflow.python.tools import optimize_for_inference_lib as infr_opt |
19 | 20 | import tensorframes as tfs
|
20 | 21 |
|
21 | 22 | from pyspark.ml import Transformer
|
@@ -60,17 +61,32 @@ def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tf
|
60 | 61 | # Further conanonicalization, e.g. converting dict to sorted str pairs happens here
|
61 | 62 | return self._set(**kwargs)
|
62 | 63 |
|
63 |
| - def _transform(self, dataset): |
| 64 | + def _optimize_for_inference(self): |
| 65 | + """ Optimize the graph for inference """ |
64 | 66 | gin = self.getTFInputGraph()
|
65 | 67 | input_mapping = self.getInputMapping()
|
66 | 68 | output_mapping = self.getOutputMapping()
|
| 69 | + input_node_names = [tfx.as_op_name(tnsr_name) for _, tnsr_name in input_mapping] |
| 70 | + output_node_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping] |
| 71 | + |
| 72 | + # NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type |
| 73 | + opt_gdef = infr_opt.optimize_for_inference(gin.graph_def, |
| 74 | + input_node_names, |
| 75 | + output_node_names, |
| 76 | + tf.float64.as_datatype_enum) |
| 77 | + return opt_gdef |
| 78 | + |
| 79 | + def _transform(self, dataset): |
| 80 | + graph_def = self._optimize_for_inference() |
| 81 | + input_mapping = self.getInputMapping() |
| 82 | + output_mapping = self.getOutputMapping() |
67 | 83 |
|
68 | 84 | graph = tf.Graph()
|
69 | 85 | with tf.Session(graph=graph):
|
70 | 86 | analyzed_df = tfs.analyze(dataset)
|
71 | 87 |
|
72 | 88 | out_tnsr_op_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping]
|
73 |
| - tf.import_graph_def(graph_def=gin.graph_def, name='', return_elements=out_tnsr_op_names) |
| 89 | + tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names) |
74 | 90 |
|
75 | 91 | feed_dict = dict((tfx.op_name(graph, tnsr_name), col_name)
|
76 | 92 | for col_name, tnsr_name in input_mapping)
|
|
0 commit comments