@@ -33,120 +33,131 @@ class SparkDLTypeConverters(object):
33
33
These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
34
34
They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
35
35
"""
36
+
36
37
@staticmethod
37
38
def toTFGraph (value ):
38
- if isinstance (value , tf .Graph ):
39
- return value
40
- else :
41
- raise TypeError ("Could not convert %s to TensorFlow Graph" % type (value ))
39
+ if not isinstance (value , tf .Graph ):
40
+ raise TypeError ("Could not convert %s to tf.Graph" % type (value ))
41
+ return value
42
42
43
43
@staticmethod
44
44
def asColumnToTensorNameMap (value ):
45
45
"""
46
46
Convert a value to a column name to :py:obj:`tf.Tensor` name mapping
47
47
as a sorted list of string pairs, if possible.
48
48
"""
49
- if isinstance (value , dict ):
50
- strs_pair_seq = []
51
- for _maybe_col_name , _maybe_tnsr_name in value .items ():
52
- # Check if the non-tensor value is of string type
53
- _col_name = _get_strict_col_name (_maybe_col_name )
54
- # Check if the tensor name is actually valid
55
- _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
56
- strs_pair_seq .append ((_col_name , _tnsr_name ))
49
+ if not isinstance (value , dict ):
50
+ err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
51
+ raise TypeError (err_msg .format (type (value ), value ))
57
52
58
- return sorted (strs_pair_seq )
53
+ # Convertion logic after quick type check
54
+ strs_pair_seq = []
55
+ for _maybe_col_name , _maybe_tnsr_name in value .items ():
56
+ # Check if the non-tensor value is of string type
57
+ _check_is_str (_maybe_col_name )
58
+ # Check if the tensor name looks like a tensor name
59
+ _check_is_tensor_name (_maybe_tnsr_name )
60
+ strs_pair_seq .append ((_maybe_col_name , _maybe_tnsr_name ))
59
61
60
- err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
61
- raise TypeError (err_msg .format (type (value ), value ))
62
+ return sorted (strs_pair_seq )
62
63
63
64
@staticmethod
64
65
def asTensorNameToColumnMap (value ):
65
66
"""
66
67
Convert a value to a :py:obj:`tf.Tensor` name to column name mapping
67
68
as a sorted list of string pairs, if possible.
68
69
"""
69
- if isinstance (value , dict ):
70
- strs_pair_seq = []
71
- for _maybe_tnsr_name , _maybe_col_name in value .items ():
72
- # Check if the non-tensor value is of string type
73
- _col_name = _get_strict_col_name (_maybe_col_name )
74
- # Check if the tensor name is actually valid
75
- _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
76
- strs_pair_seq .append ((_tnsr_name , _col_name ))
70
+ if not isinstance (value , dict ):
71
+ err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
72
+ raise TypeError (err_msg .format (type (value ), value ))
77
73
78
- return sorted (strs_pair_seq )
74
+ # Convertion logic after quick type check
75
+ strs_pair_seq = []
76
+ for _maybe_tnsr_name , _maybe_col_name in value .items ():
77
+ # Check if the non-tensor value is of string type
78
+ _check_is_str (_maybe_col_name )
79
+ # Check if the tensor name looks like a tensor name
80
+ _check_is_tensor_name (_maybe_tnsr_name )
81
+ strs_pair_seq .append ((_maybe_tnsr_name , _maybe_col_name ))
79
82
80
- err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
81
- raise TypeError (err_msg .format (type (value ), value ))
83
+ return sorted (strs_pair_seq )
82
84
83
85
@staticmethod
84
86
def toTFHParams (value ):
85
87
""" Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
86
- if isinstance (value , tf .contrib .training .HParams ):
87
- return value
88
- else :
88
+ if not isinstance (value , tf .contrib .training .HParams ):
89
89
raise TypeError ("Could not convert %s to TensorFlow HParams" % type (value ))
90
90
91
+ return value
92
+
91
93
@staticmethod
92
- def toStringOrTFTensor (value ):
94
+ def toTFTensorName (value ):
93
95
""" Convert a value to a str or a :py:obj:`tf.Tensor` object, if possible. """
94
96
if isinstance (value , tf .Tensor ):
95
- return value
97
+ return value . name
96
98
try :
99
+ _check_is_tensor_name (value )
97
100
return TypeConverters .toString (value )
98
101
except Exception as exc :
99
102
err_msg = "Could not convert [type {}] {} to tf.Tensor or str. {}"
100
103
raise TypeError (err_msg .format (type (value ), value , exc ))
101
104
102
105
@staticmethod
103
- def supportedNameConverter (supportedList ):
106
+ def buildCheckList (supportedList ):
104
107
"""
105
108
Create a converter that try to check if a value is part of the supported list.
106
109
107
110
:param supportedList: list, containing supported objects.
108
111
:return: a converter that try to convert a value if it is part of the `supportedList`.
109
112
"""
113
+
110
114
def converter (value ):
111
- if value in supportedList :
112
- return value
113
- err_msg = "[type {}] {} is not in the supported list: {}"
114
- raise TypeError (err_msg .format (type (value ), str (value ), supportedList ))
115
+ if value not in supportedList :
116
+ err_msg = "[type {}] {} is not in the supported list: {}"
117
+ raise TypeError (err_msg .format (type (value ), str (value ), supportedList ))
118
+
119
+ return value
115
120
116
121
return converter
117
122
118
123
@staticmethod
119
124
def toKerasLoss (value ):
120
125
""" Convert a value to a name of Keras loss function, if possible """
121
- if kmutil .is_valid_loss_function (value ):
122
- return value
123
- err_msg = "Named loss not supported in Keras: [type {}] {}"
124
- raise ValueError (err_msg .format (type (value ), value ))
126
+ # return early in for clarify as well as less indentation
127
+ if not kmutil .is_valid_loss_function (value ):
128
+ err_msg = "Named loss not supported in Keras: [type {}] {}"
129
+ raise ValueError (err_msg .format (type (value ), value ))
130
+
131
+ return value
125
132
126
133
@staticmethod
127
134
def toKerasOptimizer (value ):
128
135
""" Convert a value to a name of Keras optimizer, if possible """
129
- if kmutil .is_valid_optimizer (value ):
130
- return value
131
- err_msg = "Named optimizer not supported in Keras: [type {}] {}"
132
- raise TypeError (err_msg .format (type (value ), value ))
136
+ if not kmutil .is_valid_optimizer (value ):
137
+ err_msg = "Named optimizer not supported in Keras: [type {}] {}"
138
+ raise TypeError (err_msg .format (type (value ), value ))
139
+
140
+ return value
133
141
134
142
135
- def _get_strict_tensor_name (_maybe_tnsr_name ):
143
+ def _check_is_tensor_name (_maybe_tnsr_name ):
136
144
""" Check if the input is a valid tensor name """
137
145
try :
138
146
assert isinstance (_maybe_tnsr_name , six .string_types ), \
139
147
"must provide a strict tensor name as input, but got {}" .format (type (_maybe_tnsr_name ))
140
- assert tfx .as_tensor_name (_maybe_tnsr_name ) == _maybe_tnsr_name , \
141
- "input {} must be a valid tensor name" .format (_maybe_tnsr_name )
148
+
149
+ # The check is taken from TensorFlow's NodeDef protocol buffer.
150
+ # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
151
+ _ , src_idx = _maybe_tnsr_name .split (":" )
152
+ _ = int (src_idx )
142
153
except Exception as exc :
143
154
err_msg = "Can NOT convert [type {}] {} to tf.Tensor name: {}"
144
155
raise TypeError (err_msg .format (type (_maybe_tnsr_name ), _maybe_tnsr_name , exc ))
145
156
else :
146
157
return _maybe_tnsr_name
147
158
148
159
149
- def _get_strict_col_name (_maybe_col_name ):
160
+ def _check_is_str (_maybe_col_name ):
150
161
""" Check if the given colunm name is a valid column name """
151
162
# We only check if the column name candidate is a string type
152
163
if not isinstance (_maybe_col_name , six .string_types ):
0 commit comments