Skip to content

Commit cd3aa8d

Browse files
committed
tests
1 parent f4d938c commit cd3aa8d

File tree

2 files changed

+223
-22
lines changed

2 files changed

+223
-22
lines changed

python/sparkdl/graph/input.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,6 @@ def _new_obj_internal(cls):
4141
obj.output_tensor_name_from_signature = None
4242
return obj
4343

44-
def translateInputMapping(self, input_mapping):
45-
assert self.input_tensor_name_from_signature is not None
46-
_input_mapping = {}
47-
if isinstance(input_mapping, dict):
48-
input_mapping = list(input_mapping.items())
49-
assert isinstance(input_mapping, list)
50-
for col_name, sig_key in input_mapping:
51-
tnsr_name = self.input_tensor_name_from_signature[sig_key]
52-
_input_mapping[col_name] = tnsr_name
53-
return _input_mapping
54-
55-
def translateOutputMapping(self, output_mapping):
56-
assert self.output_tensor_name_from_signature is not None
57-
_output_mapping = {}
58-
if isinstance(output_mapping, dict):
59-
output_mapping = list(output_mapping.items())
60-
assert isinstance(output_mapping, list)
61-
for sig_key, col_name in output_mapping:
62-
tnsr_name = self.output_tensor_name_from_signature[sig_key]
63-
_output_mapping[tnsr_name] = col_name
64-
return _output_mapping
65-
6644
@classmethod
6745
def fromGraph(cls, graph, sess, feed_names, fetch_names):
6846
"""
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2017 Databricks, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
from __future__ import absolute_import, division, print_function
16+
17+
from contextlib import contextmanager
18+
from glob import glob
19+
import os
20+
import shutil
21+
import tempfile
22+
23+
import numpy as np
24+
import tensorflow as tf
25+
26+
from sparkdl.graph.input import *
27+
import sparkdl.graph.utils as tfx
28+
29+
from ..tests import PythonUnitTestCase
30+
31+
32+
class TFInputGraphTest(PythonUnitTestCase):
33+
34+
def setUp(self):
35+
self.vec_size = 23
36+
self.num_samples = 107
37+
38+
self.input_col = 'dfInputCol'
39+
self.input_op_name = 'tnsrOpIn'
40+
self.output_col = 'dfOutputCol'
41+
self.output_op_name = 'tnsrOpOut'
42+
43+
self.feed_names = []
44+
self.fetch_names = []
45+
self.input_mapping = {}
46+
self.output_mapping = {}
47+
self.setup_iomap(replica=1)
48+
49+
self.input_graphs = []
50+
self.test_case_results = []
51+
# Build a temporary directory, which might or might not be used by the test
52+
self.model_output_root = tempfile.mkdtemp()
53+
54+
def tearDown(self):
55+
shutil.rmtree(self.model_output_root, ignore_errors=True)
56+
57+
def setup_iomap(self, replica=1):
58+
self.input_mapping = {}
59+
self.feed_names = []
60+
self.output_mapping = {}
61+
self.fetch_names = []
62+
63+
if replica > 1:
64+
for i in range(replica):
65+
colname = '{}_replica{:03d}'.format(self.input_col, i)
66+
tnsr_op_name = '{}_replica{:03d}'.format(self.input_op_name, i)
67+
self.input_mapping[colname] = tnsr_op_name
68+
self.feed_names.append(tnsr_op_name + ':0')
69+
70+
colname = '{}_replica{:03d}'.format(self.output_col, i)
71+
tnsr_op_name = '{}_replica{:03d}'.format(self.output_op_name, i)
72+
self.output_mapping[tnsr_op_name] = colname
73+
self.fetch_names.append(tnsr_op_name + ':0')
74+
else:
75+
self.input_mapping = {self.input_col: self.input_op_name}
76+
self.feed_names = [self.input_op_name + ':0']
77+
self.output_mapping = {self.output_op_name: self.output_col}
78+
self.fetch_names = [self.output_op_name + ':0']
79+
80+
@contextmanager
81+
def _run_test_in_tf_session(self):
82+
""" [THIS IS NOT A TEST]: encapsulate general test workflow """
83+
84+
# Build the TensorFlow graph
85+
graph = tf.Graph()
86+
with tf.Session(graph=graph) as sess, graph.as_default():
87+
# Build test graph and transformers from here
88+
yield sess
89+
90+
ref_feed = tfx.get_tensor(graph, self.input_op_name)
91+
ref_fetch = tfx.get_tensor(graph, self.output_op_name)
92+
93+
def check_input_graph(tgt_gdef, test_idx):
94+
namespace = 'TEST_TGT_NS{:03d}'.format(test_idx)
95+
tf.import_graph_def(tgt_gdef, name=namespace)
96+
tgt_feed = tfx.get_tensor(graph, '{}/{}'.format(namespace, self.input_op_name))
97+
tgt_fetch = tfx.get_tensor(graph, '{}/{}'.format(namespace, self.output_op_name))
98+
99+
for _ in range(10):
100+
local_data = np.random.randn(31, self.vec_size)
101+
ref_out = sess.run(ref_fetch, feed_dict={ref_feed: local_data})
102+
tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: local_data})
103+
self.assertTrue(np.allclose(ref_out, tgt_out))
104+
105+
for test_idx, input_graph in enumerate(self.input_graphs):
106+
check_input_graph(input_graph.graph_def, test_idx)
107+
108+
109+
def test_build_from_tf_graph(self):
110+
""" Build TFTransformer from tf.Graph """
111+
with self._run_test_in_tf_session() as sess:
112+
# Begin building graph
113+
x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name)
114+
_ = tf.reduce_mean(x, axis=1, name=self.output_op_name)
115+
116+
gin = TFInputGraph.fromGraph(sess.graph, sess, self.feed_names, self.fetch_names)
117+
self.input_graphs.append(gin)
118+
# End building graph
119+
120+
def test_build_from_saved_model(self):
121+
""" Build TFTransformer from saved model """
122+
# Setup saved model export directory
123+
saved_model_root = self.model_output_root
124+
saved_model_dir = os.path.join(saved_model_root, 'saved_model')
125+
serving_tag = "serving_tag"
126+
serving_sigdef_key = 'prediction_signature'
127+
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
128+
129+
with self._run_test_in_tf_session() as sess:
130+
# Model definition: begin
131+
x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name)
132+
w = tf.Variable(tf.random_normal([self.vec_size], dtype=tf.float64),
133+
dtype=tf.float64, name='varW')
134+
z = tf.reduce_mean(x * w, axis=1, name=self.output_op_name)
135+
# Model definition ends
136+
137+
sess.run(w.initializer)
138+
139+
sig_inputs = {
140+
'input_sig': tf.saved_model.utils.build_tensor_info(x)}
141+
sig_outputs = {
142+
'output_sig': tf.saved_model.utils.build_tensor_info(z)}
143+
144+
serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
145+
inputs=sig_inputs,
146+
outputs=sig_outputs)
147+
148+
builder.add_meta_graph_and_variables(sess,
149+
[serving_tag],
150+
signature_def_map={
151+
serving_sigdef_key: serving_sigdef})
152+
builder.save()
153+
154+
# Build the transformer from exported serving model
155+
# We are using signaures, thus must provide the keys
156+
gin = TFInputGraph.fromSavedModelWithSignature(
157+
saved_model_dir, serving_tag, serving_sigdef_key)
158+
self.input_graphs.append(gin)
159+
160+
# Build the transformer from exported serving model
161+
# We are not using signatures, thus must provide tensor/operation names
162+
gin = TFInputGraph.fromSavedModel(
163+
saved_model_dir, serving_tag, self.feed_names, self.fetch_names)
164+
self.input_graphs.append(gin)
165+
166+
gin = TFInputGraph.fromGraph(
167+
sess.graph, sess, self.feed_names, self.fetch_names)
168+
self.input_graphs.append(gin)
169+
170+
171+
def test_build_from_checkpoint(self):
172+
""" Build TFTransformer from a model checkpoint """
173+
# Build the TensorFlow graph
174+
model_ckpt_dir = self.model_output_root
175+
ckpt_path_prefix = os.path.join(model_ckpt_dir, 'model_ckpt')
176+
serving_sigdef_key = 'prediction_signature'
177+
178+
with self._run_test_in_tf_session() as sess:
179+
x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name)
180+
#x = tf.placeholder(tf.float64, shape=[None, vec_size], name=input_col)
181+
w = tf.Variable(tf.random_normal([self.vec_size], dtype=tf.float64),
182+
dtype=tf.float64, name='varW')
183+
z = tf.reduce_mean(x * w, axis=1, name=self.output_op_name)
184+
sess.run(w.initializer)
185+
saver = tf.train.Saver(var_list=[w])
186+
_ = saver.save(sess, ckpt_path_prefix, global_step=2702)
187+
188+
# Prepare the signature_def
189+
serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
190+
inputs={
191+
'input_sig': tf.saved_model.utils.build_tensor_info(x)
192+
},
193+
outputs={
194+
'output_sig': tf.saved_model.utils.build_tensor_info(z)
195+
})
196+
197+
# A rather contrived way to add signature def to a meta_graph
198+
meta_graph_def = tf.train.export_meta_graph()
199+
200+
# Find the meta_graph file (there should be only one)
201+
_ckpt_meta_fpaths = glob('{}/*.meta'.format(model_ckpt_dir))
202+
self.assertEqual(len(_ckpt_meta_fpaths), 1, msg=','.join(_ckpt_meta_fpaths))
203+
ckpt_meta_fpath = _ckpt_meta_fpaths[0]
204+
205+
# Add signature_def to the meta_graph and serialize it
206+
# This will overwrite the existing meta_graph_def file
207+
meta_graph_def.signature_def[serving_sigdef_key].CopyFrom(serving_sigdef)
208+
with open(ckpt_meta_fpath, mode='wb') as fout:
209+
fout.write(meta_graph_def.SerializeToString())
210+
211+
# Build the transformer from exported serving model
212+
# We are using signaures, thus must provide the keys
213+
gin = TFInputGraph.fromCheckpointWithSignature(
214+
model_ckpt_dir, serving_sigdef_key)
215+
self.input_graphs.append(gin)
216+
217+
# Transformer without using signature_def
218+
gin = TFInputGraph.fromCheckpoint(model_ckpt_dir, self.feed_names, self.fetch_names)
219+
self.input_graphs.append(gin)
220+
221+
gin = TFInputGraph.fromGraph(
222+
sess.graph, sess, self.feed_names, self.fetch_names)
223+
self.input_graphs.append(gin)

0 commit comments

Comments
 (0)