24
24
25
25
__all__ = ['SparkDLTypeConverters' ]
26
26
27
- def _get_strict_tensor_name (_maybe_tnsr_name ):
28
- assert isinstance (_maybe_tnsr_name , six .string_types ), \
29
- "must provide a strict tensor name as input, but got {}" .format (type (_maybe_tnsr_name ))
30
- assert tfx .as_tensor_name (_maybe_tnsr_name ) == _maybe_tnsr_name , \
31
- "input {} must be a valid tensor name" .format (_maybe_tnsr_name )
32
- return _maybe_tnsr_name
33
-
34
- def _try_convert_tf_tensor_mapping (value , is_key_tf_tensor = True ):
35
- if isinstance (value , dict ):
36
- strs_pair_seq = []
37
- for k , v in value .items ():
38
- # Check if the non-tensor value is of string type
39
- _non_tnsr_str_val = v if is_key_tf_tensor else k
40
- if not isinstance (_non_tnsr_str_val , six .string_types ):
41
- err_msg = 'expect string type for {}, but got {}'
42
- raise TypeError (err_msg .format (_non_tnsr_str_val , type (_non_tnsr_str_val )))
43
-
44
- # Check if the tensor name is actually valid
45
- try :
46
- if is_key_tf_tensor :
47
- _pair = (_get_strict_tensor_name (k ), v )
48
- else :
49
- _pair = (k , _get_strict_tensor_name (v ))
50
- except Exception as exc :
51
- err_msg = "Can NOT convert {} (type {}) to tf.Tensor name: {}"
52
- _not_tf_op = k if is_key_tf_tensor else v
53
- raise TypeError (err_msg .format (_not_tf_op , type (_not_tf_op ), exc ))
54
-
55
- strs_pair_seq .append (_pair )
56
-
57
- return sorted (strs_pair_seq )
58
-
59
- if is_key_tf_tensor :
60
- raise TypeError ("Could not convert %s to tf.Tensor name to str mapping" % type (value ))
61
- else :
62
- raise TypeError ("Could not convert %s to str to tf.Tensor name mapping" % type (value ))
63
-
64
27
65
28
class SparkDLTypeConverters (object ):
29
+ """
30
+ .. note:: DeveloperApi
31
+
32
+ Factory methods for common type conversion functions for :py:func:`Param.typeConverter`.
33
+ These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
34
+ They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
35
+ """
66
36
@staticmethod
67
37
def toTFGraph (value ):
68
38
if isinstance (value , tf .Graph ):
@@ -72,49 +42,114 @@ def toTFGraph(value):
72
42
73
43
@staticmethod
74
44
def asColumnToTensorNameMap (value ):
75
- return _try_convert_tf_tensor_mapping (value , is_key_tf_tensor = False )
45
+ """
46
+ Convert a value to a column name to :py:obj:`tf.Tensor` name mapping
47
+ as a sorted list of string pairs, if possible.
48
+ """
49
+ if isinstance (value , dict ):
50
+ strs_pair_seq = []
51
+ for _maybe_col_name , _maybe_tnsr_name in value .items ():
52
+ # Check if the non-tensor value is of string type
53
+ _col_name = _get_strict_col_name (_maybe_col_name )
54
+ # Check if the tensor name is actually valid
55
+ _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
56
+ strs_pair_seq .append ((_col_name , _tnsr_name ))
57
+
58
+ return sorted (strs_pair_seq )
59
+
60
+ err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
61
+ raise TypeError (err_msg .format (type (value ), value ))
76
62
77
63
@staticmethod
78
64
def asTensorNameToColumnMap (value ):
79
- return _try_convert_tf_tensor_mapping (value , is_key_tf_tensor = True )
65
+ """
66
+ Convert a value to a :py:obj:`tf.Tensor` name to column name mapping
67
+ as a sorted list of string pairs, if possible.
68
+ """
69
+ if isinstance (value , dict ):
70
+ strs_pair_seq = []
71
+ for _maybe_tnsr_name , _maybe_col_name in value .items ():
72
+ # Check if the non-tensor value is of string type
73
+ _col_name = _get_strict_col_name (_maybe_col_name )
74
+ # Check if the tensor name is actually valid
75
+ _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
76
+ strs_pair_seq .append ((_tnsr_name , _col_name ))
77
+
78
+ return sorted (strs_pair_seq )
79
+
80
+ err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
81
+ raise TypeError (err_msg .format (type (value ), value ))
80
82
81
83
@staticmethod
82
84
def toTFHParams (value ):
85
+ """ Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
83
86
if isinstance (value , tf .contrib .training .HParams ):
84
87
return value
85
88
else :
86
89
raise TypeError ("Could not convert %s to TensorFlow HParams" % type (value ))
87
90
88
91
@staticmethod
89
92
def toStringOrTFTensor (value ):
93
+ """ Convert a value to a str or a :py:obj:`tf.Tensor` object, if possible. """
90
94
if isinstance (value , tf .Tensor ):
91
95
return value
92
- else :
93
- try :
94
- return TypeConverters . toString ( value )
95
- except TypeError :
96
- raise TypeError ("Could not convert %s to tensorflow.Tensor or str" % type (value ))
96
+ try :
97
+ return TypeConverters . toString ( value )
98
+ except Exception as exc :
99
+ err_msg = "Could not convert [type {}] {} to tf.Tensor or str. {}"
100
+ raise TypeError (err_msg . format ( type (value ), value , exc ))
97
101
98
102
@staticmethod
99
103
def supportedNameConverter (supportedList ):
104
+ """
105
+ Create a converter that try to check if a value is part of the supported list.
106
+
107
+ :param supportedList: list, containing supported objects.
108
+ :return: a converter that try to convert a value if it is part of the `supportedList`.
109
+ """
100
110
def converter (value ):
101
111
if value in supportedList :
102
112
return value
103
- else :
104
- raise TypeError ("%s %s is not in the supported list." % type (value ), str (value ))
113
+ err_msg = "[type {}] {} is not in the supported list: {}"
114
+ raise TypeError (err_msg . format ( type (value ), str (value ), supportedList ))
105
115
106
116
return converter
107
117
108
118
@staticmethod
109
119
def toKerasLoss (value ):
120
+ """ Convert a value to a name of Keras loss function, if possible """
110
121
if kmutil .is_valid_loss_function (value ):
111
122
return value
112
- raise ValueError (
113
- "Named loss not supported in Keras: {} type({})" .format (value , type (value )))
123
+ err_msg = "Named loss not supported in Keras: [type {}] {}"
124
+ raise ValueError ( err_msg .format (type (value ), value ))
114
125
115
126
@staticmethod
116
127
def toKerasOptimizer (value ):
128
+ """ Convert a value to a name of Keras optimizer, if possible """
117
129
if kmutil .is_valid_optimizer (value ):
118
130
return value
119
- raise TypeError (
120
- "Named optimizer not supported in Keras: {} type({})" .format (value , type (value )))
131
+ err_msg = "Named optimizer not supported in Keras: [type {}] {}"
132
+ raise TypeError (err_msg .format (type (value ), value ))
133
+
134
+
135
+ def _get_strict_tensor_name (_maybe_tnsr_name ):
136
+ """ Check if the input is a valid tensor name """
137
+ try :
138
+ assert isinstance (_maybe_tnsr_name , six .string_types ), \
139
+ "must provide a strict tensor name as input, but got {}" .format (type (_maybe_tnsr_name ))
140
+ assert tfx .as_tensor_name (_maybe_tnsr_name ) == _maybe_tnsr_name , \
141
+ "input {} must be a valid tensor name" .format (_maybe_tnsr_name )
142
+ except Exception as exc :
143
+ err_msg = "Can NOT convert [type {}] {} to tf.Tensor name: {}"
144
+ raise TypeError (err_msg .format (type (_maybe_tnsr_name ), _maybe_tnsr_name , exc ))
145
+ else :
146
+ return _maybe_tnsr_name
147
+
148
+
149
+ def _get_strict_col_name (_maybe_col_name ):
150
+ """ Check if the given colunm name is a valid column name """
151
+ # We only check if the column name candidate is a string type
152
+ if not isinstance (_maybe_col_name , six .string_types ):
153
+ err_msg = 'expect string type but got type {} for {}'
154
+ raise TypeError (err_msg .format (type (_maybe_col_name ), _maybe_col_name ))
155
+ return _maybe_col_name
0 commit comments