28
28
import sparkdl .utils .jvmapi as JVMAPI
29
29
import sparkdl .graph .utils as tfx
30
30
31
+ __all__ = ['TFImageTransformer' ]
32
+
33
+ IMAGE_INPUT_TENSOR_NAME = tfx .as_tensor_name (utils .IMAGE_INPUT_PLACEHOLDER_NAME )
34
+ USER_GRAPH_NAMESPACE = 'given'
35
+ NEW_OUTPUT_PREFIX = 'sdl_flattened'
36
+
31
37
class TFImageTransformer (Transformer , HasInputCol , HasOutputCol , HasOutputMode ):
32
38
"""
33
39
Applies the Tensorflow graph to the image column in DataFrame.
@@ -47,9 +53,6 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
47
53
since a new session is created inside this transformer.
48
54
"""
49
55
50
- USER_GRAPH_NAMESPACE = 'given'
51
- NEW_OUTPUT_PREFIX = 'sdl_flattened'
52
-
53
56
graph = Param (Params ._dummy (), "graph" , "A TensorFlow computation graph" ,
54
57
typeConverter = SparkDLTypeConverters .toTFGraph )
55
58
inputTensor = Param (Params ._dummy (), "inputTensor" ,
@@ -61,28 +64,28 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
61
64
62
65
@keyword_only
63
66
def __init__ (self , inputCol = None , outputCol = None , graph = None ,
64
- inputTensor = utils . IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor = None ,
67
+ inputTensor = IMAGE_INPUT_TENSOR_NAME , outputTensor = None ,
65
68
outputMode = "vector" ):
66
69
"""
67
70
__init__(self, inputCol=None, outputCol=None, graph=None,
68
- inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor=None,
71
+ inputTensor=IMAGE_INPUT_TENSOR_NAME , outputTensor=None,
69
72
outputMode="vector")
70
73
"""
71
74
super (TFImageTransformer , self ).__init__ ()
72
- self ._setDefault (inputTensor = utils .IMAGE_INPUT_PLACEHOLDER_NAME )
73
- self ._setDefault (outputMode = "vector" )
74
75
kwargs = self ._input_kwargs
75
76
self .setParams (** kwargs )
76
77
77
78
@keyword_only
78
79
def setParams (self , inputCol = None , outputCol = None , graph = None ,
79
- inputTensor = utils . IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor = None ,
80
+ inputTensor = IMAGE_INPUT_TENSOR_NAME , outputTensor = None ,
80
81
outputMode = "vector" ):
81
82
"""
82
83
setParams(self, inputCol=None, outputCol=None, graph=None,
83
- inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor=None,
84
+ inputTensor=IMAGE_INPUT_TENSOR_NAME , outputTensor=None,
84
85
outputMode="vector")
85
86
"""
87
+ self ._setDefault (inputTensor = IMAGE_INPUT_TENSOR_NAME )
88
+ self ._setDefault (outputMode = "vector" )
86
89
kwargs = self ._input_kwargs
87
90
return self ._set (** kwargs )
88
91
@@ -179,7 +182,7 @@ def _addReshapeLayers(self, tf_graph, dtype="uint8"):
179
182
# Add on the original graph
180
183
tf .import_graph_def (gdef , input_map = {input_tensor_name : image_reshaped_expanded },
181
184
return_elements = [self .getOutputTensor ().name ],
182
- name = self . USER_GRAPH_NAMESPACE )
185
+ name = USER_GRAPH_NAMESPACE )
183
186
184
187
# Flatten the output for tensorframes
185
188
output_node = g .get_tensor_by_name (self ._getOriginalOutputTensorName ())
@@ -198,10 +201,10 @@ def _stripGraph(self, tf_graph):
198
201
return g
199
202
200
203
def _getOriginalOutputTensorName (self ):
201
- return self . USER_GRAPH_NAMESPACE + '/' + self .getOutputTensor ().name
204
+ return USER_GRAPH_NAMESPACE + '/' + self .getOutputTensor ().name
202
205
203
206
def _getFinalOutputTensorName (self ):
204
- return self . NEW_OUTPUT_PREFIX + '_' + self .getOutputTensor ().name
207
+ return NEW_OUTPUT_PREFIX + '_' + self .getOutputTensor ().name
205
208
206
209
def _getFinalOutputOpName (self ):
207
210
return tfx .as_op_name (self ._getFinalOutputTensorName ())
0 commit comments