-
Notifications
You must be signed in to change notification settings - Fork 4
/
ConvAutoencoder.py
318 lines (258 loc) · 12.4 KB
/
ConvAutoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
TF Convolutional Autoencoder
Arash Saber Tehrani - May 2017
Reference: https://github.com/arashsaber/Deep-Convolutional-AutoEncoder
Modified David Yu - July 2018
Reference: https://github.com/MrDavidYu/TF_Convolutional_Autoencoder
Add ons:
1. Allows for custom .jpg input
2. Checkpoint save/restore
3. TensorBoard logs for input/output images
3. Input autorescaling
4. ReLU activation replaced by LeakyReLU
"""
import os
import re
import scipy.misc
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from glob import glob
# Some important consts
num_examples = 669
batch_size = 30
n_epochs = 1000
save_steps = 500 # Number of training batches between checkpoint saves
checkpoint_dir = "./ckpt/"
model_name = "ConvAutoEnc.model"
logs_dir = "./logs/run1/"
# Fetch input data (faces/trees/imgs)
data_dir = "./data/celebG/"
data_path = os.path.join(data_dir, '*.jpg')
data = glob(data_path)
if len(data) == 0:
raise Exception("[!] No data found in '" + data_path+ "'")
'''
Some util functions from https://github.com/carpedm20/DCGAN-tensorflow
'''
def path_to_img(path, grayscale = False):
if (grayscale):
return scipy.misc.imread(path, flatten = True).astype(np.float)
else:
return scipy.misc.imread(path).astype(np.float)
def center_crop(x, crop_h, crop_w,
resize_h=64, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
j = int(round((h - crop_h)/2.))
i = int(round((w - crop_w)/2.))
return scipy.misc.imresize(
x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
def transform(image, input_height, input_width,
resize_height=48, resize_width=48, crop=True):
if crop:
cropped_image = center_crop(
image, input_height, input_width,
resize_height, resize_width)
else:
cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])
return np.array(cropped_image)/127.5 - 1.
def autoresize(image_path, input_height, input_width,
resize_height=48, resize_width=48,
crop=True, grayscale=False):
image = path_to_img(image_path, grayscale)
return transform(image, input_height, input_width,
resize_height, resize_width, crop)
np.random.shuffle(data)
imread_img = path_to_img(data[0]) # test read an image
if len(imread_img.shape) >= 3: # check if image is a non-grayscale image by checking channel number
c_dim = path_to_img(data[0]).shape[-1]
else:
c_dim = 1
is_grayscale = (c_dim == 1)
'''
tf Graph Input
'''
x = tf.placeholder(tf.float32, [None, 48, 48, 3], name='InputData')
if __debug__:
print("Reading input from:" + data_dir)
print("Input image shape:" + str(imread_img.shape))
print("Assigning input tensor of shape:" + str(x.shape))
print("Writing checkpoints to:" + checkpoint_dir)
print("Writing TensorBoard logs to:" + logs_dir)
# strides = [Batch, Height, Width, Channels] in default NHWC data_format. Batch and Channels
# must always be set to 1. If channels is set to 3, then we would increment the index for the
# color channel by 3 everytime we convolve the filter. So this means we would only use one of
# the channels and skip the other two. If we change the Batch number then it means some images
# in the batch are skipped.
#
# To calculate the size of the output of CONV layer:
# OutWidth = (InWidth - FilterWidth + 2*Padding)/Stride + 1
def conv2d(input, name, kshape, strides=[1, 1, 1, 1]):
with tf.variable_scope(name):
W = tf.get_variable(name='w_' + name,
shape=kshape,
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
b = tf.get_variable(name='b_' + name,
shape=[kshape[3]],
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
out = tf.nn.conv2d(input,W,strides=strides, padding='SAME')
out = tf.nn.bias_add(out, b)
out = tf.nn.leaky_relu(out)
return out
# tf.contrib.layers.conv2d_transpose, do not get confused with
# tf.layers.conv2d_transpose
def deconv2d(input, name, kshape, n_outputs, strides=[1, 1]):
with tf.variable_scope(name):
out = tf.contrib.layers.conv2d_transpose(input,
num_outputs= n_outputs,
kernel_size=kshape,
stride=strides,
padding='SAME',
weights_initializer=tf.contrib.layers.xavier_initializer_conv2d(uniform=False),
biases_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
activation_fn=tf.nn.leaky_relu)
return out
# Input to maxpool: [BatchSize, Width1, Height1, Channels]
# Output of maxpool: [BatchSize, Width2, Height2, Channels]
#
# To calculate the size of the output of maxpool layer:
# OutWidth = (InWidth - FilterWidth)/Stride + 1
#
# The kernel kshape will typically be [1,2,2,1] for a general
# RGB image input of [batch_size,48,48,3]
# kshape is 1 for batch and channels because we don't want to take
# the maximum over multiple examples of channels.
def maxpool2d(x,name,kshape=[1, 2, 2, 1], strides=[1, 2, 2, 1]):
with tf.variable_scope(name):
out = tf.nn.max_pool(x,
ksize=kshape, #size of window
strides=strides,
padding='SAME')
return out
def upsample(input, name, factor=[2,2]):
size = [int(input.shape[1] * factor[0]), int(input.shape[2] * factor[1])]
with tf.variable_scope(name):
out = tf.image.resize_bilinear(input, size=size, align_corners=None, name=None)
return out
def fullyConnected(input, name, output_size):
with tf.variable_scope(name):
input_size = input.shape[1:]
input_size = int(np.prod(input_size)) # get total num of cells in one input image
W = tf.get_variable(name='w_'+name,
shape=[input_size, output_size],
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
b = tf.get_variable(name='b_'+name,
shape=[output_size],
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
input = tf.reshape(input, [-1, input_size])
out = tf.nn.leaky_relu(tf.add(tf.matmul(input, W), b))
return out
def dropout(input, name, keep_rate):
with tf.variable_scope(name):
out = tf.nn.dropout(input, keep_rate)
return out
def ConvAutoEncoder(x, name, reuse=False):
with tf.variable_scope(name) as scope:
if reuse:
scope.reuse_variables()
input = tf.reshape(x, shape=[-1, 48, 48, 3])
# kshape = [k_h, k_w, in_channels, out_chnnels]
c1 = conv2d(input, name='c1', kshape=[7, 7, 3, 15]) # Input: [48,48,3]; Output: [48,48,15]
p1 = maxpool2d(c1, name='p1') # Input: [48,48,15]; Output: [24,24,15]
do1 = dropout(p1, name='do1', keep_rate=0.75)
c2 = conv2d(do1, name='c2', kshape=[5, 5, 15, 25]) # Input: [24,24,15]; Output: [24,24,25]
p2 = maxpool2d(c2, name='p2') # Input: [24,24,25]; Output: [12,12,25]
p2 = tf.reshape(p2, shape=[-1, 12*12*25]) # Input: [12,12,25]; Output: [12*12*25]
fc1 = fullyConnected(p2, name='fc1', output_size=12*12*5) # Input: [12*12*25]; Output: [12*12*5]
do2 = dropout(fc1, name='do2', keep_rate=0.75)
fc2 = fullyConnected(do2, name='fc2', output_size=12*12*3) # Input: [12*12*5]; Output: [12*12*3]
do3 = dropout(fc2, name='do3', keep_rate=0.75)
fc3 = fullyConnected(do3, name='fc3', output_size=64) # Input: [12*12*3]; Output: [64] --> bottleneck layer
# Decoding part
fc4 = fullyConnected(fc3, name='fc4', output_size=12*12*3) # Input: [64]; Output: [12*12*3]
do4 = dropout(fc4, name='do4', keep_rate=0.75)
fc5 = fullyConnected(do4, name='fc5', output_size=12*12*5) # Input: [12*12*3]; Output: [12*12*5]
do5 = dropout(fc5, name='do5', keep_rate=0.75)
fc6 = fullyConnected(do5, name='fc6', output_size=21*21*25) # Input: [12*12*5]; Output: [12*12*25]
do6 = dropout(fc6, name='do6', keep_rate=0.75)
do6 = tf.reshape(do6, shape=[-1, 21, 21, 25]) # Input: [12*12*25]; Output: [12,12,25]
dc1 = deconv2d(do6, name='dc1', kshape=[5, 5],n_outputs=15) # Input: [12,12,25]; Output: [12,12,15]
up1 = upsample(dc1, name='up1', factor=[2, 2]) # Input: [12,12,15]; Output: [24,24,15]
dc2 = deconv2d(up1, name='dc2', kshape=[7, 7],n_outputs=3) # Input: [24,24,15]; Output: [24,24,3]
up2 = upsample(dc2, name='up2', factor=[2, 2]) # Input: [24,24,3]; Output: [48,48,3]
output = fullyConnected(up2, name='output', output_size=48*48*3)
with tf.variable_scope('cost'):
# N.B. reduce_mean is a batch operation! finds the mean across the batch
cost = tf.reduce_mean(tf.square(tf.subtract(output, tf.reshape(x,shape=[-1,48*48*3]))))
return x, tf.reshape(output,shape=[-1,48,48,3]), cost # returning, input, output and cost
def train_network(x):
with tf.Session() as sess:
_, _, cost = ConvAutoEncoder(x, 'ConvAutoEnc')
with tf.variable_scope('opt'):
optimizer = tf.train.AdamOptimizer().minimize(cost)
# Create a summary to monitor cost tensor
tf.summary.scalar("cost", cost)
tf.summary.image("face_input", ConvAutoEncoder(x, 'ConvAutoEnc', reuse=True)[0], max_outputs=4)
tf.summary.image("face_output", ConvAutoEncoder(x, 'ConvAutoEnc', reuse=True)[1], max_outputs=4)
merged_summary_op = tf.summary.merge_all() # Merge all summaries into a single op
sess.run(tf.global_variables_initializer()) # memory allocation exceeded 10% issue
# Model saver
saver = tf.train.Saver()
counter = 0 # Used for checkpointing
success, restored_counter = restore(saver, sess)
if success:
counter = restored_counter
print(">>> Restore successful")
else:
print(">>> No restore checkpoints detected")
# create log writer object
writer = tf.summary.FileWriter(logs_dir, graph=tf.get_default_graph())
for epoch in range(n_epochs):
avg_cost = 0
n_batches = int(num_examples / batch_size)
# Loop over all batches
for i in range(n_batches):
counter += 1
print("epoch " + str(epoch) + " batch " + str(i))
batch_files = data[i*batch_size:(i+1)*batch_size] # get the current batch of files
batch = [autoresize(batch_file,
input_height=48,
input_width=48,
resize_height=48,
resize_width=48,
crop=True,
grayscale=False) for batch_file in batch_files]
batch_images = np.array(batch).astype(np.float32)
# Get cost function from running optimizer
_, c, summary = sess.run([optimizer, cost, merged_summary_op], feed_dict={x: batch_images})
# Compute average loss
avg_cost += c / n_batches
writer.add_summary(summary, epoch * n_batches + i)
if counter % save_steps == 0:
save(saver, counter, sess)
# Display logs per epoch step
print('Epoch', epoch + 1, ' / ', n_epochs, 'cost:', avg_cost)
print('>>> Optimization Finished')
# Create checkpoint
def save(saver, step, session):
print(">>> Saving to checkpoint, step:" + str(step))
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver.save(session,
os.path.join(checkpoint_dir, model_name),
global_step=step)
# Restore from checkpoint
def restore(saver, session):
print(">>> Restoring from checkpoints...")
checkpoint_state = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint_state and checkpoint_state.model_checkpoint_path:
checkpoint_name = os.path.basename(checkpoint_state.model_checkpoint_path)
saver.restore(session, os.path.join(checkpoint_dir, checkpoint_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",checkpoint_name)).group(0))
print(">>> Found restore checkpoint {}".format(checkpoint_name))
return True, counter
else:
return False, 0
train_network(x)