Skip to content

Commit 66507f4

Browse files
committed
tf_image inputTensor default setter bug-fix
1 parent 76e9fb9 commit 66507f4

File tree

3 files changed

+30
-27
lines changed

3 files changed

+30
-27
lines changed

python/sparkdl/param/converters.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,15 @@ def toTFHParams(value):
9292

9393
@staticmethod
9494
def toTFTensorName(value):
95-
""" Convert a value to a str or a :py:obj:`tf.Tensor` object, if possible. """
95+
""" Convert a value to a :py:obj:`tf.Tensor` name, if possible. """
9696
if isinstance(value, tf.Tensor):
9797
return value.name
9898
try:
99-
_check_is_tensor_name(value)
100-
return TypeConverters.toString(value)
99+
_maybe_tnsr_name = TypeConverters.toString(value)
100+
_check_is_tensor_name(_maybe_tnsr_name)
101+
return _maybe_tnsr_name
101102
except Exception as exc:
102-
err_msg = "Could not convert [type {}] {} to tf.Tensor or str. {}"
103+
err_msg = "Could not convert [type {}] {} to tf.Tensor name. {}"
103104
raise TypeError(err_msg.format(type(value), value, exc))
104105

105106
@staticmethod
@@ -142,19 +143,19 @@ def toKerasOptimizer(value):
142143

143144
def _check_is_tensor_name(_maybe_tnsr_name):
144145
""" Check if the input is a valid tensor name """
145-
try:
146-
assert isinstance(_maybe_tnsr_name, six.string_types), \
147-
"must provide a strict tensor name as input, but got {}".format(type(_maybe_tnsr_name))
146+
assert isinstance(_maybe_tnsr_name, six.string_types), \
147+
"expect tensor name to be of string type, but got [type {}]".format(type(_maybe_tnsr_name))
148148

149-
# The check is taken from TensorFlow's NodeDef protocol buffer.
150-
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
149+
# The check is taken from TensorFlow's NodeDef protocol buffer.
150+
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
151+
try:
151152
_, src_idx = _maybe_tnsr_name.split(":")
152153
_ = int(src_idx)
153-
except Exception as exc:
154-
err_msg = "Can NOT convert [type {}] {} to tf.Tensor name: {}"
155-
raise TypeError(err_msg.format(type(_maybe_tnsr_name), _maybe_tnsr_name, exc))
156-
else:
157-
return _maybe_tnsr_name
154+
except Exception:
155+
err_msg = "Tensor name must be of type <op_name>:<index>, but got {}"
156+
raise TypeError(err_msg.format(_maybe_tnsr_name))
157+
158+
return _maybe_tnsr_name
158159

159160

160161
def _check_is_str(_maybe_col_name):

python/sparkdl/transformers/keras_applications.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,3 @@ def _testKerasModel(self, include_top):
109109
"InceptionV3": InceptionV3Model,
110110
"Xception": XceptionModel
111111
}
112-

python/sparkdl/transformers/tf_image.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
import sparkdl.utils.jvmapi as JVMAPI
2929
import sparkdl.graph.utils as tfx
3030

31+
__all__ = ['TFImageTransformer']
32+
33+
IMAGE_INPUT_TENSOR_NAME = tfx.as_tensor_name(utils.IMAGE_INPUT_PLACEHOLDER_NAME)
34+
USER_GRAPH_NAMESPACE = 'given'
35+
NEW_OUTPUT_PREFIX = 'sdl_flattened'
36+
3137
class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
3238
"""
3339
Applies the Tensorflow graph to the image column in DataFrame.
@@ -47,9 +53,6 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
4753
since a new session is created inside this transformer.
4854
"""
4955

50-
USER_GRAPH_NAMESPACE = 'given'
51-
NEW_OUTPUT_PREFIX = 'sdl_flattened'
52-
5356
graph = Param(Params._dummy(), "graph", "A TensorFlow computation graph",
5457
typeConverter=SparkDLTypeConverters.toTFGraph)
5558
inputTensor = Param(Params._dummy(), "inputTensor",
@@ -61,28 +64,28 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
6164

6265
@keyword_only
6366
def __init__(self, inputCol=None, outputCol=None, graph=None,
64-
inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None,
67+
inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None,
6568
outputMode="vector"):
6669
"""
6770
__init__(self, inputCol=None, outputCol=None, graph=None,
68-
inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None,
71+
inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None,
6972
outputMode="vector")
7073
"""
7174
super(TFImageTransformer, self).__init__()
72-
self._setDefault(inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME)
73-
self._setDefault(outputMode="vector")
7475
kwargs = self._input_kwargs
7576
self.setParams(**kwargs)
7677

7778
@keyword_only
7879
def setParams(self, inputCol=None, outputCol=None, graph=None,
79-
inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None,
80+
inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None,
8081
outputMode="vector"):
8182
"""
8283
setParams(self, inputCol=None, outputCol=None, graph=None,
83-
inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None,
84+
inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None,
8485
outputMode="vector")
8586
"""
87+
self._setDefault(inputTensor=IMAGE_INPUT_TENSOR_NAME)
88+
self._setDefault(outputMode="vector")
8689
kwargs = self._input_kwargs
8790
return self._set(**kwargs)
8891

@@ -179,7 +182,7 @@ def _addReshapeLayers(self, tf_graph, dtype="uint8"):
179182
# Add on the original graph
180183
tf.import_graph_def(gdef, input_map={input_tensor_name: image_reshaped_expanded},
181184
return_elements=[self.getOutputTensor().name],
182-
name=self.USER_GRAPH_NAMESPACE)
185+
name=USER_GRAPH_NAMESPACE)
183186

184187
# Flatten the output for tensorframes
185188
output_node = g.get_tensor_by_name(self._getOriginalOutputTensorName())
@@ -198,10 +201,10 @@ def _stripGraph(self, tf_graph):
198201
return g
199202

200203
def _getOriginalOutputTensorName(self):
201-
return self.USER_GRAPH_NAMESPACE + '/' + self.getOutputTensor().name
204+
return USER_GRAPH_NAMESPACE + '/' + self.getOutputTensor().name
202205

203206
def _getFinalOutputTensorName(self):
204-
return self.NEW_OUTPUT_PREFIX + '_' + self.getOutputTensor().name
207+
return NEW_OUTPUT_PREFIX + '_' + self.getOutputTensor().name
205208

206209
def _getFinalOutputOpName(self):
207210
return tfx.as_op_name(self._getFinalOutputTensorName())

0 commit comments

Comments
 (0)