Skip to content

Commit 3c849f2

Browse files
committed
Merge branch 'tf-transformer-part1' into tf-transformer-part2
2 parents f0912fb + a8531ec commit 3c849f2

File tree

4 files changed

+44
-36
lines changed

4 files changed

+44
-36
lines changed

python/sparkdl/param/converters.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,15 @@
1313
# limitations under the License.
1414
#
1515

16-
# pylint: disable=wrong-spelling-in-docstring,invalid-name,import-error
16+
# pylint: disable=invalid-name,import-error
1717

1818
""" SparkDLTypeConverters
19-
Type conversion utilities for definition Spark Deep Learning related MLlib `Params`.
19+
20+
Type conversion utilities for defining MLlib `Params` used in Spark Deep Learning Pipelines.
21+
22+
.. note:: We follow the convention of MLlib to name these utilities "converters",
23+
but most of them act as type checkers that return the argument if it is
24+
the desired type and raise `TypeError` otherwise.
2025
"""
2126

2227
import six
@@ -33,7 +38,7 @@ class SparkDLTypeConverters(object):
3338
"""
3439
.. note:: DeveloperApi
3540
36-
Factory methods for type conversion functions for :py:func:`Param.typeConverter`.
41+
Methods for type conversion functions for :py:func:`Param.typeConverter`.
3742
These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
3843
They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
3944
"""
@@ -50,19 +55,16 @@ def toTFGraph(value):
5055
@staticmethod
5156
def asColumnToTensorNameMap(value):
5257
"""
53-
Convert a value to a column name to :py:obj:`tf.Tensor` name mapping
54-
as a sorted list of string pairs, if possible.
58+
Convert a value to a column name to :py:class:`tf.Tensor` name mapping
59+
as a sorted list (in lexicographical order) of string pairs, if possible.
5560
"""
5661
if not isinstance(value, dict):
5762
err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
5863
raise TypeError(err_msg.format(type(value), value))
5964

60-
# Conversion logic after quick type check
6165
strs_pair_seq = []
6266
for _maybe_col_name, _maybe_tnsr_name in value.items():
63-
# Check if the non-tensor value is of string type
6467
_check_is_str(_maybe_col_name)
65-
# Check if the tensor name looks like a tensor name
6668
_check_is_tensor_name(_maybe_tnsr_name)
6769
strs_pair_seq.append((_maybe_col_name, _maybe_tnsr_name))
6870

@@ -71,52 +73,55 @@ def asColumnToTensorNameMap(value):
7173
@staticmethod
7274
def asTensorNameToColumnMap(value):
7375
"""
74-
Convert a value to a :py:obj:`tf.Tensor` name to column name mapping
75-
as a sorted list of string pairs, if possible.
76+
Convert a value to a :py:class:`tf.Tensor` name to column name mapping
77+
as a sorted list (in lexicographical order) of string pairs, if possible.
7678
"""
7779
if not isinstance(value, dict):
7880
err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
7981
raise TypeError(err_msg.format(type(value), value))
8082

81-
# Conversion logic after quick type check
8283
strs_pair_seq = []
8384
for _maybe_tnsr_name, _maybe_col_name in value.items():
84-
# Check if the non-tensor value is of string type
8585
_check_is_str(_maybe_col_name)
86-
# Check if the tensor name looks like a tensor name
8786
_check_is_tensor_name(_maybe_tnsr_name)
8887
strs_pair_seq.append((_maybe_tnsr_name, _maybe_col_name))
8988

9089
return sorted(strs_pair_seq)
9190

9291
@staticmethod
9392
def toTFHParams(value):
94-
""" Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
93+
"""
94+
Check that the given value is a :py:class:`tf.contrib.training.HParams` object,
95+
and return it. Raise an error otherwise.
96+
"""
9597
if not isinstance(value, tf.contrib.training.HParams):
9698
raise TypeError("Could not convert %s to TensorFlow HParams" % type(value))
9799

98100
return value
99101

100102
@staticmethod
101103
def toTFTensorName(value):
102-
""" Convert a value to a :py:obj:`tf.Tensor` name, if possible. """
104+
"""
105+
Check if a value is a valid :py:class:`tf.Tensor` name and return it.
106+
Raise an error otherwise.
107+
"""
103108
if isinstance(value, tf.Tensor):
104109
return value.name
105110
try:
106-
_maybe_tnsr_name = TypeConverters.toString(value)
107-
_check_is_tensor_name(_maybe_tnsr_name)
108-
return _maybe_tnsr_name
111+
_check_is_tensor_name(value)
112+
return value
109113
except Exception as exc:
110114
err_msg = "Could not convert [type {}] {} to tf.Tensor name. {}"
111115
raise TypeError(err_msg.format(type(value), value, exc))
112116

113117
@staticmethod
114-
def buildCheckList(supportedList):
118+
def buildSupportedItemConverter(supportedList):
115119
"""
116-
Create a converter that try to check if a value is part of the supported list.
120+
Create a "converter" that try to check if a value is part of the supported list of values.
117121
118122
:param supportedList: list, containing supported objects.
119-
:return: a converter that try to convert a value if it is part of the `supportedList`.
123+
:return: a converter that try to check if a value is part of the `supportedList` and return it.
124+
Raise an error otherwise.
120125
"""
121126

122127
def converter(value):
@@ -131,7 +136,10 @@ def converter(value):
131136

132137
@staticmethod
133138
def toKerasLoss(value):
134-
""" Convert a value to a name of Keras loss function, if possible """
139+
"""
140+
Check if a value is a valid Keras loss function name and return it.
141+
Otherwise raise an error.
142+
"""
135143
# return early in for clarify as well as less indentation
136144
if not kmutil.is_valid_loss_function(value):
137145
err_msg = "Named loss not supported in Keras: [type {}] {}"
@@ -141,7 +149,10 @@ def toKerasLoss(value):
141149

142150
@staticmethod
143151
def toKerasOptimizer(value):
144-
""" Convert a value to a name of Keras optimizer, if possible """
152+
"""
153+
Check if a value is a valid name of Keras optimizer and return it.
154+
Otherwise raise an error.
155+
"""
145156
if not kmutil.is_valid_optimizer(value):
146157
err_msg = "Named optimizer not supported in Keras: [type {}] {}"
147158
raise TypeError(err_msg.format(type(value), value))
@@ -150,7 +161,7 @@ def toKerasOptimizer(value):
150161

151162

152163
def _check_is_tensor_name(_maybe_tnsr_name):
153-
""" Check if the input is a valid tensor name """
164+
""" Check if the input is a valid tensor name or raise a `TypeError` otherwise. """
154165
if not isinstance(_maybe_tnsr_name, six.string_types):
155166
err_msg = "expect tensor name to be of string type, but got [type {}]"
156167
raise TypeError(err_msg.format(type(_maybe_tnsr_name)))
@@ -164,13 +175,10 @@ def _check_is_tensor_name(_maybe_tnsr_name):
164175
err_msg = "Tensor name must be of type <op_name>:<index>, but got {}"
165176
raise TypeError(err_msg.format(_maybe_tnsr_name))
166177

167-
return _maybe_tnsr_name
168-
169178

170-
def _check_is_str(_maybe_col_name):
171-
""" Check if the given colunm name is a valid column name """
179+
def _check_is_str(_maybe_str):
180+
""" Check if the value is a valid string type or raise a `TypeError` otherwise. """
172181
# We only check if the column name candidate is a string type
173-
if not isinstance(_maybe_col_name, six.string_types):
182+
if not isinstance(_maybe_str, six.string_types):
174183
err_msg = 'expect string type but got type {} for {}'
175-
raise TypeError(err_msg.format(type(_maybe_col_name), _maybe_col_name))
176-
return _maybe_col_name
184+
raise TypeError(err_msg.format(type(_maybe_str), _maybe_str))

python/sparkdl/param/image_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class HasOutputMode(Params):
107107
"How the output column should be formatted. 'vector' for a 1-d MLlib " +
108108
"Vector of floats. 'image' to format the output to work with the image " +
109109
"tools in this package.",
110-
typeConverter=SparkDLTypeConverters.buildCheckList(OUTPUT_MODES))
110+
typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(OUTPUT_MODES))
111111

112112
def setOutputMode(self, value):
113113
return self._set(outputMode=value)

python/sparkdl/param/shared_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class HasTFHParams(Params):
232232
key-value object, storing parameters to be used to define the final
233233
TensorFlow graph for the Transformer.
234234
235-
Currently accepted values are:
235+
Currently used values are:
236236
- `batch_size`: number of samples evaluated together in inference steps"""),
237237
typeConverter=SparkDLTypeConverters.toTFHParams)
238238

python/sparkdl/transformers/named_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class DeepImagePredictor(Transformer, HasInputCol, HasOutputCol):
4040
"""
4141

4242
modelName = Param(Params._dummy(), "modelName", "A deep learning model name",
43-
typeConverter=SparkDLTypeConverters.buildCheckList(SUPPORTED_MODELS))
43+
typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(SUPPORTED_MODELS))
4444
decodePredictions = Param(Params._dummy(), "decodePredictions",
4545
"If true, output predictions in the (class, description, probability) format",
4646
typeConverter=TypeConverters.toBoolean)
@@ -125,7 +125,7 @@ class DeepImageFeaturizer(Transformer, HasInputCol, HasOutputCol):
125125
"""
126126

127127
modelName = Param(Params._dummy(), "modelName", "A deep learning model name",
128-
typeConverter=SparkDLTypeConverters.buildCheckList(SUPPORTED_MODELS))
128+
typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(SUPPORTED_MODELS))
129129

130130
@keyword_only
131131
def __init__(self, inputCol=None, outputCol=None, modelName=None):
@@ -169,7 +169,7 @@ class _NamedImageTransformer(Transformer, HasInputCol, HasOutputCol):
169169
"""
170170

171171
modelName = Param(Params._dummy(), "modelName", "A deep learning model name",
172-
typeConverter=SparkDLTypeConverters.buildCheckList(SUPPORTED_MODELS))
172+
typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(SUPPORTED_MODELS))
173173
featurize = Param(Params._dummy(), "featurize",
174174
"If true, output features. If false, output predictions. Either way the output is a vector.",
175175
typeConverter=TypeConverters.toBoolean)

0 commit comments

Comments
 (0)