Skip to content

Commit dd168ed

Browse files
committed
converter changes
1 parent fcabcb6 commit dd168ed

File tree

1 file changed

+85
-50
lines changed

1 file changed

+85
-50
lines changed

python/sparkdl/param/converters.py

Lines changed: 85 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,15 @@
2424

2525
__all__ = ['SparkDLTypeConverters']
2626

27-
def _get_strict_tensor_name(_maybe_tnsr_name):
28-
assert isinstance(_maybe_tnsr_name, six.string_types), \
29-
"must provide a strict tensor name as input, but got {}".format(type(_maybe_tnsr_name))
30-
assert tfx.as_tensor_name(_maybe_tnsr_name) == _maybe_tnsr_name, \
31-
"input {} must be a valid tensor name".format(_maybe_tnsr_name)
32-
return _maybe_tnsr_name
33-
34-
def _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=True):
35-
if isinstance(value, dict):
36-
strs_pair_seq = []
37-
for k, v in value.items():
38-
# Check if the non-tensor value is of string type
39-
_non_tnsr_str_val = v if is_key_tf_tensor else k
40-
if not isinstance(_non_tnsr_str_val, six.string_types):
41-
err_msg = 'expect string type for {}, but got {}'
42-
raise TypeError(err_msg.format(_non_tnsr_str_val, type(_non_tnsr_str_val)))
43-
44-
# Check if the tensor name is actually valid
45-
try:
46-
if is_key_tf_tensor:
47-
_pair = (_get_strict_tensor_name(k), v)
48-
else:
49-
_pair = (k, _get_strict_tensor_name(v))
50-
except Exception as exc:
51-
err_msg = "Can NOT convert {} (type {}) to tf.Tensor name: {}"
52-
_not_tf_op = k if is_key_tf_tensor else v
53-
raise TypeError(err_msg.format(_not_tf_op, type(_not_tf_op), exc))
54-
55-
strs_pair_seq.append(_pair)
56-
57-
return sorted(strs_pair_seq)
58-
59-
if is_key_tf_tensor:
60-
raise TypeError("Could not convert %s to tf.Tensor name to str mapping" % type(value))
61-
else:
62-
raise TypeError("Could not convert %s to str to tf.Tensor name mapping" % type(value))
63-
6427

6528
class SparkDLTypeConverters(object):
29+
"""
30+
.. note:: DeveloperApi
31+
32+
Factory methods for common type conversion functions for :py:func:`Param.typeConverter`.
33+
These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
34+
They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
35+
"""
6636
@staticmethod
6737
def toTFGraph(value):
6838
if isinstance(value, tf.Graph):
@@ -72,49 +42,114 @@ def toTFGraph(value):
7242

7343
@staticmethod
7444
def asColumnToTensorNameMap(value):
75-
return _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=False)
45+
"""
46+
Convert a value to a column name to :py:obj:`tf.Tensor` name mapping
47+
as a sorted list of string pairs, if possible.
48+
"""
49+
if isinstance(value, dict):
50+
strs_pair_seq = []
51+
for _maybe_col_name, _maybe_tnsr_name in value.items():
52+
# Check if the non-tensor value is of string type
53+
_col_name = _get_strict_col_name(_maybe_col_name)
54+
# Check if the tensor name is actually valid
55+
_tnsr_name = _get_strict_tensor_name(_maybe_tnsr_name)
56+
strs_pair_seq.append((_col_name, _tnsr_name))
57+
58+
return sorted(strs_pair_seq)
59+
60+
err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
61+
raise TypeError(err_msg.format(type(value), value))
7662

7763
@staticmethod
7864
def asTensorNameToColumnMap(value):
79-
return _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=True)
65+
"""
66+
Convert a value to a :py:obj:`tf.Tensor` name to column name mapping
67+
as a sorted list of string pairs, if possible.
68+
"""
69+
if isinstance(value, dict):
70+
strs_pair_seq = []
71+
for _maybe_tnsr_name, _maybe_col_name in value.items():
72+
# Check if the non-tensor value is of string type
73+
_col_name = _get_strict_col_name(_maybe_col_name)
74+
# Check if the tensor name is actually valid
75+
_tnsr_name = _get_strict_tensor_name(_maybe_tnsr_name)
76+
strs_pair_seq.append((_tnsr_name, _col_name))
77+
78+
return sorted(strs_pair_seq)
79+
80+
err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
81+
raise TypeError(err_msg.format(type(value), value))
8082

8183
@staticmethod
8284
def toTFHParams(value):
85+
""" Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
8386
if isinstance(value, tf.contrib.training.HParams):
8487
return value
8588
else:
8689
raise TypeError("Could not convert %s to TensorFlow HParams" % type(value))
8790

8891
@staticmethod
8992
def toStringOrTFTensor(value):
93+
""" Convert a value to a str or a :py:obj:`tf.Tensor` object, if possible. """
9094
if isinstance(value, tf.Tensor):
9195
return value
92-
else:
93-
try:
94-
return TypeConverters.toString(value)
95-
except TypeError:
96-
raise TypeError("Could not convert %s to tensorflow.Tensor or str" % type(value))
96+
try:
97+
return TypeConverters.toString(value)
98+
except Exception as exc:
99+
err_msg = "Could not convert [type {}] {} to tf.Tensor or str. {}"
100+
raise TypeError(err_msg.format(type(value), value, exc))
97101

98102
@staticmethod
99103
def supportedNameConverter(supportedList):
104+
"""
105+
Create a converter that try to check if a value is part of the supported list.
106+
107+
:param supportedList: list, containing supported objects.
108+
:return: a converter that try to convert a value if it is part of the `supportedList`.
109+
"""
100110
def converter(value):
101111
if value in supportedList:
102112
return value
103-
else:
104-
raise TypeError("%s %s is not in the supported list." % type(value), str(value))
113+
err_msg = "[type {}] {} is not in the supported list: {}"
114+
raise TypeError(err_msg.format(type(value), str(value), supportedList))
105115

106116
return converter
107117

108118
@staticmethod
109119
def toKerasLoss(value):
120+
""" Convert a value to a name of Keras loss function, if possible """
110121
if kmutil.is_valid_loss_function(value):
111122
return value
112-
raise ValueError(
113-
"Named loss not supported in Keras: {} type({})".format(value, type(value)))
123+
err_msg = "Named loss not supported in Keras: [type {}] {}"
124+
raise ValueError(err_msg.format(type(value), value))
114125

115126
@staticmethod
116127
def toKerasOptimizer(value):
128+
""" Convert a value to a name of Keras optimizer, if possible """
117129
if kmutil.is_valid_optimizer(value):
118130
return value
119-
raise TypeError(
120-
"Named optimizer not supported in Keras: {} type({})".format(value, type(value)))
131+
err_msg = "Named optimizer not supported in Keras: [type {}] {}"
132+
raise TypeError(err_msg.format(type(value), value))
133+
134+
135+
def _get_strict_tensor_name(_maybe_tnsr_name):
136+
""" Check if the input is a valid tensor name """
137+
try:
138+
assert isinstance(_maybe_tnsr_name, six.string_types), \
139+
"must provide a strict tensor name as input, but got {}".format(type(_maybe_tnsr_name))
140+
assert tfx.as_tensor_name(_maybe_tnsr_name) == _maybe_tnsr_name, \
141+
"input {} must be a valid tensor name".format(_maybe_tnsr_name)
142+
except Exception as exc:
143+
err_msg = "Can NOT convert [type {}] {} to tf.Tensor name: {}"
144+
raise TypeError(err_msg.format(type(_maybe_tnsr_name), _maybe_tnsr_name, exc))
145+
else:
146+
return _maybe_tnsr_name
147+
148+
149+
def _get_strict_col_name(_maybe_col_name):
150+
""" Check if the given colunm name is a valid column name """
151+
# We only check if the column name candidate is a string type
152+
if not isinstance(_maybe_col_name, six.string_types):
153+
err_msg = 'expect string type but got type {} for {}'
154+
raise TypeError(err_msg.format(type(_maybe_col_name), _maybe_col_name))
155+
return _maybe_col_name

0 commit comments

Comments
 (0)