Skip to content

Commit 66d44e9

Browse files
committed
(WIP) respond to PR comments
1 parent 692b0eb commit 66d44e9

File tree

4 files changed

+70
-63
lines changed

4 files changed

+70
-63
lines changed

python/sparkdl/graph/input.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,20 @@ class TFInputGraph(object):
2929
"""
3030

3131
# TODO: for (de-)serialization, the class should correspond to a ProtocolBuffer definition.
32-
def __init__(self, graph_def, input_mapping, output_mapping):
32+
def __init__(self, graph_def):
3333
# tf.GraphDef
3434
self.graph_def = graph_def
3535

36-
if isinstance(input_mapping, dict):
37-
input_mapping = list(input_mapping.items())
38-
assert isinstance(input_mapping, list), \
39-
"output mapping must be a list of strings, found type {}".format(type(input_mapping))
40-
self.input_mapping = sorted(input_mapping)
41-
42-
if isinstance(output_mapping, dict):
43-
output_mapping = list(output_mapping.items())
44-
assert isinstance(output_mapping, list), \
45-
"output mapping must be a list of strings, found type {}".format(type(output_mapping))
46-
self.output_mapping = sorted(output_mapping)
47-
48-
49-
def _get_params_from(gin_builder, input_mapping, output_mapping):
50-
gin = gin_builder.build(input_mapping, output_mapping)
51-
imap = dict(gin.input_mapping)
52-
assert len(imap) == len(gin.input_mapping)
53-
omap = dict(gin.output_mapping)
54-
assert len(omap) == len(gin.output_mapping)
55-
return gin.graph_def, imap, omap
56-
57-
5836
def get_params_from_checkpoint(checkpoint_dir, signature_def_key, input_mapping, output_mapping):
5937
assert signature_def_key is not None
6038
gin_builder = TFInputGraphBuilder.fromCheckpoint(checkpoint_dir, signature_def_key)
61-
return _get_params_from(gin_builder, input_mapping, output_mapping)
62-
39+
return gin_builder.build(input_mapping, output_mapping)
6340

6441
def get_params_from_saved_model(saved_model_dir, tag_set, signature_def_key, input_mapping,
6542
output_mapping):
6643
assert signature_def_key is not None
6744
gin_builder = TFInputGraphBuilder.fromSavedModel(saved_model_dir, tag_set, signature_def_key)
68-
return _get_params_from(gin_builder, input_mapping, output_mapping)
45+
return gin_builder.build(input_mapping, output_mapping)
6946

7047

7148
class TFInputGraphBuilder(object):
@@ -117,12 +94,19 @@ def build(self, input_mapping, output_mapping):
11794
tnsr = tnsr_or_sig
11895
fetches.append(tfx.get_tensor(graph, tnsr))
11996
tf_output_colname = tfx.op_name(graph, tnsr)
97+
# NOTE(phi-dbq): put the check here as it will be the entry point to construct
98+
# a `TFInputGraph` object.
99+
assert tf_output_colname not in _output_mapping, \
100+
"operation {} has multiple output tensors and ".format(tf_output_colname) + \
101+
"at least two of them are used in the output DataFrame. " + \
102+
"Operation names are used to name columns which leads to conflicts. " + \
103+
"You can apply `tf.identity` ops to each to avoid name conflicts."
120104
_output_mapping[tf_output_colname] = requested_colname
121105
output_mapping = _output_mapping
122106

123107
gdef = tfx.strip_and_freeze_until(fetches, graph, sess)
124108

125-
return TFInputGraph(gdef, input_mapping, output_mapping)
109+
return TFInputGraph(gdef), input_mapping, output_mapping
126110

127111
@classmethod
128112
def fromGraph(cls, graph):
@@ -167,6 +151,8 @@ def import_graph_fn(sess):
167151
ckpt_path = tf.train.latest_checkpoint(checkpoint_dir)
168152

169153
# NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def
154+
# the current `import_graph_def` function seems to ignore
155+
# any signature_def fields in a checkpoint's meta_graph_def.
170156
meta_graph_def = meta_graph_pb2.MetaGraphDef()
171157
with open("{}.meta".format(ckpt_path), 'rb') as fin:
172158
meta_graph_def.ParseFromString(fin.read())
@@ -177,7 +163,9 @@ def import_graph_fn(sess):
177163
sig_def = None
178164
if signature_def_key is not None:
179165
sig_def = meta_graph_def.signature_def[signature_def_key]
180-
# TODO: check if sig_def is valid
166+
assert sig_def, 'singnature_def_key {} provided, '.format(signature_def_key) + \
167+
'but failed to find it from the meta_graph_def ' + \
168+
'from checkpoint {}'.format(checkpoint_dir)
181169

182170
return sig_def
183171

python/sparkdl/transformers/param.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
from sparkdl.graph.builder import GraphFunction, IsolatedSession
3030
import sparkdl.graph.utils as tfx
3131
from sparkdl.graph.input import TFInputGraph, TFInputGraphBuilder
32-
"""
33-
Copied from PySpark for backward compatibility. First in Apache Spark version 2.1.1.
34-
"""
3532

33+
########################################################
34+
# Copied from PySpark for backward compatibility. First in Apache Spark version 2.1.1.
35+
########################################################
3636

3737
def keyword_only(func):
3838
"""
@@ -98,34 +98,22 @@ def getOutputCol(self):
9898
return self.getOrDefault(self.outputCol)
9999

100100

101-
"""
102-
TensorFlow Specific Parameters
103-
"""
104-
101+
########################################################
102+
# New in sparkdl: TensorFlow Specific Parameters
103+
########################################################
105104

106105
class SparkDLTypeConverters(object):
107106
@staticmethod
108107
def toTFGraph(value):
109108
if isinstance(value, tf.Graph):
110109
return value
111-
elif isinstance(value, GraphFunction):
112-
with IsolatedSession() as issn:
113-
issn.importGraphFunction(value, prefix='')
114-
g = issn.graph
115-
return g
116110
else:
117111
raise TypeError("Could not convert %s to TensorFlow Graph" % type(value))
118112

119113
@staticmethod
120114
def toTFInputGraph(value):
121115
if isinstance(value, TFInputGraph):
122116
return value
123-
elif isinstance(value, TFInputGraphBuilder):
124-
return value
125-
elif isinstance(value, tf.Graph):
126-
return TFInputGraphBuilder.fromGraph(value)
127-
elif isinstance(value, tf.GraphDef):
128-
return TFInputGraphBuilder.fromGraphDef(value)
129117
else:
130118
raise TypeError("Could not convert %s to TFInputGraph" % type(value))
131119

@@ -171,9 +159,6 @@ def converter(value):
171159
return converter
172160

173161

174-
# New in sparkdl
175-
176-
177162
class HasOutputMapping(Params):
178163
"""
179164
Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs
@@ -185,7 +170,11 @@ class HasOutputMapping(Params):
185170
typeConverter=SparkDLTypeConverters.asTensorToColumnMap)
186171

187172
def setOutputMapping(self, value):
188-
return self._set(outputMapping=value)
173+
# NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the
174+
# serializable TFInputGraph object once the inputMapping and outputMapping
175+
# parameters are provided.
176+
raise NotImplementedError(
177+
"Please use the Transformer's constructor to assign `outputMapping` field.")
189178

190179
def getOutputMapping(self):
191180
return self.getOrDefault(self.outputMapping)
@@ -202,28 +191,36 @@ class HasInputMapping(Params):
202191
typeConverter=SparkDLTypeConverters.asColumnToTensorMap)
203192

204193
def setInputMapping(self, value):
205-
return self._set(inputMapping=value)
194+
# NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the
195+
# serializable TFInputGraph object once the inputMapping and outputMapping
196+
# parameters are provided.
197+
raise NotImplementedError(
198+
"Please use the Transformer's constructor to assigne `inputMapping` field.")
206199

207200
def getInputMapping(self):
208201
return self.getOrDefault(self.inputMapping)
209202

210203

211204
class HasTFInputGraph(Params):
212205
"""
213-
Mixin for param tfGraph: the :class:`tf.Graph` object that represents a TensorFlow computation.
206+
Mixin for param tfInputGraph: a serializable object derived from a TensorFlow computation graph.
214207
"""
215208
tfInputGraph = Param(
216209
Params._dummy(),
217210
"tfInputGraph",
218-
"TensorFlow Graph object",
211+
"A serializable object derived from a TensorFlow computation graph",
219212
typeConverter=SparkDLTypeConverters.toTFInputGraph)
220213

221214
def __init__(self):
222215
super(HasTFInputGraph, self).__init__()
223216
self._setDefault(tfInputGraph=None)
224217

225218
def setTFInputGraph(self, value):
226-
return self._set(tfInputGraph=value)
219+
# NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the
220+
# serializable TFInputGraph object once the inputMapping and outputMapping
221+
# parameters are provided.
222+
raise NotImplementedError(
223+
"Please use the Transformer's constructor to assign `tfInputGraph` field.")
227224

228225
def getTFInputGraph(self):
229226
return self.getOrDefault(self.tfInputGraph)
@@ -236,7 +233,7 @@ class HasTFHParams(Params):
236233
tfHParams = Param(
237234
Params._dummy(),
238235
"hparams",
239-
"instance of :class:`tf.contrib.training.HParams`",
236+
"instance of :class:`tf.contrib.training.HParams`, a key-value map-like object",
240237
typeConverter=SparkDLTypeConverters.toTFHParams)
241238

242239
def setTFHParams(self, value):

python/sparkdl/transformers/tf_tensor.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pyspark.ml import Transformer
2222

2323
import sparkdl.graph.utils as tfx
24-
from sparkdl.graph.input import TFInputGraphBuilder
24+
from sparkdl.graph.input import TFInputGraph, TFInputGraphBuilder
2525
from sparkdl.transformers.param import (keyword_only, SparkDLTypeConverters, HasInputMapping,
2626
HasOutputMapping, HasTFInputGraph, HasTFHParams)
2727

@@ -57,15 +57,38 @@ def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tf
5757
"""
5858
super(TFTransformer, self).__init__()
5959
kwargs = self._input_kwargs
60-
_maybe_gin = SparkDLTypeConverters.toTFInputGraph(tfInputGraph)
60+
# The set of parameters either come from some helper functions,
61+
# in which case type(_maybe_gin) is already TFInputGraph.
62+
_maybe_gin = tfInputGraph
63+
if isinstance(_maybe_gin, TFInputGraph):
64+
return self._set(**kwargs)
65+
66+
# Otherwise, `_maybe_gin` needs to be converted to TFInputGraph
67+
# We put all the conversion logic here rather than in SparkDLTypeConverters
6168
if isinstance(_maybe_gin, TFInputGraphBuilder):
62-
kwargs['tfInputGraph'] = _maybe_gin.build(inputMapping, outputMapping)
69+
gin = _maybe_gin
70+
elif isinstance(_maybe_gin, tf.Graph):
71+
gin = TFInputGraphBuilder.fromGraph(_maybe_gin)
72+
elif isinstance(_maybe_gin, tf.GraphDef):
73+
gin = TFInputGraphBuilder.fromGraphDef(_maybe_gin)
74+
else:
75+
raise TypeError("TFTransformer expect tfInputGraph convertible to TFInputGraph, " + \
76+
"but the given type {} cannot be converted, ".format(type(tfInputGraph)) + \
77+
"please provide `tf.Graph`, `tf.GraphDef` or use one of the " + \
78+
"`get_params_from_*` helper functions to build parameters")
79+
80+
gin, input_mapping, output_mapping = gin.build(inputMapping, outputMapping)
81+
kwargs['tfInputGraph'] = gin
82+
kwargs['inputMapping'] = input_mapping
83+
kwargs['outputMapping'] = output_mapping
84+
85+
# Further conanonicalization, e.g. converting dict to sorted str pairs happens here
6386
return self._set(**kwargs)
6487

6588
def _transform(self, dataset):
6689
gin = self.getTFInputGraph()
67-
input_mapping = gin.input_mapping
68-
output_mapping = gin.output_mapping
90+
input_mapping = self.getInputMapping()
91+
output_mapping = self.getOutputMapping()
6992

7093
graph = tf.Graph()
7194
with tf.Session(graph=graph):

python/tests/transformers/tf_tensor_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def test_build_from_tf_graph(self):
5454

5555
# Build the TensorFlow graph
5656
with tf.Session() as sess:
57-
#x = tf.placeholder(tf.float64, shape=[None, vec_size])
5857
x = tfs.block(analyzed_df, 'vec')
5958
z = tf.reduce_mean(x, axis=1)
6059
graph = sess.graph
@@ -68,9 +67,9 @@ def test_build_from_tf_graph(self):
6867

6968
# Apply the transform
7069
gin_from_graph = TFInputGraphBuilder.fromGraph(graph)
71-
for gin in [gin_from_graph, graph]:
70+
for gin_or_graph in [gin_from_graph, graph]:
7271
transfomer = TFTransformer(
73-
tfInputGraph=TFInputGraphBuilder.fromGraph(graph),
72+
tfInputGraph=gin_or_graph,
7473
inputMapping={
7574
'vec': x
7675
},

0 commit comments

Comments
 (0)