Skip to content

Commit f8ab1ca

Browse files
committed
(feat) tf_deconv_module 테스트코드 추가
1 parent e82b660 commit f8ab1ca

File tree

1 file changed

+310
-0
lines changed

1 file changed

+310
-0
lines changed

testcodes/test_deconv_modules.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright 2018 Jaewook Kang (jwkang10@gmail.com)
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ===================================================================================
14+
# -*- coding: utf-8 -*-
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
20+
import numpy as np
21+
import six
22+
from datetime import datetime
23+
from os import getcwd
24+
import sys
25+
sys.path.insert(0,getcwd())
26+
sys.path.insert(0,getcwd()+'/testcodes')
27+
28+
import tensorflow as tf
29+
import tensorflow.contrib.slim as slim
30+
from test_util import create_test_input
31+
32+
# module import
33+
from tf_deconv_module import get_nearest_neighbor_unpool2d_module
34+
from tf_deconv_module import get_transconv_unpool_module
35+
36+
37+
class ModuleEndpointName(object):
38+
39+
def __init__(self,deconv_type,input_shape,output_shape,layer_index=0):
40+
41+
input_shape = input_shape,
42+
output_shape = output_shape
43+
if deconv_type == 'conv2dtrans_unpool':
44+
self.name_list = ['unitest'+ str(layer_index) + '/conv2dtrans_unpool', 'conv2dtrans_unpool_out']
45+
46+
self.shape_dict = {self.name_list[0]:output_shape}
47+
48+
49+
50+
51+
class ModelTestConfig(object):
52+
53+
def __init__(self):
54+
55+
self.is_trainable = True
56+
self.unpool_weights_initializer = tf.contrib.layers.xavier_initializer()
57+
self.unpool_weights_regularizer = tf.contrib.layers.l2_regularizer(4E-5)
58+
self.unpool_biases_initializer = slim.init_ops.zeros_initializer()
59+
self.unpool_normalizer_fn = slim.batch_norm
60+
self.unpool_activation_fn = tf.nn.relu6
61+
62+
# batch_norm
63+
self.unpool_batch_norm_decay = 0.999
64+
self.unpool_batch_norm_fused = True
65+
66+
67+
68+
class ModuleTest(tf.test.TestCase):
69+
70+
def _get_deconv_module(self,inputs,
71+
unpool_rate,
72+
module_type,
73+
layer_index=0,
74+
scope=None,
75+
model_config=None):
76+
77+
scope = scope + str(layer_index)
78+
net = inputs
79+
80+
with tf.name_scope(name=scope,default_name='test_module',values=[inputs]):
81+
82+
if module_type == 'nearest_neighbor_unpool':
83+
net = get_nearest_neighbor_unpool2d_module(inputs=net,
84+
unpool_rate=unpool_rate,
85+
scope =scope)
86+
elif module_type == 'conv2dtrans_unpool':
87+
net = get_transconv_unpool_module(inputs=net,
88+
unpool_rate=unpool_rate,
89+
model_config=model_config,
90+
scope=scope)
91+
return net
92+
93+
94+
95+
96+
97+
def test_nearest_neighbor_unpool(self):
98+
99+
TEST_MODULE_NAME = 'nearest_neighbor_unpool'
100+
scope = 'unitest'
101+
102+
input_width = 2
103+
input_height = 2
104+
input_shape = [1, input_height,input_width,1]
105+
106+
107+
x = tf.to_float([[0, 1],
108+
[2, 3]])
109+
x = tf.reshape(x,shape=input_shape)
110+
111+
112+
y_unpool2_test1_expected = tf.to_float([[0,0,1,1],
113+
[0,0,1,1],
114+
[2,2,3,3],
115+
[2,2,3,3]])
116+
117+
y_unpool2_test1_expected = tf.reshape(y_unpool2_test1_expected,
118+
shape=[1,input_height*2,input_width*2,1])
119+
120+
y_unpool3_test1_expected = tf.to_float([[0, 0, 0, 1, 1, 1],
121+
[0, 0, 0, 1, 1, 1],
122+
[0, 0, 0, 1, 1, 1],
123+
[2, 2, 2, 3, 3, 3],
124+
[2, 2, 2, 3, 3, 3],
125+
[2, 2, 2, 3, 3, 3]])
126+
127+
y_unpool3_test1_expected = tf.reshape(y_unpool3_test1_expected,
128+
shape=[1,input_height*3,input_width*3,1])
129+
130+
y_unpool2_test1 = self._get_deconv_module(inputs=x,
131+
unpool_rate=2,
132+
module_type=TEST_MODULE_NAME,
133+
layer_index=0,
134+
scope=scope)
135+
136+
y_unpool3_test1 = self._get_deconv_module(inputs=x,
137+
unpool_rate=3,
138+
module_type=TEST_MODULE_NAME,
139+
layer_index=1,
140+
scope=scope)
141+
142+
with self.test_session() as sess:
143+
print('--------------------------------------------')
144+
print ('[tfTest] run test_nearest_neighbor_unpool()')
145+
sess.run(tf.global_variables_initializer())
146+
self.assertAllClose(y_unpool2_test1.eval(),y_unpool2_test1_expected.eval())
147+
self.assertAllClose(y_unpool3_test1.eval(),y_unpool3_test1_expected.eval())
148+
149+
# print ('[test1] Result of unpool rate2 = %s' % y_unpool2_test1.eval())
150+
# print ('[test1] Expected of unpool rate2 = %s' % y_unpool2_test1_expected.eval())
151+
# print ('[test1] Result of unpool rate3 = %s' % y_unpool3_test1.eval())
152+
# print ('[test1] Expected of unpool rate3 = %s' % y_unpool3_test1_expected.eval())
153+
154+
print ('[test1] input shape of x = %s'% x.get_shape().as_list())
155+
print ('[test1] output shape of y_unpool2 = %s' % y_unpool2_test1.get_shape().as_list())
156+
print ('[test1] output shape of y_unpool3 = %s' % y_unpool3_test1.get_shape().as_list())
157+
158+
159+
160+
161+
162+
163+
def test_transconv_unpool_name_shape(self):
164+
scope = 'unitest'
165+
166+
model_config = ModelTestConfig()
167+
TEST_MODULE_NAME = 'conv2dtrans_unpool'
168+
169+
with tf.name_scope(name=scope):
170+
input_width = 2
171+
input_height = 2
172+
input_shape = [None, input_height,input_width,1]
173+
unpool_rate = 3
174+
175+
expected_output_shape = [input_shape[0],
176+
input_shape[1]*unpool_rate,
177+
input_shape[2]*unpool_rate,
178+
input_shape[3]]
179+
180+
inputs = create_test_input(batchsize= input_shape[0],
181+
heightsize=input_shape[1],
182+
widthsize =input_shape[2],
183+
channelnum= input_shape[3])
184+
185+
186+
y_unpool2, midpoint= self._get_deconv_module(inputs=inputs,
187+
unpool_rate=unpool_rate,
188+
module_type=TEST_MODULE_NAME,
189+
model_config=model_config,
190+
scope=scope)
191+
192+
expected_midpoint = ModuleEndpointName(deconv_type=TEST_MODULE_NAME,
193+
input_shape=input_shape,
194+
output_shape=expected_output_shape)
195+
196+
197+
print('------------------------------------------------')
198+
print('[tfTest] run test_transconv_unpool_name_shape()')
199+
print('[tfTest] midpoint name and shape')
200+
print('[tfTest] unpool rate = %s' % unpool_rate)
201+
202+
self.assertItemsEqual(midpoint.keys(), expected_midpoint.name_list)
203+
204+
for name, shape in six.iteritems(expected_midpoint.shape_dict):
205+
print ('%s : shape = %s' %(name,shape))
206+
self.assertListEqual(midpoint[name].get_shape().as_list(),shape)
207+
208+
209+
210+
211+
def test_transconv_unknown_batchsize_shape(self):
212+
'''
213+
this func check the below test case:
214+
- when a module is built without specifying batch_norm size,
215+
check whether the model output has a proper batch_size given by an input
216+
'''
217+
scope = 'unitest'
218+
219+
model_config = ModelTestConfig()
220+
TEST_MODULE_NAME = 'conv2dtrans_unpool'
221+
batch_size = 1
222+
223+
input_width = 2
224+
input_height = 2
225+
input_shape = [None, input_height,input_width,1]
226+
unpool_rate = 3
227+
228+
module_graph = tf.Graph()
229+
with module_graph.as_default():
230+
inputs = create_test_input(batchsize= input_shape[0],
231+
heightsize=input_shape[1],
232+
widthsize =input_shape[2],
233+
channelnum= input_shape[3])
234+
235+
module_output, midpoint= self._get_deconv_module(inputs=inputs,
236+
unpool_rate=unpool_rate,
237+
module_type=TEST_MODULE_NAME,
238+
model_config=model_config,
239+
scope=scope)
240+
241+
expected_prefix = scope
242+
self.assertTrue(module_output.op.name.startswith(expected_prefix))
243+
self.assertListEqual(module_output.get_shape().as_list(),
244+
[None,
245+
input_shape[1] * unpool_rate,
246+
input_shape[2] * unpool_rate,
247+
input_shape[3]])
248+
249+
input_shape[0] = batch_size
250+
expected_output_shape = [input_shape[0],
251+
input_shape[1]*unpool_rate,
252+
input_shape[2]*unpool_rate,
253+
input_shape[3]]
254+
255+
256+
# which generate a sample image using np.arange()
257+
print('------------------------------------------------')
258+
print ('[tfTest] run test_transconv_unknown_batchsize_shape()')
259+
print('[tfTest] unpool rate = %s' % unpool_rate)
260+
261+
images = create_test_input( batchsize=input_shape[0],
262+
heightsize=input_shape[1],
263+
widthsize=input_shape[2],
264+
channelnum=input_shape[3])
265+
266+
# tensorboard graph summary =============
267+
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
268+
tb_logdir_path = getcwd() + '/tf_logs'
269+
tb_logdir = "{}/run-{}/".format(tb_logdir_path, now)
270+
271+
if not tf.gfile.Exists(tb_logdir_path):
272+
tf.gfile.MakeDirs(tb_logdir_path)
273+
274+
275+
# summary
276+
tb_summary_writer = tf.summary.FileWriter(logdir=tb_logdir)
277+
tb_summary_writer.add_graph(module_graph)
278+
tb_summary_writer.close()
279+
280+
281+
# write pbfile of graph_def
282+
savedir = getcwd() + '/pbfiles'
283+
if not tf.gfile.Exists(savedir):
284+
tf.gfile.MakeDirs(savedir)
285+
286+
pbfilename = TEST_MODULE_NAME + '.pb'
287+
pbtxtfilename = TEST_MODULE_NAME + '.pbtxt'
288+
289+
with self.test_session(graph=module_graph) as sess:
290+
sess.run(tf.global_variables_initializer())
291+
output = sess.run(module_output, {inputs: images.eval()})
292+
self.assertListEqual(list(output.shape),expected_output_shape)
293+
print ('[TfTest] output shape = %s' % list(output.shape))
294+
print ('[TfTest] expected_output_shape = %s' % expected_output_shape)
295+
296+
print("TF graph_def is saved in pb at %s." % savedir + pbfilename)
297+
tf.train.write_graph(graph_or_graph_def=sess.graph_def,
298+
logdir=savedir,
299+
name=pbfilename)
300+
tf.train.write_graph(graph_or_graph_def=sess.graph_def,
301+
logdir=savedir,
302+
name=pbtxtfilename,as_text=True)
303+
304+
305+
306+
307+
if __name__ == '__main__':
308+
tf.test.main()
309+
310+

0 commit comments

Comments
 (0)