13
13
# limitations under the License.
14
14
#
15
15
16
- # pylint: disable=wrong-spelling-in-docstring, invalid-name,import-error
16
+ # pylint: disable=invalid-name,import-error
17
17
18
18
""" SparkDLTypeConverters
19
- Type conversion utilities for definition Spark Deep Learning related MLlib `Params`.
19
+
20
+ Type conversion utilities for defining MLlib `Params` used in Spark Deep Learning Pipelines.
21
+
22
+ .. note:: We follow the convention of MLlib to name these utilities "converters",
23
+ but most of them act as type checkers that return the argument if it is
24
+ the desired type and raise `TypeError` otherwise.
20
25
"""
21
26
22
27
import six
@@ -33,7 +38,7 @@ class SparkDLTypeConverters(object):
33
38
"""
34
39
.. note:: DeveloperApi
35
40
36
- Factory methods for type conversion functions for :py:func:`Param.typeConverter`.
41
+ Methods for type conversion functions for :py:func:`Param.typeConverter`.
37
42
These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
38
43
They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
39
44
"""
@@ -50,19 +55,16 @@ def toTFGraph(value):
50
55
@staticmethod
51
56
def asColumnToTensorNameMap (value ):
52
57
"""
53
- Convert a value to a column name to :py:obj :`tf.Tensor` name mapping
54
- as a sorted list of string pairs, if possible.
58
+ Convert a value to a column name to :py:class :`tf.Tensor` name mapping
59
+ as a sorted list (in lexicographical order) of string pairs, if possible.
55
60
"""
56
61
if not isinstance (value , dict ):
57
62
err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
58
63
raise TypeError (err_msg .format (type (value ), value ))
59
64
60
- # Conversion logic after quick type check
61
65
strs_pair_seq = []
62
66
for _maybe_col_name , _maybe_tnsr_name in value .items ():
63
- # Check if the non-tensor value is of string type
64
67
_check_is_str (_maybe_col_name )
65
- # Check if the tensor name looks like a tensor name
66
68
_check_is_tensor_name (_maybe_tnsr_name )
67
69
strs_pair_seq .append ((_maybe_col_name , _maybe_tnsr_name ))
68
70
@@ -71,52 +73,55 @@ def asColumnToTensorNameMap(value):
71
73
@staticmethod
72
74
def asTensorNameToColumnMap (value ):
73
75
"""
74
- Convert a value to a :py:obj :`tf.Tensor` name to column name mapping
75
- as a sorted list of string pairs, if possible.
76
+ Convert a value to a :py:class :`tf.Tensor` name to column name mapping
77
+ as a sorted list (in lexicographical order) of string pairs, if possible.
76
78
"""
77
79
if not isinstance (value , dict ):
78
80
err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
79
81
raise TypeError (err_msg .format (type (value ), value ))
80
82
81
- # Conversion logic after quick type check
82
83
strs_pair_seq = []
83
84
for _maybe_tnsr_name , _maybe_col_name in value .items ():
84
- # Check if the non-tensor value is of string type
85
85
_check_is_str (_maybe_col_name )
86
- # Check if the tensor name looks like a tensor name
87
86
_check_is_tensor_name (_maybe_tnsr_name )
88
87
strs_pair_seq .append ((_maybe_tnsr_name , _maybe_col_name ))
89
88
90
89
return sorted (strs_pair_seq )
91
90
92
91
@staticmethod
93
92
def toTFHParams (value ):
94
- """ Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
93
+ """
94
+ Check that the given value is a :py:class:`tf.contrib.training.HParams` object,
95
+ and return it. Raise an error otherwise.
96
+ """
95
97
if not isinstance (value , tf .contrib .training .HParams ):
96
98
raise TypeError ("Could not convert %s to TensorFlow HParams" % type (value ))
97
99
98
100
return value
99
101
100
102
@staticmethod
101
103
def toTFTensorName (value ):
102
- """ Convert a value to a :py:obj:`tf.Tensor` name, if possible. """
104
+ """
105
+ Check if a value is a valid :py:class:`tf.Tensor` name and return it.
106
+ Raise an error otherwise.
107
+ """
103
108
if isinstance (value , tf .Tensor ):
104
109
return value .name
105
110
try :
106
- _maybe_tnsr_name = TypeConverters .toString (value )
107
- _check_is_tensor_name (_maybe_tnsr_name )
108
- return _maybe_tnsr_name
111
+ _check_is_tensor_name (value )
112
+ return value
109
113
except Exception as exc :
110
114
err_msg = "Could not convert [type {}] {} to tf.Tensor name. {}"
111
115
raise TypeError (err_msg .format (type (value ), value , exc ))
112
116
113
117
@staticmethod
114
- def buildCheckList (supportedList ):
118
+ def buildSupportedItemConverter (supportedList ):
115
119
"""
116
- Create a converter that try to check if a value is part of the supported list.
120
+ Create a " converter" that try to check if a value is part of the supported list of values .
117
121
118
122
:param supportedList: list, containing supported objects.
119
- :return: a converter that try to convert a value if it is part of the `supportedList`.
123
+ :return: a converter that try to check if a value is part of the `supportedList` and return it.
124
+ Raise an error otherwise.
120
125
"""
121
126
122
127
def converter (value ):
@@ -131,7 +136,10 @@ def converter(value):
131
136
132
137
@staticmethod
133
138
def toKerasLoss (value ):
134
- """ Convert a value to a name of Keras loss function, if possible """
139
+ """
140
+ Check if a value is a valid Keras loss function name and return it.
141
+ Otherwise raise an error.
142
+ """
135
143
# return early in for clarify as well as less indentation
136
144
if not kmutil .is_valid_loss_function (value ):
137
145
err_msg = "Named loss not supported in Keras: [type {}] {}"
@@ -141,7 +149,10 @@ def toKerasLoss(value):
141
149
142
150
@staticmethod
143
151
def toKerasOptimizer (value ):
144
- """ Convert a value to a name of Keras optimizer, if possible """
152
+ """
153
+ Check if a value is a valid name of Keras optimizer and return it.
154
+ Otherwise raise an error.
155
+ """
145
156
if not kmutil .is_valid_optimizer (value ):
146
157
err_msg = "Named optimizer not supported in Keras: [type {}] {}"
147
158
raise TypeError (err_msg .format (type (value ), value ))
@@ -150,7 +161,7 @@ def toKerasOptimizer(value):
150
161
151
162
152
163
def _check_is_tensor_name (_maybe_tnsr_name ):
153
- """ Check if the input is a valid tensor name """
164
+ """ Check if the input is a valid tensor name or raise a `TypeError` otherwise. """
154
165
if not isinstance (_maybe_tnsr_name , six .string_types ):
155
166
err_msg = "expect tensor name to be of string type, but got [type {}]"
156
167
raise TypeError (err_msg .format (type (_maybe_tnsr_name )))
@@ -164,13 +175,10 @@ def _check_is_tensor_name(_maybe_tnsr_name):
164
175
err_msg = "Tensor name must be of type <op_name>:<index>, but got {}"
165
176
raise TypeError (err_msg .format (_maybe_tnsr_name ))
166
177
167
- return _maybe_tnsr_name
168
-
169
178
170
- def _check_is_str (_maybe_col_name ):
171
- """ Check if the given colunm name is a valid column name """
179
+ def _check_is_str (_maybe_str ):
180
+ """ Check if the value is a valid string type or raise a `TypeError` otherwise. """
172
181
# We only check if the column name candidate is a string type
173
- if not isinstance (_maybe_col_name , six .string_types ):
182
+ if not isinstance (_maybe_str , six .string_types ):
174
183
err_msg = 'expect string type but got type {} for {}'
175
- raise TypeError (err_msg .format (type (_maybe_col_name ), _maybe_col_name ))
176
- return _maybe_col_name
184
+ raise TypeError (err_msg .format (type (_maybe_str ), _maybe_str ))
0 commit comments