Skip to content

Commit 7ae89e4

Browse files
committed
PR comments
converters simpiliciation
1 parent 77b3906 commit 7ae89e4

File tree

4 files changed

+67
-62
lines changed

4 files changed

+67
-62
lines changed

python/sparkdl/param/converters.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,120 +33,131 @@ class SparkDLTypeConverters(object):
3333
These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
3434
They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
3535
"""
36+
3637
@staticmethod
3738
def toTFGraph(value):
38-
if isinstance(value, tf.Graph):
39-
return value
40-
else:
41-
raise TypeError("Could not convert %s to TensorFlow Graph" % type(value))
39+
if not isinstance(value, tf.Graph):
40+
raise TypeError("Could not convert %s to tf.Graph" % type(value))
41+
return value
4242

4343
@staticmethod
4444
def asColumnToTensorNameMap(value):
4545
"""
4646
Convert a value to a column name to :py:obj:`tf.Tensor` name mapping
4747
as a sorted list of string pairs, if possible.
4848
"""
49-
if isinstance(value, dict):
50-
strs_pair_seq = []
51-
for _maybe_col_name, _maybe_tnsr_name in value.items():
52-
# Check if the non-tensor value is of string type
53-
_col_name = _get_strict_col_name(_maybe_col_name)
54-
# Check if the tensor name is actually valid
55-
_tnsr_name = _get_strict_tensor_name(_maybe_tnsr_name)
56-
strs_pair_seq.append((_col_name, _tnsr_name))
49+
if not isinstance(value, dict):
50+
err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
51+
raise TypeError(err_msg.format(type(value), value))
5752

58-
return sorted(strs_pair_seq)
53+
# Convertion logic after quick type check
54+
strs_pair_seq = []
55+
for _maybe_col_name, _maybe_tnsr_name in value.items():
56+
# Check if the non-tensor value is of string type
57+
_check_is_str(_maybe_col_name)
58+
# Check if the tensor name looks like a tensor name
59+
_check_is_tensor_name(_maybe_tnsr_name)
60+
strs_pair_seq.append((_maybe_col_name, _maybe_tnsr_name))
5961

60-
err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
61-
raise TypeError(err_msg.format(type(value), value))
62+
return sorted(strs_pair_seq)
6263

6364
@staticmethod
6465
def asTensorNameToColumnMap(value):
6566
"""
6667
Convert a value to a :py:obj:`tf.Tensor` name to column name mapping
6768
as a sorted list of string pairs, if possible.
6869
"""
69-
if isinstance(value, dict):
70-
strs_pair_seq = []
71-
for _maybe_tnsr_name, _maybe_col_name in value.items():
72-
# Check if the non-tensor value is of string type
73-
_col_name = _get_strict_col_name(_maybe_col_name)
74-
# Check if the tensor name is actually valid
75-
_tnsr_name = _get_strict_tensor_name(_maybe_tnsr_name)
76-
strs_pair_seq.append((_tnsr_name, _col_name))
70+
if not isinstance(value, dict):
71+
err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
72+
raise TypeError(err_msg.format(type(value), value))
7773

78-
return sorted(strs_pair_seq)
74+
# Convertion logic after quick type check
75+
strs_pair_seq = []
76+
for _maybe_tnsr_name, _maybe_col_name in value.items():
77+
# Check if the non-tensor value is of string type
78+
_check_is_str(_maybe_col_name)
79+
# Check if the tensor name looks like a tensor name
80+
_check_is_tensor_name(_maybe_tnsr_name)
81+
strs_pair_seq.append((_maybe_tnsr_name, _maybe_col_name))
7982

80-
err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
81-
raise TypeError(err_msg.format(type(value), value))
83+
return sorted(strs_pair_seq)
8284

8385
@staticmethod
8486
def toTFHParams(value):
8587
""" Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
86-
if isinstance(value, tf.contrib.training.HParams):
87-
return value
88-
else:
88+
if not isinstance(value, tf.contrib.training.HParams):
8989
raise TypeError("Could not convert %s to TensorFlow HParams" % type(value))
9090

91+
return value
92+
9193
@staticmethod
92-
def toStringOrTFTensor(value):
94+
def toTFTensorName(value):
9395
""" Convert a value to a str or a :py:obj:`tf.Tensor` object, if possible. """
9496
if isinstance(value, tf.Tensor):
95-
return value
97+
return value.name
9698
try:
99+
_check_is_tensor_name(value)
97100
return TypeConverters.toString(value)
98101
except Exception as exc:
99102
err_msg = "Could not convert [type {}] {} to tf.Tensor or str. {}"
100103
raise TypeError(err_msg.format(type(value), value, exc))
101104

102105
@staticmethod
103-
def supportedNameConverter(supportedList):
106+
def buildCheckList(supportedList):
104107
"""
105108
Create a converter that try to check if a value is part of the supported list.
106109
107110
:param supportedList: list, containing supported objects.
108111
:return: a converter that try to convert a value if it is part of the `supportedList`.
109112
"""
113+
110114
def converter(value):
111-
if value in supportedList:
112-
return value
113-
err_msg = "[type {}] {} is not in the supported list: {}"
114-
raise TypeError(err_msg.format(type(value), str(value), supportedList))
115+
if value not in supportedList:
116+
err_msg = "[type {}] {} is not in the supported list: {}"
117+
raise TypeError(err_msg.format(type(value), str(value), supportedList))
118+
119+
return value
115120

116121
return converter
117122

118123
@staticmethod
119124
def toKerasLoss(value):
120125
""" Convert a value to a name of Keras loss function, if possible """
121-
if kmutil.is_valid_loss_function(value):
122-
return value
123-
err_msg = "Named loss not supported in Keras: [type {}] {}"
124-
raise ValueError(err_msg.format(type(value), value))
126+
# return early in for clarify as well as less indentation
127+
if not kmutil.is_valid_loss_function(value):
128+
err_msg = "Named loss not supported in Keras: [type {}] {}"
129+
raise ValueError(err_msg.format(type(value), value))
130+
131+
return value
125132

126133
@staticmethod
127134
def toKerasOptimizer(value):
128135
""" Convert a value to a name of Keras optimizer, if possible """
129-
if kmutil.is_valid_optimizer(value):
130-
return value
131-
err_msg = "Named optimizer not supported in Keras: [type {}] {}"
132-
raise TypeError(err_msg.format(type(value), value))
136+
if not kmutil.is_valid_optimizer(value):
137+
err_msg = "Named optimizer not supported in Keras: [type {}] {}"
138+
raise TypeError(err_msg.format(type(value), value))
139+
140+
return value
133141

134142

135-
def _get_strict_tensor_name(_maybe_tnsr_name):
143+
def _check_is_tensor_name(_maybe_tnsr_name):
136144
""" Check if the input is a valid tensor name """
137145
try:
138146
assert isinstance(_maybe_tnsr_name, six.string_types), \
139147
"must provide a strict tensor name as input, but got {}".format(type(_maybe_tnsr_name))
140-
assert tfx.as_tensor_name(_maybe_tnsr_name) == _maybe_tnsr_name, \
141-
"input {} must be a valid tensor name".format(_maybe_tnsr_name)
148+
149+
# The check is taken from TensorFlow's NodeDef protocol buffer.
150+
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
151+
_, src_idx = _maybe_tnsr_name.split(":")
152+
_ = int(src_idx)
142153
except Exception as exc:
143154
err_msg = "Can NOT convert [type {}] {} to tf.Tensor name: {}"
144155
raise TypeError(err_msg.format(type(_maybe_tnsr_name), _maybe_tnsr_name, exc))
145156
else:
146157
return _maybe_tnsr_name
147158

148159

149-
def _get_strict_col_name(_maybe_col_name):
160+
def _check_is_str(_maybe_col_name):
150161
""" Check if the given colunm name is a valid column name """
151162
# We only check if the column name candidate is a string type
152163
if not isinstance(_maybe_col_name, six.string_types):

python/sparkdl/param/image_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class HasOutputMode(Params):
107107
"How the output column should be formatted. 'vector' for a 1-d MLlib " +
108108
"Vector of floats. 'image' to format the output to work with the image " +
109109
"tools in this package.",
110-
typeConverter=SparkDLTypeConverters.supportedNameConverter(OUTPUT_MODES))
110+
typeConverter=SparkDLTypeConverters.buildCheckList(OUTPUT_MODES))
111111

112112
def setOutputMode(self, value):
113113
return self._set(outputMode=value)

python/sparkdl/transformers/named_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class DeepImagePredictor(Transformer, HasInputCol, HasOutputCol):
4040
"""
4141

4242
modelName = Param(Params._dummy(), "modelName", "A deep learning model name",
43-
typeConverter=SparkDLTypeConverters.supportedNameConverter(SUPPORTED_MODELS))
43+
typeConverter=SparkDLTypeConverters.buildCheckList(SUPPORTED_MODELS))
4444
decodePredictions = Param(Params._dummy(), "decodePredictions",
4545
"If true, output predictions in the (class, description, probability) format",
4646
typeConverter=TypeConverters.toBoolean)

python/sparkdl/transformers/tf_image.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
5454
typeConverter=SparkDLTypeConverters.toTFGraph)
5555
inputTensor = Param(Params._dummy(), "inputTensor",
5656
"A TensorFlow tensor object or name representing the input image",
57-
typeConverter=SparkDLTypeConverters.toStringOrTFTensor)
57+
typeConverter=SparkDLTypeConverters.toTFTensorName)
5858
outputTensor = Param(Params._dummy(), "outputTensor",
5959
"A TensorFlow tensor object or name representing the output",
60-
typeConverter=SparkDLTypeConverters.toStringOrTFTensor)
60+
typeConverter=SparkDLTypeConverters.toTFTensorName)
6161

6262
@keyword_only
6363
def __init__(self, inputCol=None, outputCol=None, graph=None,
@@ -99,18 +99,12 @@ def getGraph(self):
9999
return self.getOrDefault(self.graph)
100100

101101
def getInputTensor(self):
102-
tensor_or_name = self.getOrDefault(self.inputTensor)
103-
if isinstance(tensor_or_name, tf.Tensor):
104-
return tensor_or_name
105-
else:
106-
return self.getGraph().get_tensor_by_name(tensor_or_name)
102+
tensor_name = self.getOrDefault(self.inputTensor)
103+
return self.getGraph().get_tensor_by_name(tensor_name)
107104

108105
def getOutputTensor(self):
109-
tensor_or_name = self.getOrDefault(self.outputTensor)
110-
if isinstance(tensor_or_name, tf.Tensor):
111-
return tensor_or_name
112-
else:
113-
return self.getGraph().get_tensor_by_name(tensor_or_name)
106+
tensor_name = self.getOrDefault(self.outputTensor)
107+
return self.getGraph().get_tensor_by_name(tensor_name)
114108

115109
def _transform(self, dataset):
116110
graph = self.getGraph()

0 commit comments

Comments
 (0)