Skip to content

Overview of All Changes in TFTransformer PRs #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 103 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
42c6e6e
flat param API impl
phi-dbq Aug 8, 2017
ecbefb9
support input graph scenarios
phi-dbq Aug 25, 2017
ab89bd2
(WIP) new interface implementation
phi-dbq Sep 9, 2017
8c7d72e
docs and cleanup
phi-dbq Sep 9, 2017
eb543c6
using tensorflow API instead of our utilities
phi-dbq Sep 10, 2017
4743bb9
automatic type conversion
phi-dbq Sep 10, 2017
622c788
cleanup
phi-dbq Sep 10, 2017
07f1cec
PR comments
phi-dbq Sep 11, 2017
692b0eb
(WIP) address comments
phi-dbq Sep 12, 2017
66d44e9
(WIP) respond to PR comments
phi-dbq Sep 13, 2017
9b3fe86
test refactor
phi-dbq Sep 13, 2017
8c32501
Merge remote-tracking branch 'upstream/master' into tf-1d-transformer
phi-dbq Sep 16, 2017
dbd9aaa
(wip) consolidating params
phi-dbq Sep 16, 2017
4572205
rebase upstream
phi-dbq Sep 16, 2017
1cc7591
import params fix
phi-dbq Sep 16, 2017
2fc6787
(wip) TFInputGraph impl
phi-dbq Sep 16, 2017
889df0a
(wip) moving to new API
phi-dbq Sep 17, 2017
86cd6d9
(wip) enable saved_model tests
phi-dbq Sep 17, 2017
ac09182
(wip) enable checkpoint test
phi-dbq Sep 17, 2017
6b22eed
(wip) enable multiple tensor tests
phi-dbq Sep 17, 2017
a3517d6
enable all tests
phi-dbq Sep 17, 2017
457a4c2
params and converters
phi-dbq Sep 18, 2017
323939a
tests
phi-dbq Sep 18, 2017
6e46073
Merge branch 'tf-transformer-part1' into api-tf-transformer
phi-dbq Sep 18, 2017
b232b3c
optimize graph for inference
phi-dbq Sep 18, 2017
d921366
more tests
phi-dbq Sep 19, 2017
0c8c219
update utils
phi-dbq Sep 19, 2017
522279a
tests
phi-dbq Sep 19, 2017
f4d938c
intro: TFInputGraph
phi-dbq Sep 19, 2017
cd3aa8d
tests
phi-dbq Sep 19, 2017
97b25c6
Baseline
phi-dbq Sep 21, 2017
07c58e6
allows setting TFInputGraph
phi-dbq Sep 21, 2017
269ad15
utilize test_input_graph for transformer tests
phi-dbq Sep 21, 2017
84a8138
enable all tests
phi-dbq Sep 21, 2017
7287ab7
fix style
phi-dbq Sep 21, 2017
467480e
Merge branch 'tf-transformer-part1' into tf-transformer-part2
phi-dbq Sep 21, 2017
4f11374
fix style
phi-dbq Sep 21, 2017
27f0617
Merge branch 'tf-transformer-part1' into tf-transformer-part2
phi-dbq Sep 21, 2017
7b6ec3a
autogen test cases
phi-dbq Sep 21, 2017
561f8e7
test refactoring
phi-dbq Sep 22, 2017
7248517
further refactor
phi-dbq Sep 23, 2017
2e8f7a1
Merge branch 'tf-transformer-part1' into tf-transformer-part3
phi-dbq Sep 23, 2017
e09027f
Merge branch 'tf-transformer-part1' into tf-transformer-part3
phi-dbq Sep 23, 2017
40caace
and so there is no helper classes
phi-dbq Sep 23, 2017
93e659d
Merge branch 'tf-transformer-part1' into tf-transformer-part2
phi-dbq Sep 23, 2017
6e880ce
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Sep 23, 2017
883321e
input graph
phi-dbq Sep 23, 2017
c72444b
docs
phi-dbq Sep 23, 2017
e963d11
and into more pieces
phi-dbq Sep 23, 2017
89e2a1d
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Sep 23, 2017
f7a7d38
update converter and test cases
phi-dbq Sep 25, 2017
ce60629
class & docs
phi-dbq Sep 25, 2017
e0cf2ff
update docs
phi-dbq Sep 25, 2017
fcabcb6
using `parameterized` to simplify testing logic
phi-dbq Sep 25, 2017
77b3906
converter changes
phi-dbq Sep 26, 2017
76e9fb9
PR comments
phi-dbq Sep 28, 2017
66507f4
tf_image inputTensor default setter bug-fix
phi-dbq Sep 29, 2017
d239a5a
use type error, always
phi-dbq Sep 29, 2017
5947c9c
doc updates
phi-dbq Sep 29, 2017
f8d7930
Merge branch 'tf-transformer-part1' into tf-transformer-part2
phi-dbq Sep 29, 2017
202e7ea
refactoring tfx API
phi-dbq Sep 29, 2017
faf8cdd
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Sep 29, 2017
6aa85b9
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Sep 29, 2017
aef1661
refactoring tfx API
phi-dbq Sep 29, 2017
ead1ed6
test refactoring
phi-dbq Sep 29, 2017
cf72beb
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Sep 29, 2017
20e2dbc
update tfx utils usage
phi-dbq Sep 29, 2017
20a5346
one way to build these tests
phi-dbq Sep 29, 2017
cf64708
tests refactored
phi-dbq Sep 30, 2017
c3b3a86
test cases in a single class
phi-dbq Sep 30, 2017
e47060f
shuffle things around
phi-dbq Sep 30, 2017
4e8f4e3
docs mostly
phi-dbq Sep 30, 2017
eaa5fa0
yapf'd
phi-dbq Sep 30, 2017
43d6583
consolidate tempdir creation
phi-dbq Oct 2, 2017
8b75d44
Address PR comments
phi-dbq Oct 2, 2017
055ce14
PR comments
phi-dbq Oct 2, 2017
2d48b32
(wip) utils test
phi-dbq Oct 2, 2017
742cdaf
a few more tests for utils
phi-dbq Oct 2, 2017
f0912fb
test update cont'd
phi-dbq Oct 2, 2017
ee3acf1
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Oct 3, 2017
7f16396
Merge branch 'tf-transformer-part1' into tf-transformer-part3
phi-dbq Oct 3, 2017
f5107ad
(wip) PR comments
phi-dbq Oct 3, 2017
ac681b0
more tests
phi-dbq Oct 3, 2017
4d173c5
change test generator module name
phi-dbq Oct 3, 2017
85e0778
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Oct 3, 2017
22754c9
tensor tests
phi-dbq Oct 3, 2017
0144b8c
tensor test update
phi-dbq Oct 3, 2017
a8531ec
buildCheckList name change and doc fixup
phi-dbq Oct 3, 2017
3c849f2
Merge branch 'tf-transformer-part1' into tf-transformer-part2
phi-dbq Oct 3, 2017
707697d
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Oct 3, 2017
c6eb87c
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Oct 3, 2017
d729528
PR comments
phi-dbq Oct 4, 2017
0c2eda1
PR comments
phi-dbq Oct 4, 2017
63967b4
PR comments
phi-dbq Oct 5, 2017
fe719b2
Merge branch 'tf-transformer-part2' into tf-transformer-part3
phi-dbq Oct 5, 2017
812f4d6
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Oct 5, 2017
47d497c
TFTransformer Part-4 Test Refactor (#15)
phi-dbq Nov 18, 2017
a39b6d3
TFTransformer Part-3 Test Refactor (#14)
thunterdb Nov 18, 2017
07cc335
deleting original testing ideas
phi-dbq Nov 18, 2017
925fc0d
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Nov 18, 2017
decdc8f
PR comments
phi-dbq Nov 22, 2017
91b9379
Merge branch 'tf-transformer-part3' into tf-transformer-part4
phi-dbq Nov 22, 2017
af95b74
PR comments
phi-dbq Nov 22, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions python/tests/graph/test_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import, division, print_function

import contextlib
import shutil
import numpy as np
import os
import tensorflow as tf
import tempfile
import glob

import sparkdl.graph.utils as tfx
from sparkdl.graph.input import TFInputGraph


class TestGraphImport(object):
def test_graph_novar(self):
gin = _build_graph_input(lambda session:
TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name],
[_tensor_output_name]))
_check_input_novar(gin)

def test_graphdef_novar(self):
gin = _build_graph_input(lambda session:
TFInputGraph.fromGraphDef(session.graph.as_graph_def(),
[_tensor_input_name], [_tensor_output_name]))
_check_input_novar(gin)

def test_saved_model_novar(self):
with _make_temp_directory() as tmp_dir:
saved_model_dir = os.path.join(tmp_dir, 'saved_model')

def gin_fun(session):
_build_saved_model(session, saved_model_dir)
# Build the transformer from exported serving model
# We are using signatures, thus must provide the keys
return TFInputGraph.fromSavedModelWithSignature(saved_model_dir, _serving_tag,
_serving_sigdef_key)

gin = _build_graph_input(gin_fun)
_check_input_novar(gin)

def test_saved_graph_novar(self):
with _make_temp_directory() as tmp_dir:
saved_model_dir = os.path.join(tmp_dir, 'saved_model')

def gin_fun(session):
_build_saved_model(session, saved_model_dir)
return TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], [_tensor_output_name])

gin = _build_graph_input(gin_fun)
_check_input_novar(gin)

def test_checkpoint_sig_var(self):
with _make_temp_directory() as tmp_dir:
def gin_fun(session):
_build_checkpointed_model(session, tmp_dir)
return TFInputGraph.fromCheckpointWithSignature(tmp_dir, _serving_sigdef_key)

gin = _build_graph_input_var(gin_fun)
_check_input_novar(gin)

def test_checkpoint_nosig_var(self):
with _make_temp_directory() as tmp_dir:
def gin_fun(session):
_build_checkpointed_model(session, tmp_dir)
return TFInputGraph.fromCheckpoint(tmp_dir,
[_tensor_input_name], [_tensor_output_name])

gin = _build_graph_input_var(gin_fun)
_check_input_novar(gin)

def test_checkpoint_graph_var(self):
with _make_temp_directory() as tmp_dir:
def gin_fun(session):
_build_checkpointed_model(session, tmp_dir)
return TFInputGraph.fromGraph(session.graph, session,
[_tensor_input_name], [_tensor_output_name])

gin = _build_graph_input_var(gin_fun)
_check_input_novar(gin)

def test_graphdef_novar_2(self):
gin = _build_graph_input_2(lambda session:
TFInputGraph.fromGraphDef(session.graph.as_graph_def(),
[_tensor_input_name], [_tensor_output_name]))
_check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1)

def test_saved_graph_novar_2(self):
with _make_temp_directory() as tmp_dir:
saved_model_dir = os.path.join(tmp_dir, 'saved_model')

def gin_fun(session):
_build_saved_model(session, saved_model_dir)
return TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], [_tensor_output_name])

gin = _build_graph_input_2(gin_fun)
_check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1)

_serving_tag = "serving_tag"
_serving_sigdef_key = 'prediction_signature'
# The name of the input tensor
_tensor_input_name = "input_tensor"
# For testing graphs with 2 inputs
_tensor_input_name_2 = "input_tensor_2"
# The name of the output tensor (scalar)
_tensor_output_name = "output_tensor"
# The name of the variable
_tensor_var_name = "variable"
# The size of the input tensor
_tensor_size = 3


def _build_checkpointed_model(session, tmp_dir):
"""
Writes a model checkpoint in the given directory. The graph is assumed to be generated
with _build_graph_var.
"""
ckpt_path_prefix = os.path.join(tmp_dir, 'model_ckpt')
input_tensor = tfx.get_tensor(_tensor_input_name, session.graph)
output_tensor = tfx.get_tensor(_tensor_output_name, session.graph)
w = tfx.get_tensor(_tensor_var_name, session.graph)
saver = tf.train.Saver(var_list=[w])
_ = saver.save(session, ckpt_path_prefix, global_step=2702)
sig_inputs = {'input_sig': tf.saved_model.utils.build_tensor_info(input_tensor)}
sig_outputs = {'output_sig': tf.saved_model.utils.build_tensor_info(output_tensor)}
serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
inputs=sig_inputs, outputs=sig_outputs)

# A rather contrived way to add signature def to a meta_graph
meta_graph_def = tf.train.export_meta_graph()

# Find the meta_graph file (there should be only one)
_ckpt_meta_fpaths = glob.glob('{}/*.meta'.format(tmp_dir))
assert len(_ckpt_meta_fpaths) == 1, \
'expected only one meta graph, but got {}'.format(','.join(_ckpt_meta_fpaths))
ckpt_meta_fpath = _ckpt_meta_fpaths[0]

# Add signature_def to the meta_graph and serialize it
# This will overwrite the existing meta_graph_def file
meta_graph_def.signature_def[_serving_sigdef_key].CopyFrom(serving_sigdef)
with open(ckpt_meta_fpath, mode='wb') as fout:
fout.write(meta_graph_def.SerializeToString())


def _build_saved_model(session, saved_model_dir):
"""
Saves a model in a file. The graph is assumed to be generated with _build_graph_novar.
"""
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
input_tensor = tfx.get_tensor(_tensor_input_name, session.graph)
output_tensor = tfx.get_tensor(_tensor_output_name, session.graph)
sig_inputs = {'input_sig': tf.saved_model.utils.build_tensor_info(input_tensor)}
sig_outputs = {'output_sig': tf.saved_model.utils.build_tensor_info(output_tensor)}
serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
inputs=sig_inputs, outputs=sig_outputs)

builder.add_meta_graph_and_variables(
session, [_serving_tag], signature_def_map={_serving_sigdef_key: serving_sigdef})
builder.save()


@contextlib.contextmanager
def _make_temp_directory():
temp_dir = tempfile.mkdtemp()
try:
yield temp_dir
finally:
shutil.rmtree(temp_dir)


def _build_graph_input(gin_function):
"""
Makes a session and a default graph, loads the simple graph into it, and then calls
gin_function(session) to return the graph input object
"""
graph = tf.Graph()
with tf.Session(graph=graph) as s, graph.as_default():
_build_graph()
return gin_function(s)


def _build_graph_input_2(gin_function):
"""
Makes a session and a default graph, loads the simple graph into it (graph_2), and then calls
gin_function(session) to return the graph input object
"""
graph = tf.Graph()
with tf.Session(graph=graph) as s, graph.as_default():
_build_graph_2()
return gin_function(s)


def _build_graph_input_var(gin_function):
"""
Makes a session and a default graph, loads the simple graph into it that contains a variable,
and then calls gin_function(session) to return the graph input object
"""
graph = tf.Graph()
with tf.Session(graph=graph) as s, graph.as_default():
_build_graph_var(s)
return gin_function(s)


def _build_graph():
"""
Given a session (implicitly), adds nodes of computations

It takes a vector input, with vec_size columns and returns an int32 scalar.
"""
x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name)
_ = tf.reduce_max(x, name=_tensor_output_name)


def _build_graph_2():
"""
Given a session (implicitly), adds nodes of computations with two inputs.

It takes a vector input, with vec_size columns and returns an int32 scalar.
"""
x1 = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name)
x2 = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name_2)
# Make sure that the inputs are not used in a symmetric manner.
_ = tf.reduce_max(x1 - x2, name=_tensor_output_name)


def _build_graph_var(session):
"""
Given a session, adds nodes that include one variable.
"""
x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name)
w = tf.Variable(tf.ones(shape=[_tensor_size], dtype=tf.int32), name=_tensor_var_name)
_ = tf.reduce_max(x * w, name=_tensor_output_name)
session.run(w.initializer)


def _check_input_novar(gin):
"""
Tests that the graph from _build_graph has been serialized in the InputGraph object.
"""
_check_output(gin, np.array([1, 2, 3]), 3)


def _check_output(gin, tf_input, expected):
"""
Takes a TFInputGraph object (assumed to have the input and outputs of the given
names above) and compares the outcome against some expected outcome.
"""
graph = tf.Graph()
graph_def = gin.graph_def
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name="")
tgt_feed = tfx.get_tensor(_tensor_input_name, graph)
tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
# Run on the testing target
tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input})
# Working on integers, the calculation should be exact
assert np.all(tgt_out == expected), (tgt_out, expected)


# TODO: we could factorize with _check_output, but this is not worth the time doing it.
def _check_output_2(gin, tf_input1, tf_input2, expected):
"""
Takes a TFInputGraph object (assumed to have the input and outputs of the given
names above) and compares the outcome against some expected outcome.
"""
graph = tf.Graph()
graph_def = gin.graph_def
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name="")
tgt_feed1 = tfx.get_tensor(_tensor_input_name, graph)
tgt_feed2 = tfx.get_tensor(_tensor_input_name_2, graph)
tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
# Run on the testing target
tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed1: tf_input1, tgt_feed2: tf_input2})
# Working on integers, the calculation should be exact
assert np.all(tgt_out == expected), (tgt_out, expected)