From cf856db1155ff4ecc327954e05995d2cb4e09229 Mon Sep 17 00:00:00 2001 From: Philip Yang Date: Wed, 22 Nov 2017 09:04:04 -0800 Subject: [PATCH] TensorFlow Transformer Part-2 (#9) * update utils * tests * fix style Using the following YAPF style ======================================================== based_on_style = pep8 ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False COLUMN_LIMIT=100 SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True SPLIT_BEFORE_FIRST_ARGUMENT=False SPLIT_BEFORE_NAMED_ASSIGNS=False SPLIT_PENALTY_AFTER_OPENING_BRACKET=30 USE_TABS=False ======================================================== * refactoring tfx API * test refactoring * PR comments 1. docs in graph/utils.py * (wip) utils test * a few more tests for utils * test update cont'd * PR comments * PR comments * PR comments * TensorFlow Transformer Part-3 (#10) * intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments --- python/docs/sparkdl.rst | 2 + python/sparkdl/__init__.py | 8 +- python/sparkdl/graph/builder.py | 18 +- python/sparkdl/graph/input.py | 355 ++++++++++++++++++ python/sparkdl/graph/tensorframes_udf.py | 14 +- python/sparkdl/graph/utils.py | 165 ++++---- python/sparkdl/param/__init__.py | 2 +- python/sparkdl/param/converters.py | 17 +- python/sparkdl/param/shared_params.py | 32 +- python/sparkdl/transformers/keras_image.py | 8 +- python/sparkdl/transformers/tf_image.py | 6 +- python/sparkdl/transformers/tf_tensor.py | 105 ++++++ python/tests/graph/test_builder.py | 18 +- python/tests/graph/test_import.py | 322 ++++++++++++++++ python/tests/graph/test_pieces.py | 4 +- python/tests/graph/test_utils.py | 174 +++++++++ python/tests/tests.py | 17 +- .../tests/transformers/tf_transformer_test.py | 146 +++++++ python/tests/udf/keras_sql_udf_test.py | 3 +- 19 files changed, 1297 insertions(+), 119 deletions(-) create mode 100644 python/sparkdl/graph/input.py create mode 100644 python/sparkdl/transformers/tf_tensor.py create mode 100644 python/tests/graph/test_import.py create mode 100644 python/tests/graph/test_utils.py create mode 100644 python/tests/transformers/tf_transformer_test.py diff --git a/python/docs/sparkdl.rst b/python/docs/sparkdl.rst index c92e60cc..bf0c86f8 100644 --- a/python/docs/sparkdl.rst +++ b/python/docs/sparkdl.rst @@ -6,8 +6,10 @@ Subpackages .. toctree:: + sparkdl.estimators sparkdl.graph sparkdl.image + sparkdl.param sparkdl.transformers sparkdl.udf sparkdl.utils diff --git a/python/sparkdl/__init__.py b/python/sparkdl/__init__.py index aa15059a..06b91bc8 100644 --- a/python/sparkdl/__init__.py +++ b/python/sparkdl/__init__.py @@ -13,15 +13,17 @@ # limitations under the License. # +from .graph.input import TFInputGraph from .image.imageIO import imageSchema, imageType, readImages from .transformers.keras_image import KerasImageFileTransformer from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer from .transformers.tf_image import TFImageTransformer +from .transformers.tf_tensor import TFTransformer from .transformers.utils import imageInputPlaceholder + __all__ = [ 'imageSchema', 'imageType', 'readImages', - 'TFImageTransformer', - 'DeepImagePredictor', 'DeepImageFeaturizer', - 'KerasImageFileTransformer', + 'TFImageTransformer', 'TFInputGraph', 'TFTransformer', + 'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer', 'imageInputPlaceholder'] diff --git a/python/sparkdl/graph/builder.py b/python/sparkdl/graph/builder.py index 86c3b3ce..a7d7122f 100644 --- a/python/sparkdl/graph/builder.py +++ b/python/sparkdl/graph/builder.py @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False): self.graph = graph or tf.Graph() self.sess = tf.Session(graph=self.graph) if using_keras: + self.using_keras = True self.keras_prev_sess = K.get_session() else: + self.using_keras = False self.keras_prev_sess = None def __enter__(self): - self.sess.as_default() self.sess.__enter__() - if self.keras_prev_sess is not None: + if self.using_keras: K.set_session(self.sess) return self def __exit__(self, *args): - if self.keras_prev_sess is not None: + if self.using_keras: K.set_session(self.keras_prev_sess) self.sess.__exit__(*args) @@ -87,8 +88,8 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True): else: gdef = self.graph.as_graph_def(add_shapes=True) return GraphFunction(graph_def=gdef, - input_names=[tfx.validated_input(self.graph, elem) for elem in inputs], - output_names=[tfx.validated_output(self.graph, elem) for elem in outputs]) + input_names=[tfx.validated_input(elem, self.graph) for elem in inputs], + output_names=[tfx.validated_output(elem, self.graph) for elem in outputs]) def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs): """ @@ -130,8 +131,8 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k return_elements=gfn.output_names, name=scope_name, **gdef_kargs) - feeds = [tfx.get_tensor(self.graph, name) for name in input_names] - fetches = [tfx.get_tensor(self.graph, name) for name in output_names] + feeds = [tfx.get_tensor(name, self.graph) for name in input_names] + fetches = [tfx.get_tensor(name, self.graph) for name in output_names] return (feeds, fetches) @@ -233,7 +234,7 @@ def fromList(cls, functions): _, first_gfn = functions[0] feeds, _ = issn.importGraphFunction(first_gfn, prefix='') for tnsr in feeds: - name = tfx.op_name(issn.graph, tnsr) + name = tfx.op_name(tnsr, issn.graph) first_input_info.append((tnsr.dtype, tnsr.shape, name)) # TODO: make sure that this graph is not reused to prevent name conflict # Report error if the graph is not manipulated by anyone else @@ -268,4 +269,3 @@ def fromList(cls, functions): gfn = issn.asGraphFunction(first_inputs, last_outputs) return gfn - diff --git a/python/sparkdl/graph/input.py b/python/sparkdl/graph/input.py new file mode 100644 index 00000000..2dedc8ef --- /dev/null +++ b/python/sparkdl/graph/input.py @@ -0,0 +1,355 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import tensorflow as tf +from tensorflow.core.protobuf import meta_graph_pb2 # pylint: disable=no-name-in-module + +import sparkdl.graph.utils as tfx + +__all__ = ["TFInputGraph"] + +# pylint: disable=invalid-name,wrong-spelling-in-comment,wrong-spelling-in-docstring + +class TFInputGraph(object): + """ + An opaque object containing TensorFlow graph. + This object can be serialized. + + .. note:: We recommend constructing this object using one of the class constructor methods. + + - :py:meth:`fromGraph` + - :py:meth:`fromGraphDef` + - :py:meth:`fromCheckpoint` + - :py:meth:`fromCheckpointWithSignature` + - :py:meth:`fromSavedModel` + - :py:meth:`fromSavedModelWithSignature` + + + When the graph contains serving signatures in which a set of well-known names are associated + with their corresponding raw tensor names in the graph, we extract and store them here. + For example, the TensorFlow saved model may contain the following structure, + so that end users can retrieve the the input tensor via `well_known_input_sig` and + the output tensor via `well_known_output_sig` without knowing the actual tensor names a priori. + + .. code-block:: python + + sigdef: {'well_known_prediction_signature': + inputs { key: "well_known_input_sig" + value { + name: "tnsrIn:0" + dtype: DT_DOUBLE + tensor_shape { dim { size: -1 } dim { size: 17 } } + } + } + outputs { key: "well_known_output_sig" + value { + name: "tnsrOut:0" + dtype: DT_DOUBLE + tensor_shape { dim { size: -1 } } + } + }} + + + In this case, the class will internally store the mapping from signature names to tensor names. + + .. code-block:: python + + {'well_known_input_sig': 'tnsrIn:0'} + {'well_known_output_sig': 'tnsrOut:0'} + + + :param graph_def: :py:obj:`tf.GraphDef`, a serializable object containing the topology and + computation units of the TensorFlow graph. The graph object is prepared for + inference, i.e. the variables are converted to constants and operations like + BatchNormalization_ are converted to be independent of input batch. + + .. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization + + :param input_tensor_name_from_signature: dict, signature key names mapped to tensor names. + Please see the example above. + :param output_tensor_name_from_signature: dict, signature key names mapped to tensor names + Please see the example above. + """ + + + def __init__(self, graph_def, input_tensor_name_from_signature, + output_tensor_name_from_signature): + self.graph_def = graph_def + self.input_tensor_name_from_signature = input_tensor_name_from_signature + self.output_tensor_name_from_signature = output_tensor_name_from_signature + + def translateInputMapping(self, input_mapping): + """ + When the meta_graph contains signature_def, we expect users to provide + input and output mapping with respect to the tensor reference keys + embedded in the `signature_def`. + + This function translates the input_mapping into the canonical format, + which maps input DataFrame column names to tensor names. + + :param input_mapping: dict, DataFrame column name to tensor reference names + defined in the signature_def key. + """ + assert self.input_tensor_name_from_signature is not None + _input_mapping = {} + if isinstance(input_mapping, dict): + input_mapping = list(input_mapping.items()) + assert isinstance(input_mapping, list) + for col_name, sig_key in input_mapping: + tnsr_name = self.input_tensor_name_from_signature[sig_key] + _input_mapping[col_name] = tnsr_name + return _input_mapping + + def translateOutputMapping(self, output_mapping): + """ + When the meta_graph contains signature_def, we expect users to provide + input and output mapping with respect to the tensor reference keys + embedded in the `signature_def`. + + This function translates the output_mapping into the canonical format, + which maps tensor names into input DataFrame column names. + + :param output_mapping: dict, tensor reference names defined in the signature_def keys + into the output DataFrame column names. + """ + assert self.output_tensor_name_from_signature is not None + _output_mapping = {} + if isinstance(output_mapping, dict): + output_mapping = list(output_mapping.items()) + assert isinstance(output_mapping, list) + for sig_key, col_name in output_mapping: + tnsr_name = self.output_tensor_name_from_signature[sig_key] + _output_mapping[tnsr_name] = col_name + return _output_mapping + + @classmethod + def fromGraph(cls, graph, sess, feed_names, fetch_names): + """ + Construct a TFInputGraph from a in memory `tf.Graph` object. + The graph might contain variables that are maintained in the provided session. + Thus we need an active session in which the graph's variables are initialized or + restored. We do not close the session. As a result, this constructor can be used + inside a standard TensorFlow session context. + + .. code-block:: python + + with tf.Session() as sess: + graph = import_my_tensorflow_graph(...) + input = TFInputGraph.fromGraph(graph, sess, ...) + + :param graph: a :py:class:`tf.Graph` object containing the topology and computation units of + the TensorFlow graph. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + + @classmethod + def fromGraphDef(cls, graph_def, feed_names, fetch_names): + """ + Construct a TFInputGraph from a tf.GraphDef object. + + :param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and + computation units of the TensorFlow graph. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + assert isinstance(graph_def, tf.GraphDef), \ + ('expect tf.GraphDef type but got', type(graph_def)) + + graph = tf.Graph() + with tf.Session(graph=graph) as sess: + tf.import_graph_def(graph_def, name='') + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + + @classmethod + def fromCheckpoint(cls, checkpoint_dir, feed_names, fetch_names): + """ + Construct a TFInputGraph object from a checkpoint, ignore the embedded + signature_def, if there is any. + + :param checkpoint_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + return _from_checkpoint_impl(checkpoint_dir, signature_def_key=None, feed_names=feed_names, + fetch_names=fetch_names) + + @classmethod + def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key): + """ + Construct a TFInputGraph object from a checkpoint, using the embedded + signature_def. Throw an error if we cannot find an entry with the `signature_def_key` + inside the `signature_def`. + + :param checkpoint_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param signature_def_key: str, key (name) of the signature_def to use. It should be in + the list of `signature_def` structures saved with the checkpoint. + """ + assert signature_def_key is not None + return _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names=None, + fetch_names=None) + + @classmethod + def fromSavedModel(cls, saved_model_dir, tag_set, feed_names, fetch_names): + """ + Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory. + Ignore the the embedded signature_def, if there is any. + + :param saved_model_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param tag_set: str, name of the graph stored in this meta_graph of the saved model + that we are interested in using. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=None, + feed_names=feed_names, fetch_names=fetch_names) + + @classmethod + def fromSavedModelWithSignature(cls, saved_model_dir, tag_set, signature_def_key): + """ + Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory, + using the embedded signature_def. Throw error if we cannot find an entry with + the `signature_def_key` inside the `signature_def`. + + :param saved_model_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param tag_set: str, name of the graph stored in this meta_graph of the saved model + that we are interested in using. + :param signature_def_key: str, key (name) of the signature_def to use. It should be in + the list of `signature_def` structures saved with the + TensorFlow `SavedModel`. + """ + assert signature_def_key is not None + return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=signature_def_key, + feed_names=None, fetch_names=None) + + +def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names): + """ + Construct a TFInputGraph from a model checkpoint. + Notice that one should either provide the `signature_def_key` or provide both + `feed_names` and `fetch_names`. Please set the unprovided values to None. + + :param signature_def_key: str, name of the mapping contained inside the `signature_def` + from which we retrieve the signature key to tensor names mapping. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + assert (feed_names is None) == (fetch_names is None), \ + 'feed_names and fetch_names, if provided must be both non-None.' + assert (feed_names is None) != (signature_def_key is None), \ + 'must either provide feed_names or singnature_def_key' + + graph = tf.Graph() + with tf.Session(graph=graph) as sess: + # Load checkpoint and import the graph + ckpt_path = tf.train.latest_checkpoint(checkpoint_dir) + + # NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def + # the current `import_graph_def` function seems to ignore + # any signature_def fields in a checkpoint's meta_graph_def. + meta_graph_def = meta_graph_pb2.MetaGraphDef() + with open("{}.meta".format(ckpt_path), 'rb') as fin: + meta_graph_def.ParseFromString(fin.read()) + + saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True) + saver.restore(sess, ckpt_path) + + if signature_def_key is not None: + sig_def = meta_graph_def.signature_def[signature_def_key] + return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) + else: + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + +def _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key, feed_names, fetch_names): + """ + Construct a TFInputGraph from a SavedModel. + Notice that one should either provide the `signature_def_key` or provide both + `feed_names` and `fetch_names`. Please set the unprovided values to None. + + :param signature_def_key: str, name of the mapping contained inside the `signature_def` + from which we retrieve the signature key to tensor names mapping. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + assert (feed_names is None) == (fetch_names is None), \ + 'feed_names and fetch_names, if provided must appear together' + assert (feed_names is None) != (signature_def_key is None), \ + 'must either provide feed_names or singnature_def_key' + + graph = tf.Graph() + with tf.Session(graph=graph) as sess: + tag_sets = tag_set.split(',') + meta_graph_def = tf.saved_model.loader.load(sess, tag_sets, saved_model_dir) + + if signature_def_key is not None: + sig_def = tf.contrib.saved_model.get_signature_def_by_key(meta_graph_def, + signature_def_key) + return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) + else: + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + + +def _build_with_sig_def(sess, graph, sig_def): + # pylint: disable=protected-access + assert sig_def, 'signature_def must not be None' + + with sess.as_default(), graph.as_default(): + feed_mapping = {} + feed_names = [] + for sigdef_key, tnsr_info in sig_def.inputs.items(): + tnsr_name = tnsr_info.name + feed_mapping[sigdef_key] = tnsr_name + feed_names.append(tnsr_name) + + fetch_mapping = {} + fetch_names = [] + for sigdef_key, tnsr_info in sig_def.outputs.items(): + tnsr_name = tnsr_info.name + fetch_mapping[sigdef_key] = tnsr_name + fetch_names.append(tnsr_name) + + for tnsr_name in feed_names: + assert tfx.get_op(tnsr_name, graph), \ + 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) + fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] + graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) + + return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping, + output_tensor_name_from_signature=fetch_mapping) + + +def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names): + assert feed_names is not None, "must provide feed_names" + assert fetch_names is not None, "must provide fetch names" + + with sess.as_default(), graph.as_default(): + for tnsr_name in feed_names: + assert tfx.get_op(tnsr_name, graph), \ + 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) + fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] + graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) + + return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None, + output_tensor_name_from_signature=None) diff --git a/python/sparkdl/graph/tensorframes_udf.py b/python/sparkdl/graph/tensorframes_udf.py index 54027b8d..aa1531b4 100644 --- a/python/sparkdl/graph/tensorframes_udf.py +++ b/python/sparkdl/graph/tensorframes_udf.py @@ -33,7 +33,7 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal .. code-block:: python from sparkdl.graph.tensorframes_udf import makeUDF - + with IsolatedSession() as issn: x = tf.placeholder(tf.double, shape=[], name="input_x") z = tf.add(x, 3, name='z') @@ -45,7 +45,7 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal df = spark.createDataFrame([Row(xCol=float(x)) for x in range(100)]) df.createOrReplaceTempView("my_float_table") - spark.sql("select my_tensorflow_udf(xCol) as zCol from my_float_table").show() + spark.sql("select my_tensorflow_udf(xCol) as zCol from my_float_table").show() :param graph: :py:class:`tf.Graph`, a TensorFlow Graph :param udf_name: str, name of the SQL UDF @@ -77,18 +77,18 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal tfs.core._add_graph(graph, jvm_builder) # Obtain the fetches and their shapes - fetch_names = [tfx.tensor_name(graph, fetch) for fetch in fetches] - fetch_shapes = [tfx.get_shape(graph, fetch) for fetch in fetches] + fetch_names = [tfx.tensor_name(fetch, graph) for fetch in fetches] + fetch_shapes = [tfx.get_shape(fetch, graph) for fetch in fetches] # Traverse the graph nodes and obtain all the placeholders and their shapes placeholder_names = [] placeholder_shapes = [] for node in graph.as_graph_def(add_shapes=True).node: if len(node.input) == 0 and str(node.op) == 'Placeholder': - tnsr_name = tfx.tensor_name(graph, node.name) + tnsr_name = tfx.tensor_name(node.name, graph) tnsr = graph.get_tensor_by_name(tnsr_name) try: - tnsr_shape = tfx.get_shape(graph, tnsr) + tnsr_shape = tfx.get_shape(tnsr, graph) placeholder_names.append(tnsr_name) placeholder_shapes.append(tnsr_shape) except ValueError: @@ -98,7 +98,7 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal jvm_builder.shape(fetch_names + placeholder_names, fetch_shapes + placeholder_shapes) jvm_builder.fetches(fetch_names) # Passing feeds to TensorFrames - placeholder_op_names = [tfx.op_name(graph, name) for name in placeholder_names] + placeholder_op_names = [tfx.op_name(name, graph) for name in placeholder_names] # Passing the graph input to DataFrame column mapping and additional placeholder names tfs.core._add_inputs(jvm_builder, feeds_to_fields_map, placeholder_op_names) diff --git a/python/sparkdl/graph/utils.py b/python/sparkdl/graph/utils.py index 45d8b065..64e093fe 100644 --- a/python/sparkdl/graph/utils.py +++ b/python/sparkdl/graph/utils.py @@ -16,8 +16,6 @@ import logging import six -import webbrowser -from tempfile import NamedTemporaryFile import tensorflow as tf @@ -35,14 +33,15 @@ def validated_graph(graph): """ - Check if the input is a valid tf.Graph + Check if the input is a valid :py:class:`tf.Graph` and return it. + Raise an error otherwise. - :param graph: tf.Graph, a TensorFlow Graph object + :param graph: :py:class:`tf.Graph`, a TensorFlow Graph object """ assert isinstance(graph, tf.Graph), 'must provide tf.Graph, but get {}'.format(type(graph)) return graph -def get_shape(graph, tfobj_or_name): +def get_shape(tfobj_or_name, graph): """ Return the shape of the tensor as a list @@ -50,38 +49,44 @@ def get_shape(graph, tfobj_or_name): :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ graph = validated_graph(graph) - _shape = get_tensor(graph, tfobj_or_name).get_shape().as_list() + _shape = get_tensor(tfobj_or_name, graph).get_shape().as_list() return [-1 if x is None else x for x in _shape] -def get_op(graph, tfobj_or_name): +def get_op(tfobj_or_name, graph): """ - Get a tf.Operation object + Get a :py:class:`tf.Operation` object. - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the operation. + By default the graph we don't require this argument to be provided. """ graph = validated_graph(graph) + _assert_same_graph(tfobj_or_name, graph) if isinstance(tfobj_or_name, tf.Operation): return tfobj_or_name name = tfobj_or_name if isinstance(tfobj_or_name, tf.Tensor): name = tfobj_or_name.name if not isinstance(name, six.string_types): - raise TypeError('invalid op request for {} of {}'.format(name, type(name))) - _op_name = as_op_name(name) + raise TypeError('invalid op request for [type {}] {}'.format(type(name), name)) + _op_name = op_name(name, graph=None) op = graph.get_operation_by_name(_op_name) - assert op is not None, \ - 'cannot locate op {} in current graph'.format(_op_name) + err_msg = 'cannot locate op {} in the current graph, got [type {}] {}' + assert isinstance(op, tf.Operation), err_msg.format(_op_name, type(op), op) return op -def get_tensor(graph, tfobj_or_name): +def get_tensor(tfobj_or_name, graph): """ - Get a tf.Tensor object + Get a :py:class:`tf.Tensor` object - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the tensor. + By default the graph we don't require this argument to be provided. """ graph = validated_graph(graph) + _assert_same_graph(tfobj_or_name, graph) if isinstance(tfobj_or_name, tf.Tensor): return tfobj_or_name name = tfobj_or_name @@ -89,59 +94,71 @@ def get_tensor(graph, tfobj_or_name): name = tfobj_or_name.name if not isinstance(name, six.string_types): raise TypeError('invalid tensor request for {} of {}'.format(name, type(name))) - _tensor_name = as_tensor_name(name) + _tensor_name = tensor_name(name, graph=None) tnsr = graph.get_tensor_by_name(_tensor_name) - assert tnsr is not None, \ - 'cannot locate tensor {} in current graph'.format(_tensor_name) + err_msg = 'cannot locate tensor {} in the current graph, got [type {}] {}' + assert isinstance(tnsr, tf.Tensor), err_msg.format(_tensor_name, type(tnsr), tnsr) return tnsr -def as_tensor_name(name): - """ - Derive tf.Tensor name from an op/tensor name. - We do not check if the tensor exist (as no graph parameter is passed in). - - :param name: op name or tensor name - """ - assert isinstance(name, six.string_types) - name_parts = name.split(":") - assert len(name_parts) <= 2, name_parts - if len(name_parts) < 2: - name += ":0" - return name - -def as_op_name(name): - """ - Derive tf.Operation name from an op/tensor name - We do not check if the operation exist (as no graph parameter is passed in). - - :param name: op name or tensor name - """ - assert isinstance(name, six.string_types) - name_parts = name.split(":") - assert len(name_parts) <= 2, name_parts - return name_parts[0] - -def op_name(graph, tfobj_or_name): - """ - Get the name of a tf.Operation - - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either - """ - graph = validated_graph(graph) - return get_op(graph, tfobj_or_name).name - -def tensor_name(graph, tfobj_or_name): - """ - Get the name of a tf.Tensor - - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either - """ - graph = validated_graph(graph) - return get_tensor(graph, tfobj_or_name).name +def tensor_name(tfobj_or_name, graph=None): + """ + Derive the :py:class:`tf.Tensor` name from a :py:class:`tf.Operation` or :py:class:`tf.Tensor` + object, or its name. + If a name is provided and the graph is not, we will derive the tensor name based on + TensorFlow's naming convention. + If the input is a TensorFlow object, or the graph is given, we also check that + the tensor exists in the associated graph. + + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the tensor. + By default the graph we don't require this argument to be provided. + """ + if graph is not None: + return get_tensor(tfobj_or_name, graph).name + if isinstance(tfobj_or_name, six.string_types): + # If input is a string, assume it is a name and infer the corresponding tensor name. + # WARNING: this depends on TensorFlow's tensor naming convention + name = tfobj_or_name + name_parts = name.split(":") + assert len(name_parts) <= 2, name_parts + if len(name_parts) < 2: + name += ":0" + return name + elif hasattr(tfobj_or_name, 'graph'): + return get_tensor(tfobj_or_name, tfobj_or_name.graph).name + else: + raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name))) + +def op_name(tfobj_or_name, graph=None): + """ + Derive the :py:class:`tf.Operation` name from a :py:class:`tf.Operation` or + :py:class:`tf.Tensor` object, or its name. + If a name is provided and the graph is not, we will derive the operation name based on + TensorFlow's naming convention. + If the input is a TensorFlow object, or the graph is given, we also check that + the operation exists in the associated graph. + + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the operation. + By default the graph we don't require this argument to be provided. + """ + if graph is not None: + return get_op(tfobj_or_name, graph).name + if isinstance(tfobj_or_name, six.string_types): + # If input is a string, assume it is a name and infer the corresponding operation name. + # WARNING: this depends on TensorFlow's operation naming convention + name = tfobj_or_name + name_parts = name.split(":") + assert len(name_parts) <= 2, name_parts + return name_parts[0] + elif hasattr(tfobj_or_name, 'graph'): + return get_op(tfobj_or_name, tfobj_or_name.graph).name + else: + raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name))) -def validated_output(graph, tfobj_or_name): +def validated_output(tfobj_or_name, graph): """ Validate and return the output names useable GraphFunction @@ -149,9 +166,9 @@ def validated_output(graph, tfobj_or_name): :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ graph = validated_graph(graph) - return op_name(graph, tfobj_or_name) + return op_name(tfobj_or_name, graph) -def validated_input(graph, tfobj_or_name): +def validated_input(tfobj_or_name, graph): """ Validate and return the input names useable GraphFunction @@ -159,7 +176,7 @@ def validated_input(graph, tfobj_or_name): :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ graph = validated_graph(graph) - name = op_name(graph, tfobj_or_name) + name = op_name(tfobj_or_name, graph) op = graph.get_operation_by_name(name) assert 'Placeholder' == op.type, \ ('input must be Placeholder, but get', op.type) @@ -186,7 +203,7 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): gdef_frozen = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(add_shapes=True), - [op_name(graph, tnsr) for tnsr in fetches]) + [op_name(tnsr, graph) for tnsr in fetches]) if should_close_session: sess.close() @@ -198,3 +215,9 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): return g else: return gdef_frozen + + +def _assert_same_graph(tfobj, graph): + if graph is not None and hasattr(tfobj, 'graph'): + err_msg = 'the graph of TensorFlow element {} != graph {}' + assert tfobj.graph == graph, err_msg.format(tfobj, graph) diff --git a/python/sparkdl/param/__init__.py b/python/sparkdl/param/__init__.py index a291a7d4..ca1a9121 100644 --- a/python/sparkdl/param/__init__.py +++ b/python/sparkdl/param/__init__.py @@ -16,7 +16,7 @@ from sparkdl.param.shared_params import ( keyword_only, HasInputCol, HasOutputCol, HasLabelCol, # TFTransformer Params - HasInputMapping, HasOutputMapping, HasTFHParams, + HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams, # Keras Estimator Params HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName) from sparkdl.param.converters import SparkDLTypeConverters diff --git a/python/sparkdl/param/converters.py b/python/sparkdl/param/converters.py index a692a013..25a2e3a1 100644 --- a/python/sparkdl/param/converters.py +++ b/python/sparkdl/param/converters.py @@ -30,6 +30,7 @@ from pyspark.ml.param import TypeConverters +from sparkdl.graph.input import * import sparkdl.utils.keras_model as kmutil __all__ = ['SparkDLTypeConverters'] @@ -52,6 +53,13 @@ def toTFGraph(value): raise TypeError("Could not convert %s to tf.Graph" % type(value)) return value + @staticmethod + def toTFInputGraph(value): + if isinstance(value, TFInputGraph): + return value + else: + raise TypeError("Could not convert %s to TFInputGraph" % type(value)) + @staticmethod def asColumnToTensorNameMap(value): """ @@ -167,7 +175,14 @@ def _check_is_tensor_name(_maybe_tnsr_name): raise TypeError(err_msg.format(type(_maybe_tnsr_name))) # The check is taken from TensorFlow's NodeDef protocol buffer. - # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25 + # Each input is "node:src_output" with "node" being a string name and + # "src_output" indicating which output tensor to use from "node". If + # "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + # may optionally be followed by control inputs that have the format + # "^node". + # Reference: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto + # https://stackoverflow.com/questions/36150834/how-does-tensorflow-name-tensors try: _, src_idx = _maybe_tnsr_name.split(":") _ = int(src_idx) diff --git a/python/sparkdl/param/shared_params.py b/python/sparkdl/param/shared_params.py index 432d618d..1116aa54 100644 --- a/python/sparkdl/param/shared_params.py +++ b/python/sparkdl/param/shared_params.py @@ -19,9 +19,11 @@ """ import textwrap from functools import wraps +import six from pyspark.ml.param import Param, Params, TypeConverters +from sparkdl.graph.input import TFInputGraph from sparkdl.param.converters import SparkDLTypeConverters ######################################################## @@ -196,8 +198,9 @@ class HasOutputMapping(Params): """ Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs """ - outputMapping = Param(Params._dummy(), "outputMapping", - "Mapping output :class:`tf.Operation` names to DataFrame column names", + outputMapping = Param(Params._dummy(), + "outputMapping", + "Mapping output :class:`tf.Tensor` names to DataFrame column names", typeConverter=SparkDLTypeConverters.asTensorNameToColumnMap) def setOutputMapping(self, value): @@ -211,8 +214,9 @@ class HasInputMapping(Params): """ Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorOpName') pairs """ - inputMapping = Param(Params._dummy(), "inputMapping", - "Mapping input DataFrame column names to :class:`tf.Operation` names", + inputMapping = Param(Params._dummy(), + "inputMapping", + "Mapping input DataFrame column names to :class:`tf.Tensor` names", typeConverter=SparkDLTypeConverters.asColumnToTensorNameMap) def setInputMapping(self, value): @@ -222,6 +226,26 @@ def getInputMapping(self): return self.getOrDefault(self.inputMapping) +class HasTFInputGraph(Params): + """ + Mixin for param tfInputGraph: a serializable object derived from a TensorFlow computation graph. + """ + tfInputGraph = Param(Params._dummy(), + "tfInputGraph", + "A serializable object derived from a TensorFlow computation graph", + typeConverter=SparkDLTypeConverters.toTFInputGraph) + + def __init__(self): + super(HasTFInputGraph, self).__init__() + self._setDefault(tfInputGraph=None) + + def setTFInputGraph(self, value): + return self._set(tfInputGraph=value) + + def getTFInputGraph(self): + return self.getOrDefault(self.tfInputGraph) + + class HasTFHParams(Params): """ Mixin for TensorFlow model hyper-parameters diff --git a/python/sparkdl/transformers/keras_image.py b/python/sparkdl/transformers/keras_image.py index de10fc87..3c2762d9 100644 --- a/python/sparkdl/transformers/keras_image.py +++ b/python/sparkdl/transformers/keras_image.py @@ -76,14 +76,14 @@ def _transform(self, dataset): return transformer.transform(image_df).drop(self._loadedImageCol()) def _loadTFGraph(self): - with KSessionWrap() as (sess, g): + with KSessionWrap() as (sess, graph): assert K.backend() == "tensorflow", \ "Keras backend is not tensorflow but KerasImageTransformer only supports " + \ "tensorflow-backed Keras models." - with g.as_default(): + with graph.as_default(): K.set_learning_phase(0) # Testing phase model = load_model(self.getModelFile()) - out_op_name = tfx.op_name(g, model.output) + out_op_name = tfx.op_name(model.output, graph) self._inputTensor = model.input.name self._outputTensor = model.output.name - return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True) + return tfx.strip_and_freeze_until([out_op_name], graph, sess, return_graph=True) diff --git a/python/sparkdl/transformers/tf_image.py b/python/sparkdl/transformers/tf_image.py index 943af6e8..152a7fea 100644 --- a/python/sparkdl/transformers/tf_image.py +++ b/python/sparkdl/transformers/tf_image.py @@ -30,7 +30,7 @@ __all__ = ['TFImageTransformer'] -IMAGE_INPUT_TENSOR_NAME = tfx.as_tensor_name(utils.IMAGE_INPUT_PLACEHOLDER_NAME) +IMAGE_INPUT_TENSOR_NAME = tfx.tensor_name(utils.IMAGE_INPUT_PLACEHOLDER_NAME) USER_GRAPH_NAMESPACE = 'given' NEW_OUTPUT_PREFIX = 'sdl_flattened' @@ -136,7 +136,7 @@ def _transform(self, dataset): "__sdl_image_data") ) - tfs_output_name = tfx.op_name(final_graph, output_tensor) + tfs_output_name = tfx.op_name(output_tensor, final_graph) original_output_name = self._getOriginalOutputTensorName() output_shape = final_graph.get_tensor_by_name(original_output_name).shape output_mode = self.getOrDefault(self.outputMode) @@ -207,7 +207,7 @@ def _getFinalOutputTensorName(self): return NEW_OUTPUT_PREFIX + '_' + self.getOutputTensor().name def _getFinalOutputOpName(self): - return tfx.as_op_name(self._getFinalOutputTensorName()) + return tfx.op_name(self._getFinalOutputTensorName()) def _convertOutputToImage(self, df, tfs_output_col, output_shape): assert len(output_shape) == 4, str(output_shape) + " does not have 4 dimensions" diff --git a/python/sparkdl/transformers/tf_tensor.py b/python/sparkdl/transformers/tf_tensor.py new file mode 100644 index 00000000..7207f5f1 --- /dev/null +++ b/python/sparkdl/transformers/tf_tensor.py @@ -0,0 +1,105 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import logging +import tensorflow as tf +from tensorflow.python.tools import optimize_for_inference_lib as infr_opt +import tensorframes as tfs + +from pyspark.ml import Transformer + +import sparkdl.graph.utils as tfx +from sparkdl.param import (keyword_only, HasInputMapping, HasOutputMapping, + HasTFInputGraph, HasTFHParams) + +__all__ = ['TFTransformer'] + +logger = logging.getLogger('sparkdl') + +class TFTransformer(Transformer, HasTFInputGraph, HasTFHParams, HasInputMapping, HasOutputMapping): + """ + Applies the TensorFlow graph to the array column in DataFrame. + + Restrictions of the current API: + + We assume that + - All the inputs of the graphs have a "minibatch" dimension (i.e. an unknown leading + dimension) in the tensor shapes. + - Input DataFrame has an array column where all elements have the same length + - The transformer is expected to work on blocks of data at the same time. + """ + + @keyword_only + def __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): + """ + __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) + """ + super(TFTransformer, self).__init__() + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): + """ + setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) + """ + super(TFTransformer, self).__init__() + kwargs = self._input_kwargs + # Further conanonicalization, e.g. converting dict to sorted str pairs happens here + return self._set(**kwargs) + + def _optimize_for_inference(self): + """ Optimize the graph for inference """ + gin = self.getTFInputGraph() + input_mapping = self.getInputMapping() + output_mapping = self.getOutputMapping() + input_node_names = [tfx.op_name(tnsr_name) for _, tnsr_name in input_mapping] + output_node_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping] + + # NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type + opt_gdef = infr_opt.optimize_for_inference(gin.graph_def, + input_node_names, + output_node_names, + # TODO: below is the place to change for + # the `float64` data type issue. + tf.float64.as_datatype_enum) + return opt_gdef + + def _transform(self, dataset): + graph_def = self._optimize_for_inference() + input_mapping = self.getInputMapping() + output_mapping = self.getOutputMapping() + + graph = tf.Graph() + with tf.Session(graph=graph): + analyzed_df = tfs.analyze(dataset) + + out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping] + tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names) + + feed_dict = dict((tfx.op_name(tnsr_name, graph), col_name) + for col_name, tnsr_name in input_mapping) + fetches = [tfx.get_tensor(tnsr_op_name, graph) for tnsr_op_name in out_tnsr_op_names] + + out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict) + + # We still have to rename output columns + for tnsr_name, new_colname in output_mapping: + old_colname = tfx.op_name(tnsr_name, graph) + if old_colname != new_colname: + out_df = out_df.withColumnRenamed(old_colname, new_colname) + + return out_df diff --git a/python/tests/graph/test_builder.py b/python/tests/graph/test_builder.py index b0736896..93b3c9f5 100644 --- a/python/tests/graph/test_builder.py +++ b/python/tests/graph/test_builder.py @@ -78,15 +78,15 @@ def test_get_graph_elements(self): z = tf.add(x, 3, name='z') g = issn.graph - self.assertEqual(tfx.get_tensor(g, z), z) - self.assertEqual(tfx.get_tensor(g, x), x) - self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x)) - self.assertEqual("x:0", tfx.tensor_name(g, x)) - self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x)) - self.assertEqual("x", tfx.op_name(g, x)) - self.assertEqual("z", tfx.op_name(g, z)) - self.assertEqual(tfx.tensor_name(g, z), "z:0") - self.assertEqual(tfx.tensor_name(g, x), "x:0") + self.assertEqual(tfx.get_tensor(z, g), z) + self.assertEqual(tfx.get_tensor(x, g), x) + self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(x, g)) + self.assertEqual("x:0", tfx.tensor_name(x, g)) + self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(x, g)) + self.assertEqual("x", tfx.op_name(x, g)) + self.assertEqual("z", tfx.op_name(z, g)) + self.assertEqual(tfx.tensor_name(z, g), "z:0") + self.assertEqual(tfx.tensor_name(x, g), "x:0") def test_import_export_graph_function(self): """ Function import and export must be consistent """ diff --git a/python/tests/graph/test_import.py b/python/tests/graph/test_import.py new file mode 100644 index 00000000..36501568 --- /dev/null +++ b/python/tests/graph/test_import.py @@ -0,0 +1,322 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import contextlib +import shutil +import numpy as np +import os +import tensorflow as tf +import tempfile +import glob + +import sparkdl.graph.utils as tfx +from sparkdl.graph.input import TFInputGraph + + +class TestGraphImport(object): + def test_graph_novar(self): + gin = _build_graph_input(lambda session: + TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], + [_tensor_output_name])) + _check_input_novar(gin) + + def test_graphdef_novar(self): + gin = _build_graph_input(lambda session: + TFInputGraph.fromGraphDef(session.graph.as_graph_def(), + [_tensor_input_name], [_tensor_output_name])) + _check_input_novar(gin) + + def test_saved_model_novar(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + + def gin_fun(session): + _build_saved_model(session, saved_model_dir) + # Build the transformer from exported serving model + # We are using signatures, thus must provide the keys + return TFInputGraph.fromSavedModelWithSignature(saved_model_dir, _serving_tag, + _serving_sigdef_key) + + gin = _build_graph_input(gin_fun) + _check_input_novar(gin) + + def test_saved_model_iomap(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + graph = tf.Graph() + with tf.Session(graph=graph) as sess, graph.as_default(): + _build_graph() + _build_saved_model(sess, saved_model_dir) + # Build the transformer from exported serving model + # We are using signatures, thus must provide the keys + gin = TFInputGraph.fromSavedModelWithSignature(saved_model_dir, _serving_tag, + _serving_sigdef_key) + + _input_mapping_with_sigdef = {'inputCol': _tensor_input_signature} + # Input mapping for the Transformer + _translated_input_mapping = gin.translateInputMapping(_input_mapping_with_sigdef) + _expected_input_mapping = {'inputCol': tfx.tensor_name(_tensor_input_name)} + # Output mapping for the Transformer + _output_mapping_with_sigdef = {_tensor_output_signature: 'outputCol'} + _translated_output_mapping = gin.translateOutputMapping(_output_mapping_with_sigdef) + _expected_output_mapping = {tfx.tensor_name(_tensor_output_name): 'outputCol'} + + err_msg = "signature based input mapping {} and output mapping {} " + \ + "must be translated correctly into tensor name based mappings" + assert _translated_input_mapping == _expected_input_mapping \ + and _translated_output_mapping == _expected_output_mapping, \ + err_msg.format(_translated_input_mapping, _translated_output_mapping) + + + def test_saved_graph_novar(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + + def gin_fun(session): + _build_saved_model(session, saved_model_dir) + return TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input(gin_fun) + _check_input_novar(gin) + + def test_checkpoint_sig_var(self): + with _make_temp_directory() as tmp_dir: + def gin_fun(session): + _build_checkpointed_model(session, tmp_dir) + return TFInputGraph.fromCheckpointWithSignature(tmp_dir, _serving_sigdef_key) + + gin = _build_graph_input_var(gin_fun) + _check_input_novar(gin) + + def test_checkpoint_nosig_var(self): + with _make_temp_directory() as tmp_dir: + def gin_fun(session): + _build_checkpointed_model(session, tmp_dir) + return TFInputGraph.fromCheckpoint(tmp_dir, + [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input_var(gin_fun) + _check_input_novar(gin) + + def test_checkpoint_graph_var(self): + with _make_temp_directory() as tmp_dir: + def gin_fun(session): + _build_checkpointed_model(session, tmp_dir) + return TFInputGraph.fromGraph(session.graph, session, + [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input_var(gin_fun) + _check_input_novar(gin) + + def test_graphdef_novar_2(self): + gin = _build_graph_input_2(lambda session: + TFInputGraph.fromGraphDef(session.graph.as_graph_def(), + [_tensor_input_name], [_tensor_output_name])) + _check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1) + + def test_saved_graph_novar_2(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + + def gin_fun(session): + _build_saved_model(session, saved_model_dir) + return TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input_2(gin_fun) + _check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1) + +_serving_tag = "serving_tag" +_serving_sigdef_key = 'prediction_signature' +# The name of the input tensor +_tensor_input_name = "input_tensor" +# For testing graphs with 2 inputs +_tensor_input_name_2 = "input_tensor_2" +# The name of the output tensor (scalar) +_tensor_output_name = "output_tensor" +# Input signature name +_tensor_input_signature = 'well_known_input_sig' +# Output signature name +_tensor_output_signature = 'well_known_output_sig' +# The name of the variable +_tensor_var_name = "variable" +# The size of the input tensor +_tensor_size = 3 + + +def _build_checkpointed_model(session, tmp_dir): + """ + Writes a model checkpoint in the given directory. The graph is assumed to be generated + with _build_graph_var. + """ + ckpt_path_prefix = os.path.join(tmp_dir, 'model_ckpt') + input_tensor = tfx.get_tensor(_tensor_input_name, session.graph) + output_tensor = tfx.get_tensor(_tensor_output_name, session.graph) + w = tfx.get_tensor(_tensor_var_name, session.graph) + saver = tf.train.Saver(var_list=[w]) + _ = saver.save(session, ckpt_path_prefix, global_step=2702) + sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)} + sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)} + serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def( + inputs=sig_inputs, outputs=sig_outputs) + + # A rather contrived way to add signature def to a meta_graph + meta_graph_def = tf.train.export_meta_graph() + + # Find the meta_graph file (there should be only one) + _ckpt_meta_fpaths = glob.glob('{}/*.meta'.format(tmp_dir)) + assert len(_ckpt_meta_fpaths) == 1, \ + 'expected only one meta graph, but got {}'.format(','.join(_ckpt_meta_fpaths)) + ckpt_meta_fpath = _ckpt_meta_fpaths[0] + + # Add signature_def to the meta_graph and serialize it + # This will overwrite the existing meta_graph_def file + meta_graph_def.signature_def[_serving_sigdef_key].CopyFrom(serving_sigdef) + with open(ckpt_meta_fpath, mode='wb') as fout: + fout.write(meta_graph_def.SerializeToString()) + + +def _build_saved_model(session, saved_model_dir): + """ + Saves a model in a file. The graph is assumed to be generated with _build_graph_novar. + """ + builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) + input_tensor = tfx.get_tensor(_tensor_input_name, session.graph) + output_tensor = tfx.get_tensor(_tensor_output_name, session.graph) + sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)} + sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)} + serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def( + inputs=sig_inputs, outputs=sig_outputs) + + builder.add_meta_graph_and_variables( + session, [_serving_tag], signature_def_map={_serving_sigdef_key: serving_sigdef}) + builder.save() + + +@contextlib.contextmanager +def _make_temp_directory(): + temp_dir = tempfile.mkdtemp() + try: + yield temp_dir + finally: + shutil.rmtree(temp_dir) + + +def _build_graph_input(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it, and then calls + gin_function(session) to return the graph input object + """ + graph = tf.Graph() + with tf.Session(graph=graph) as s, graph.as_default(): + _build_graph() + return gin_function(s) + + +def _build_graph_input_2(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it (graph_2), and then calls + gin_function(session) to return the graph input object + """ + graph = tf.Graph() + with tf.Session(graph=graph) as s, graph.as_default(): + _build_graph_2() + return gin_function(s) + + +def _build_graph_input_var(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it that contains a variable, + and then calls gin_function(session) to return the graph input object + """ + graph = tf.Graph() + with tf.Session(graph=graph) as s, graph.as_default(): + _build_graph_var(s) + return gin_function(s) + + +def _build_graph(): + """ + Given a session (implicitly), adds nodes of computations + + It takes a vector input, with vec_size columns and returns an int32 scalar. + """ + x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) + _ = tf.reduce_max(x, name=_tensor_output_name) + + +def _build_graph_2(): + """ + Given a session (implicitly), adds nodes of computations with two inputs. + + It takes a vector input, with vec_size columns and returns an int32 scalar. + """ + x1 = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) + x2 = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name_2) + # Make sure that the inputs are not used in a symmetric manner. + _ = tf.reduce_max(x1 - x2, name=_tensor_output_name) + + +def _build_graph_var(session): + """ + Given a session, adds nodes that include one variable. + """ + x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) + w = tf.Variable(tf.ones(shape=[_tensor_size], dtype=tf.int32), name=_tensor_var_name) + _ = tf.reduce_max(x * w, name=_tensor_output_name) + session.run(w.initializer) + + +def _check_input_novar(gin): + """ + Tests that the graph from _build_graph has been serialized in the InputGraph object. + """ + _check_output(gin, np.array([1, 2, 3]), 3) + + +def _check_output(gin, tf_input, expected): + """ + Takes a TFInputGraph object (assumed to have the input and outputs of the given + names above) and compares the outcome against some expected outcome. + """ + graph = tf.Graph() + graph_def = gin.graph_def + with tf.Session(graph=graph) as sess: + tf.import_graph_def(graph_def, name="") + tgt_feed = tfx.get_tensor(_tensor_input_name, graph) + tgt_fetch = tfx.get_tensor(_tensor_output_name, graph) + # Run on the testing target + tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input}) + # Working on integers, the calculation should be exact + assert np.all(tgt_out == expected), (tgt_out, expected) + + +# TODO: we could factorize with _check_output, but this is not worth the time doing it. +def _check_output_2(gin, tf_input1, tf_input2, expected): + """ + Takes a TFInputGraph object (assumed to have the input and outputs of the given + names above) and compares the outcome against some expected outcome. + """ + graph = tf.Graph() + graph_def = gin.graph_def + with tf.Session(graph=graph) as sess: + tf.import_graph_def(graph_def, name="") + tgt_feed1 = tfx.get_tensor(_tensor_input_name, graph) + tgt_feed2 = tfx.get_tensor(_tensor_input_name_2, graph) + tgt_fetch = tfx.get_tensor(_tensor_output_name, graph) + # Run on the testing target + tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed1: tf_input1, tgt_feed2: tf_input2}) + # Working on integers, the calculation should be exact + assert np.all(tgt_out == expected), (tgt_out, expected) diff --git a/python/tests/graph/test_pieces.py b/python/tests/graph/test_pieces.py index 1497d137..9d659265 100644 --- a/python/tests/graph/test_pieces.py +++ b/python/tests/graph/test_pieces.py @@ -55,7 +55,7 @@ def exec_gfn_spimg_decode(spimg_dict, img_dtype): gfn = gfac.buildSpImageConverter(img_dtype) with IsolatedSession() as issn: feeds, fetches = issn.importGraphFunction(gfn, prefix="") - feed_dict = dict((tnsr, spimg_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) + feed_dict = dict((tnsr, spimg_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) img_out = issn.run(fetches[0], feed_dict=feed_dict) return img_out @@ -159,7 +159,7 @@ def test_pipeline(self): with IsolatedSession() as issn: # Need blank import scope name so that spimg fields match the input names feeds, fetches = issn.importGraphFunction(piped_model, prefix="") - feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) + feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) preds_tgt = issn.run(fetches[0], feed_dict=feed_dict) # Uncomment the line below to see the graph # tfx.write_visualization_html(issn.graph, diff --git a/python/tests/graph/test_utils.py b/python/tests/graph/test_utils.py new file mode 100644 index 00000000..4847c9b1 --- /dev/null +++ b/python/tests/graph/test_utils.py @@ -0,0 +1,174 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +from collections import namedtuple +# Use this to create parameterized test cases +from parameterized import parameterized + +import tensorflow as tf + +import sparkdl.graph.utils as tfx + +from ..tests import PythonUnitTestCase + +TestCase = namedtuple('TestCase', ['data', 'description']) + + +def _gen_tensor_op_string_input_tests(): + op_name = 'someOp' + for tnsr_idx in [0, 1, 2, 3, 5, 8, 15, 17]: + tnsr_name = '{}:{}'.format(op_name, tnsr_idx) + yield TestCase(data=(op_name, tfx.op_name(tnsr_name)), + description='test tensor name to op name') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), + description='test tensor name to tensor name') + + +def _gen_invalid_tensor_or_op_input_with_wrong_types(): + for wrong_val in [7, 1.2, tf.Graph()]: + yield TestCase(data=wrong_val, description='wrong type {}'.format(type(wrong_val))) + + +def _gen_invalid_tensor_or_op_with_graph_pairing(): + tnsr = tf.constant(1427.08, name='someConstOp') + other_graph = tf.Graph() + op_name = tnsr.op.name + + # Test get_tensor and get_op with non-associated tensor/op and graph inputs + _comm_suffix = ' with wrong graph' + yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph), + description='test get_op from tensor' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph), + description='test get_tensor from tensor' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph), + description='test get_op from tensor name' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph), + description='test get_tensor from tensor name' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph), + description='test get_op from op' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph), + description='test get_tensor from op' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_op(op_name, other_graph), + description='test get_op from op name' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph), + description='test get_tensor from op name' + _comm_suffix) + + +def _gen_valid_tensor_op_input_combos(): + op_name = 'someConstOp' + tnsr_name = '{}:0'.format(op_name) + tnsr = tf.constant(1427.08, name=op_name) + graph = tnsr.graph + + # Test for op_name + yield TestCase(data=(op_name, tfx.op_name(tnsr)), + description='get op name from tensor (no graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)), + description='get op name from tensor (with graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr_name)), + description='get op name from tensor name (no graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr_name, graph)), + description='get op name from tensor name (with graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr.op)), + description='get op name from op (no graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr.op, graph)), + description='get op name from op (with graph)') + yield TestCase(data=(op_name, tfx.op_name(op_name)), + description='get op name from op name (no graph)') + yield TestCase(data=(op_name, tfx.op_name(op_name, graph)), + description='get op name from op name (with graph)') + + # Test for tensor_name + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)), + description='get tensor name from tensor (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)), + description='get tensor name from tensor (with graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), + description='get tensor name from tensor name (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)), + description='get tensor name from tensor name (with graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op)), + description='get tensor name from op (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op, graph)), + description='get tensor name from op (with graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), + description='get tensor name from op name (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)), + description='get tensor name from op name (with graph)') + + # Test for get_tensor + yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)), + description='get tensor from tensor') + yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)), + description='get tensor from tensor name') + yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)), + description='get tensor from op') + yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)), + description='get tensor from op name') + + # Test for get_op + yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)), + description='get op from tensor') + yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)), + description='get op from tensor name') + yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)), + description='get op from op') + yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)), + description='test op from op name') + + +class TFeXtensionGraphUtilsTest(PythonUnitTestCase): + @parameterized.expand(_gen_tensor_op_string_input_tests) + def test_valid_tensor_op_name_inputs(self, data, description): + """ Must get correct names from valid graph element names """ + name_a, name_b = data + self.assertEqual(name_a, name_b, msg=description) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_tensor_name_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.tensor_name(data) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_op_name_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.op_name(data) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_op_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.get_op(data, tf.Graph()) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_tensor_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.get_tensor(data, tf.Graph()) + + @parameterized.expand(_gen_valid_tensor_op_input_combos) + def test_valid_tensor_op_object_inputs(self, data, description): + """ Must get correct graph elements from valid graph elements or their names """ + tfobj_or_name_a, tfobj_or_name_b = data + self.assertEqual(tfobj_or_name_a, tfobj_or_name_b, msg=description) + + @parameterized.expand(_gen_invalid_tensor_or_op_with_graph_pairing) + def test_invalid_tensor_op_object_graph_pairing(self, data, description): + """ Must fail with non-associated tensor/op and graph inputs """ + with self.assertRaises((KeyError, AssertionError, TypeError), msg=description): + data() diff --git a/python/tests/tests.py b/python/tests/tests.py index 9492a07b..4bf9d65d 100644 --- a/python/tests/tests.py +++ b/python/tests/tests.py @@ -34,21 +34,32 @@ class PythonUnitTestCase(unittest.TestCase): # This class is created to avoid replicating this logic in various places. pass -class SparkDLTestCase(unittest.TestCase): +class TestSparkContext(object): @classmethod - def setUpClass(cls): + def setup_env(cls): cls.sc = SparkContext('local[*]', cls.__name__) cls.sql = SQLContext(cls.sc) cls.session = SparkSession.builder.getOrCreate() @classmethod - def tearDownClass(cls): + def tear_down_env(cls): cls.session.stop() cls.session = None cls.sc.stop() cls.sc = None cls.sql = None + +class SparkDLTestCase(TestSparkContext, unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.setup_env() + + @classmethod + def tearDownClass(cls): + cls.tear_down_env() + def assertDfHasCols(self, df, cols = []): map(lambda c: self.assertIn(c, df.columns), cols) diff --git a/python/tests/transformers/tf_transformer_test.py b/python/tests/transformers/tf_transformer_test.py new file mode 100644 index 00000000..849a84d7 --- /dev/null +++ b/python/tests/transformers/tf_transformer_test.py @@ -0,0 +1,146 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import numpy as np +import tensorflow as tf + +from pyspark.sql.types import Row + +import tensorframes as tfs + +import sparkdl.graph.utils as tfx +from sparkdl.graph.input import TFInputGraph +from sparkdl.transformers.tf_tensor import TFTransformer + +from ..tests import SparkDLTestCase + +class TFTransformerTests(SparkDLTestCase): + def test_graph_novar(self): + transformer = _build_transformer(lambda session: + TFInputGraph.fromGraph(session.graph, session, + [_tensor_input_name], + [_tensor_output_name])) + gin = transformer.getTFInputGraph() + local_features = _build_local_features() + expected = _get_expected_result(gin, local_features) + dataset = self.session.createDataFrame(local_features) + _check_transformer_output(transformer, dataset, expected) + + +# The name of the input tensor +_tensor_input_name = "input_tensor" +# The name of the output tensor (scalar) +_tensor_output_name = "output_tensor" +# The size of the input tensor +_tensor_size = 3 +# Input mapping for the Transformer +_input_mapping = {'inputCol': tfx.tensor_name(_tensor_input_name)} +# Output mapping for the Transformer +_output_mapping = {tfx.tensor_name(_tensor_output_name): 'outputCol'} +# Numerical threshold +_all_close_tolerance = 1e-5 + + +def _build_transformer(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it, and then calls + gin_function(session) to build the :py:obj:`TFInputGraph` object. + Return the :py:obj:`TFTransformer` created from it. + """ + graph = tf.Graph() + with tf.Session(graph=graph) as sess, graph.as_default(): + _build_graph(sess) + gin = gin_function(sess) + + return TFTransformer(tfInputGraph=gin, + inputMapping=_input_mapping, + outputMapping=_output_mapping) + + +def _build_graph(sess): + """ + Given a session (implicitly), adds nodes of computations + + It takes a vector input, with `_tensor_size` columns and returns an float64 scalar. + """ + x = tf.placeholder(tf.float64, shape=[None, _tensor_size], name=_tensor_input_name) + _ = tf.reduce_max(x, axis=1, name=_tensor_output_name) + +def _build_local_features(): + """ + Build numpy array (i.e. local) features. + """ + # Build local features and DataFrame from it + local_features = [] + np.random.seed(997) + for idx in range(100): + _dict = {'idx': idx} + for colname, _ in _input_mapping.items(): + _dict[colname] = np.random.randn(_tensor_size).tolist() + + local_features.append(Row(**_dict)) + + return local_features + +def _get_expected_result(gin, local_features): + """ + Running the graph in the :py:obj:`TFInputGraph` object and compute the expected results. + :param: gin, a :py:obj:`TFInputGraph` + :return: expected results in NumPy array + """ + graph = tf.Graph() + with tf.Session(graph=graph) as sess, graph.as_default(): + # Build test graph and transformers from here + tf.import_graph_def(gin.graph_def, name='') + + # Build the results + _results = [] + for row in local_features: + fetches = [tfx.get_tensor(tnsr_name, graph) + for tnsr_name, _ in _output_mapping.items()] + feed_dict = {} + for colname, tnsr_name in _input_mapping.items(): + tnsr = tfx.get_tensor(tnsr_name, graph) + feed_dict[tnsr] = np.array(row[colname])[np.newaxis, :] + + curr_res = sess.run(fetches, feed_dict=feed_dict) + _results.append(np.ravel(curr_res)) + + expected = np.hstack(_results) + + return expected + +def _check_transformer_output(transformer, dataset, expected): + """ + Given a transformer and a spark dataset, check if the transformer + produces the expected results. + """ + analyzed_df = tfs.analyze(dataset) + out_df = transformer.transform(analyzed_df) + + # Collect transformed values + out_colnames = list(_output_mapping.values()) + _results = [] + for row in out_df.select(out_colnames).collect(): + curr_res = [row[colname] for colname in out_colnames] + _results.append(np.ravel(curr_res)) + out_tgt = np.hstack(_results) + + _err_msg = 'not close => shape {} != {}, max_diff {} > {}' + max_diff = np.max(np.abs(expected - out_tgt)) + err_msg = _err_msg.format(expected.shape, out_tgt.shape, + max_diff, _all_close_tolerance) + assert np.allclose(expected, out_tgt, atol=_all_close_tolerance), err_msg diff --git a/python/tests/udf/keras_sql_udf_test.py b/python/tests/udf/keras_sql_udf_test.py index d1473b3c..5c67c854 100644 --- a/python/tests/udf/keras_sql_udf_test.py +++ b/python/tests/udf/keras_sql_udf_test.py @@ -66,7 +66,7 @@ def test_simple_keras_udf(self): makeGraphUDF(issn.graph, 'my_keras_model_udf', model.outputs, - {tfx.op_name(issn.graph, model.inputs[0]): 'image_col'}) + {tfx.op_name(model.inputs[0], issn.graph): 'image_col'}) # Run the training procedure # Export the graph in this IsolatedSession as a GraphFunction # gfn = issn.asGraphFunction(model.inputs, model.outputs) @@ -168,4 +168,3 @@ def test_map_blocks_sql_1(self): data2 = df2.collect() assert len(data2) == 5, data2 assert data2[0].z == 3.0, data2 -