Skip to content

Commit ee17dc3

Browse files
committed
No dense layers in VAE
1 parent 4cf211e commit ee17dc3

File tree

3 files changed

+77
-62
lines changed

3 files changed

+77
-62
lines changed

VariationalAutoEncoders/VAE.py

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
'''
33
Auto-Encoding Variational Bayes - Kingma and Welling 2013
44
5-
This work is absolutely not an effort to reproduce exact results of the cited paper, nor I confine my implementations to the suggestion of the original authors.
6-
I have tried to implement my own limited understanding of the original paper in hope to get a better insight into their work.
75
Use this code with no warranty and please respect the accompanying license.
86
'''
97

@@ -13,68 +11,78 @@
1311
from tools_config import data_dir, expr_dir
1412
import os
1513
import matplotlib.pyplot as plt
16-
from tools_train import get_train_params, OneHot, vis_square
14+
from tools_train import get_train_params, OneHot, vis_square, count_model_params
1715
from datetime import datetime
1816
from tools_general import tf, np
1917
from tools_networks import deconv, conv, dense, clipped_crossentropy, dropout
2018

21-
def create_VAE_E(Xin, labels, is_training, Cout=1, trainable=True, reuse=False, networktype='vaeE'):
19+
def create_VAE_E(Xin, is_training, latentW, latentC, reuse=False, networktype='vaeE'):
2220
'''Xin: batchsize * H * W * Cin
23-
labels: batchsize * num_classes
2421
output1-2: batchsize * Cout'''
2522
with tf.variable_scope(networktype, reuse=reuse):
26-
Eout = conv(Xin, is_training, kernel_w=4, stride=2, Cout=64, pad=1, trainable=trainable, act='reLu', norm='batchnorm', name='conv1') # 14*14
27-
Eout = conv(Eout, is_training, kernel_w=4, stride=2, Cout=128, pad=1, trainable=trainable, act='reLu', norm='batchnorm', name='conv2') # 7*7
28-
posteriorMu = dense(Eout, is_training, trainable=trainable, Cout=Cout, act=None, norm=None, name='dense_mean')
29-
posteriorSigma = dense(Eout, is_training, trainable=trainable, Cout=Cout, act=None, norm=None, name='dense_var')
23+
Eout = conv(Xin, is_training, kernel_w=4, stride=2, Cout=64, pad=1, act='reLu', norm='batchnorm', name='conv1') # 14*14
24+
Eout = conv(Eout, is_training, kernel_w=4, stride=2, Cout=128, pad=1, act='reLu', norm='batchnorm', name='conv2') # 7*7
25+
26+
posteriorMu = conv(Eout, is_training, kernel_w=3, stride=1, Cout=latentC, pad=1, act=None, norm=None, name='conv_mu')
27+
posteriorSigma = conv(Eout, is_training, kernel_w=3, stride=1, Cout=latentC, pad=1, act=None, norm=None, name='conv_sig')
28+
29+
posteriorMu = tf.reshape(posteriorMu, shape=[-1, latentW * latentW * latentC])
30+
posteriorSigma = tf.reshape(posteriorSigma, shape=[-1, latentW * latentW * latentC])
31+
3032
return posteriorMu, posteriorSigma
31-
32-
def create_VAE_D(z, labels, is_training, Cout=1, trainable=True, reuse=False, networktype='vaeD'):
33-
'''z : batchsize * latend_dim
34-
labels: batchsize * num_classes
33+
34+
def create_VAE_D(z, is_training, Cout, latentW, latentC, reuse=False, networktype='vaeD'):
35+
'''input : batchsize * latentDim
3536
output: batchsize * 28 * 28 * 1'''
3637
with tf.variable_scope(networktype, reuse=reuse):
37-
Gz = dense(z, is_training, Cout=4 * 4 * 256, act='reLu', norm='batchnorm', name='dense1')
38-
Gz = tf.reshape(Gz, shape=[-1, 4, 4, 256]) # 4
39-
Gz = deconv(Gz, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='reLu', norm='batchnorm', name='deconv1') # 11
40-
Gz = deconv(Gz, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='reLu', norm='batchnorm', name='deconv2') # 25
41-
Gz = deconv(Gz, is_training, kernel_w=4, stride=Cout, Cout=1, act=None, norm=None, name='deconv3') # 28
42-
Gz = tf.nn.sigmoid(Gz)
43-
return Gz
38+
print("Latent Space Dim = H=%d, W=%d, C=%d" % (latentW, latentW, latentC))
39+
Gout = tf.reshape(z, shape=[-1, latentW, latentW, latentC])
40+
Gout = deconv(Gout, is_training, kernel_w=4, stride=2, epf=2, Cout=128, act='reLu', norm='batchnorm', name='deconv1') # 14
41+
Gout = deconv(Gout, is_training, kernel_w=4, stride=2, epf=2, Cout=Cout, act=None, norm=None, name='deconv2') # 28
42+
Gout = tf.nn.sigmoid(Gout)
43+
return Gout
4444

45-
def create_vae_trainer(base_lr=1e-4, networktype='VAE', latendDim=100):
45+
def create_vae_trainer(base_lr=1e-4, networktype='VAE', Cout=1, latentW=7, latentC=2):
4646
'''Train a Variational AutoEncoder'''
4747
eps = 1e-5
4848

4949
is_training = tf.placeholder(tf.bool, [], 'is_training')
5050

51-
inZ = tf.placeholder(tf.float32, [None, latendDim])
52-
inL = tf.placeholder(tf.float32, [None, 10])
53-
inX = tf.placeholder(tf.float32, [None, 28, 28, 1])
51+
Zph = tf.placeholder(tf.float32, [None, latentW * latentW * latentC])
52+
Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])
5453

55-
posteriorMu, posteriorSigma = create_VAE_E(inX, inL, is_training, Cout=latendDim, trainable=True, reuse=False, networktype=networktype + '_vaeE')
56-
57-
Z = posteriorSigma * inZ + posteriorMu
58-
Xrec = create_VAE_D(Z, inL, is_training, trainable=True, reuse=False, networktype=networktype + '_vaeD')
54+
posteriorMu, posteriorSigma = create_VAE_E(Xph, is_training, latentW, latentC, reuse=False, networktype=networktype + '_E')
55+
Z_op = posteriorSigma * Zph + posteriorMu
56+
Xrec_op = create_VAE_D(Z_op, is_training, Cout, latentW, latentC, reuse=False, networktype=networktype + '_D')
5957

6058
# E[log P(X|z)]
61-
reconstruction_loss = tf.reduce_sum((inX -1.0) * tf.log(1.0 - Xrec + eps) - inX * tf.log(Xrec + eps), reduction_indices = [1,2,3])
62-
# D_KL(Q(z|X) || P(z|X))
63-
KL_QZ = 0.5 * tf.reduce_sum( tf.exp(posteriorSigma) + tf.square(posteriorMu) - 1 - posteriorSigma, reduction_indices = 1)
64-
65-
total_loss = tf.reduce_mean( reconstruction_loss + KL_QZ)
59+
# rec_loss_op = tf.reduce_mean(tf.reduce_sum((Xph - 1.0) * tf.log(1.0 - Xrec_op + eps) - Xph * tf.log(Xrec_op + eps), reduction_indices=[1, 2, 3]))
60+
rec_loss_op = tf.reduce_mean(tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3]))
61+
62+
# D_KL(Q(z|X) || P(z))
63+
KL_loss = tf.reduce_mean(0.5 * tf.reduce_sum(tf.exp(posteriorSigma) + tf.square(posteriorMu) - 1 - posteriorSigma, reduction_indices=[1, ]))
6664

67-
vaetrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize(total_loss)
65+
total_loss_op = tf.add(rec_loss_op , KL_loss)
66+
train_step_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize(total_loss_op)
67+
68+
E_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_E')
69+
D_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_D')
70+
print('Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.' % (count_model_params(E_varlist) / 1000000, count_model_params(D_varlist) / 1000000,))
6871

69-
return vaetrain, total_loss, is_training, inZ, inX, inL, Xrec
72+
return train_step_op, total_loss_op, rec_loss_op, KL_loss, is_training, Zph, Xph, Xrec_op
7073

7174
if __name__ == '__main__':
7275
networktype = 'VAE_MNIST'
7376

7477
batch_size = 128
7578
base_lr = 1e-5
7679
epochs = 200
77-
latendDim = 2
80+
81+
Cout = 1
82+
83+
latentW = 7
84+
latentC = 2
85+
latendDim = latentW * latentW * latentC
7886

7987
work_dir = expr_dir + '%s/%s/' % (networktype, datetime.strftime(datetime.today(), '%Y%m%d'))
8088
if not os.path.exists(work_dir): os.makedirs(work_dir)
@@ -84,45 +92,40 @@ def create_vae_trainer(base_lr=1e-4, networktype='VAE', latendDim=100):
8492
tf.reset_default_graph()
8593
sess = tf.InteractiveSession()
8694

87-
vaetrain, total_loss, is_training, inZ, inX, inL, Xrec = create_vae_trainer(base_lr, networktype=networktype, latendDim=latendDim)
95+
train_step_op, total_loss_op, rec_loss_op, KL_loss, is_training, Zph, Xph, Xrec_op = create_vae_trainer(base_lr, networktype, Cout, latentW, latentC)
8896
tf.global_variables_initializer().run()
8997

90-
9198
var_list = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ('adam' not in var.name.lower())]
92-
saver = tf.train.Saver(var_list=var_list, max_to_keep=100)
99+
saver = tf.train.Saver(var_list=var_list, max_to_keep=int(epochs * .1))
93100
# saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')
94101

95-
best_test_loss = np.inf
96-
97-
train_loss = np.zeros(max_iter)
98-
test_loss = np.zeros(int(np.ceil(max_iter / test_int)))
102+
best_test_total_loss = np.inf
99103

100-
Z_test = np.random.normal(size=[batch_size, latendDim], loc=0.0, scale=1.).astype(np.float32)
101-
labels_test = OneHot(np.random.randint(10, size=[batch_size]), n=10)
102-
104+
train_loss = np.zeros([max_iter,3])
105+
test_loss = np.zeros([int(np.ceil(max_iter / test_int)),3])
106+
103107
for it in range(max_iter):
104108
Z = np.random.normal(size=[batch_size, latendDim], loc=0.0, scale=1.).astype(np.float32)
105109

106110
if it % test_int == 0: # Record summaries and test-set accuracy
107-
accumulated_loss = 0.0
111+
acc_loss = np.zeros([1,3])
108112
for i_test in range(test_iter):
109-
X, labels = data.test.next_batch(batch_size)
110-
111-
recloss = sess.run(total_loss, feed_dict={inX:X, inL:labels, inZ: Z, is_training:False})
112-
accumulated_loss = np.add(accumulated_loss, recloss)
113+
X, _ = data.test.next_batch(batch_size)
114+
resloss = sess.run([total_loss_op, rec_loss_op, KL_loss], feed_dict={Xph:X, Zph: Z, is_training:False})
115+
acc_loss = np.add(acc_loss, resloss)
113116

114-
test_loss[it // test_int] = np.divide(accumulated_loss, test_iter)
117+
test_loss[it // test_int] = np.divide(acc_loss, test_iter)
115118

116-
print("Iteration #%4d, testing .... Test Loss = %f" % (it, test_loss[it // test_int]))
117-
if test_loss[it // test_int] < best_test_loss:
118-
best_test_loss = test_loss[it // test_int]
119-
print('################ Best Results yet.[loss = %2.5f] saving results...' % best_test_loss)
120-
vaeD_sample = sess.run(Xrec, feed_dict={inX:X, inL:labels_test, inZ: Z_test , is_training:False})
121-
vis_square(vaeD_sample[:121], [11, 11], save_path=work_dir + 'Iter_%d.jpg' % it)
122-
saver.save(sess, work_dir + "%.3d_model.ckpt" % it)
119+
print("Iteration #%4d, testing .... Test Loss [total| rec| KL] = " % it, test_loss[it // test_int])
120+
if test_loss[it // test_int,0] < best_test_total_loss:
121+
best_test_total_loss = test_loss[it // test_int,0]
122+
print('################ Best Results yet.[loss = %2.5f] saving results...' % best_test_total_loss)
123+
vaeD_sample = sess.run(Xrec_op, feed_dict={Xph:X, Zph: Z , is_training:False})
124+
vis_square(vaeD_sample[:121], [11, 11], save_path=work_dir + 'Epoch_%.3d_Iter_%d.jpg' % (data.train.epochs_completed, it))
125+
saver.save(sess, work_dir + "%.3d_%.3d_model.ckpt" % (data.train.epochs_completed, it))
123126

124-
X, labels = data.train.next_batch(batch_size)
125-
recloss, _ = sess.run([total_loss, vaetrain], feed_dict={inX:X, inL:labels, inZ: Z, is_training:True})
127+
X, _ = data.train.next_batch(batch_size)
128+
recloss, _ = sess.run([total_loss_op, train_step_op], feed_dict={Xph:X, Zph: Z, is_training:True})
126129

127130
train_loss[it] = recloss
128131
if it % disp_int == 0:print("Iteration #%4d, Train Loss = %f" % (it, recloss))

common/tools_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,4 @@ def regularization(variables, regtype='L1', regcoef=0.1):
182182
else:
183183
raise('regularization type not detected!')
184184
print("Regularizing with type %s, coef %s for %d variables!" % (regtype, regcoef, len(variables)))
185-
return tf.multiply(regcoef, regs)
185+
return tf.multiply(regcoef, regs)

common/tools_train.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,15 @@ def vis_square(X, nh_nw, save_path=None):
4141
return save_path
4242
else:
4343
return img
44+
45+
def count_model_params(variables=None):
46+
if variables == None:
47+
variables = tf.trainable_variables()
48+
total_parameters = 0
49+
for variable in variables:
50+
shape = variable.get_shape()
51+
variable_parametes = 1
52+
for dim in shape:
53+
variable_parametes *= dim.value
54+
total_parameters += variable_parametes
55+
return(total_parameters)

0 commit comments

Comments
 (0)