Skip to content

Commit 6e46073

Browse files
committed
Merge branch 'tf-transformer-part1' into api-tf-transformer
2 parents a3517d6 + 323939a commit 6e46073

File tree

4 files changed

+137
-37
lines changed

4 files changed

+137
-37
lines changed

python/sparkdl/param/converters.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
#
1515

16+
import six
17+
1618
import tensorflow as tf
1719

1820
from pyspark.ml.param import TypeConverters
@@ -21,6 +23,37 @@
2123
from sparkdl.graph.input import TFInputGraph
2224
import sparkdl.utils.keras_model as kmutil
2325

26+
__all__ = ['SparkDLTypeConverters']
27+
28+
def _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=True):
29+
if isinstance(value, dict):
30+
strs_pair_seq = []
31+
for k, v in value.items():
32+
try:
33+
if is_key_tf_tensor:
34+
_pair = (tfx.as_tensor_name(k), v)
35+
else:
36+
_pair = (k, tfx.as_tensor_name(v))
37+
except:
38+
err_msg = "Can NOT convert {} (type {}) to tf.Tensor name"
39+
_not_tf_op = k if is_key_tf_tensor else v
40+
raise TypeError(err_msg.format(_not_tf_op, type(_not_tf_op)))
41+
42+
str_val = v if is_key_tf_tensor else k
43+
if not isinstance(str_val, six.string_types):
44+
err_msg = 'expect string type for {}, but got {}'
45+
raise TypeError(err_msg.format(str_val, type(str_val)))
46+
47+
strs_pair_seq.append(_pair)
48+
49+
return sorted(strs_pair_seq)
50+
51+
if is_key_tf_tensor:
52+
raise TypeError("Could not convert %s to tf.Tensor name to str mapping" % type(value))
53+
else:
54+
raise TypeError("Could not convert %s to str to tf.Tensor name mapping" % type(value))
55+
56+
2457
class SparkDLTypeConverters(object):
2558
@staticmethod
2659
def toTFGraph(value):
@@ -37,18 +70,12 @@ def toTFInputGraph(value):
3770
raise TypeError("Could not convert %s to TFInputGraph" % type(value))
3871

3972
@staticmethod
40-
def asColumnToTensorMap(value):
41-
if isinstance(value, dict):
42-
strs_pair_seq = [(k, tfx.as_op_name(v)) for k, v in value.items()]
43-
return sorted(strs_pair_seq)
44-
raise TypeError("Could not convert %s to TensorFlow Tensor" % type(value))
73+
def asColumnToTensorNameMap(value):
74+
return _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=False)
4575

4676
@staticmethod
47-
def asTensorToColumnMap(value):
48-
if isinstance(value, dict):
49-
strs_pair_seq = [(tfx.as_op_name(k), v) for k, v in value.items()]
50-
return sorted(strs_pair_seq)
51-
raise TypeError("Could not convert %s to TensorFlow Tensor" % type(value))
77+
def asTensorNameToColumnMap(value):
78+
return _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=True)
5279

5380
@staticmethod
5481
def toTFHParams(value):

python/sparkdl/param/shared_params.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,10 @@
2222
from functools import wraps
2323
import six
2424

25-
import keras
26-
import tensorflow as tf
27-
2825
from pyspark.ml.param import Param, Params, TypeConverters
2926

30-
from sparkdl.graph.builder import GraphFunction, IsolatedSession
31-
import sparkdl.graph.utils as tfx
3227
from sparkdl.graph.input import TFInputGraph
3328
from sparkdl.param.converters import SparkDLTypeConverters
34-
import sparkdl.utils.keras_model as kmutil
35-
3629

3730
########################################################
3831
# Copied from PySpark for backward compatibility.
@@ -204,11 +197,10 @@ class HasOutputMapping(Params):
204197
"""
205198
Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs
206199
"""
207-
outputMapping = Param(
208-
Params._dummy(),
209-
"outputMapping",
210-
"Mapping output :class:`tf.Operation` names to DataFrame column names",
211-
typeConverter=SparkDLTypeConverters.asTensorToColumnMap)
200+
outputMapping = Param(Params._dummy(),
201+
"outputMapping",
202+
"Mapping output :class:`tf.Tensor` names to DataFrame column names",
203+
typeConverter=SparkDLTypeConverters.asTensorNameToColumnMap)
212204

213205
def setOutputMapping(self, value):
214206
# NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the
@@ -225,11 +217,10 @@ class HasInputMapping(Params):
225217
"""
226218
Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorOpName') pairs
227219
"""
228-
inputMapping = Param(
229-
Params._dummy(),
230-
"inputMapping",
231-
"Mapping input DataFrame column names to :class:`tf.Operation` names",
232-
typeConverter=SparkDLTypeConverters.asColumnToTensorMap)
220+
inputMapping = Param(Params._dummy(),
221+
"inputMapping",
222+
"Mapping input DataFrame column names to :class:`tf.Tensor` names",
223+
typeConverter=SparkDLTypeConverters.asColumnToTensorNameMap)
233224

234225
def setInputMapping(self, value):
235226
# NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the
@@ -246,11 +237,10 @@ class HasTFInputGraph(Params):
246237
"""
247238
Mixin for param tfInputGraph: a serializable object derived from a TensorFlow computation graph.
248239
"""
249-
tfInputGraph = Param(
250-
Params._dummy(),
251-
"tfInputGraph",
252-
"A serializable object derived from a TensorFlow computation graph",
253-
typeConverter=SparkDLTypeConverters.toTFInputGraph)
240+
tfInputGraph = Param(Params._dummy(),
241+
"tfInputGraph",
242+
"A serializable object derived from a TensorFlow computation graph",
243+
typeConverter=SparkDLTypeConverters.toTFInputGraph)
254244

255245
def __init__(self):
256246
super(HasTFInputGraph, self).__init__()
@@ -271,11 +261,10 @@ class HasTFHParams(Params):
271261
"""
272262
Mixin for TensorFlow model hyper-parameters
273263
"""
274-
tfHParams = Param(
275-
Params._dummy(),
276-
"hparams",
277-
"instance of :class:`tf.contrib.training.HParams`, a key-value map-like object",
278-
typeConverter=SparkDLTypeConverters.toTFHParams)
264+
tfHParams = Param(Params._dummy(),
265+
"hparams",
266+
"instance of :class:`tf.contrib.training.HParams`, a key-value map-like object",
267+
typeConverter=SparkDLTypeConverters.toTFHParams)
279268

280269
def setTFHParams(self, value):
281270
return self._set(tfHParam=value)

python/tests/param/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
# Copyright 2017 Databricks, Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#

python/tests/param/params_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2017 Databricks, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
import sys
16+
17+
if sys.version_info[:2] <= (2, 6):
18+
try:
19+
import unittest2 as unittest
20+
except ImportError:
21+
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
22+
sys.exit(1)
23+
else:
24+
import unittest
25+
26+
from sparkdl.param.converters import SparkDLTypeConverters as conv
27+
28+
class ParamsConverterTest(unittest.TestCase):
29+
# pylint: disable=protected-access
30+
31+
def test_tf_input_mapping_converter(self):
32+
valid_tnsr_input = {'colA': 'tnsrOpA:0',
33+
'colB': 'tnsrOpB:0'}
34+
valid_op_input = {'colA': 'tnsrOpA',
35+
'colB': 'tnsrOpB'}
36+
valid_input_mapping_result = [('colA', 'tnsrOpA:0'),
37+
('colB', 'tnsrOpB:0')]
38+
39+
for valid_input_mapping in [valid_op_input, valid_tnsr_input]:
40+
res = conv.asColumnToTensorNameMap(valid_input_mapping)
41+
self.assertEqual(valid_input_mapping_result, res)
42+
43+
def test_tf_output_mapping_converter(self):
44+
valid_tnsr_output = {'tnsrOpA:0': 'colA',
45+
'tnsrOpB:0': 'colB'}
46+
valid_op_output = {'tnsrOpA': 'colA',
47+
'tnsrOpB': 'colB'}
48+
valid_output_mapping_result = [('tnsrOpA:0', 'colA'),
49+
('tnsrOpB:0', 'colB')]
50+
51+
for valid_output_mapping in [valid_tnsr_output, valid_op_output]:
52+
res = conv.asTensorNameToColumnMap(valid_output_mapping)
53+
self.assertEqual(valid_output_mapping_result, res)
54+
55+
56+
def test_invalid_input_mapping(self):
57+
for invalid in [['a1', 'b2'], ('c3', 'd4'), [('a', 1), ('b', 2)]]:
58+
with self.assertRaises(TypeError):
59+
conv.asColumnToTensorNameMap(invalid)
60+
conv.asTensorNameToColumnMap(invalid)
61+
62+
with self.assertRaises(TypeError):
63+
# Wrong value type: must be string
64+
conv.asTensorNameToColumnMap({1: 'a', 2.0: 'b'})
65+
conv.asColumnToTensorNameMap({'a': 1, 'b': 2.0})
66+
67+
# Wrong containter type: only accept dict
68+
conv.asColumnToTensorNameMap([('colA', 'tnsrA:0'), ('colB', 'tnsrB:0')])
69+
conv.asTensorNameToColumnMap([('tnsrA:0', 'colA'), ('tnsrB:0', 'colB')])

0 commit comments

Comments
 (0)