2
2
'''
3
3
Auto-Encoding Variational Bayes - Kingma and Welling 2013
4
4
5
- This work is absolutely not an effort to reproduce exact results of the cited paper, nor I confine my implementations to the suggestion of the original authors.
6
- I have tried to implement my own limited understanding of the original paper in hope to get a better insight into their work.
7
5
Use this code with no warranty and please respect the accompanying license.
8
6
'''
9
7
13
11
from tools_config import data_dir , expr_dir
14
12
import os
15
13
import matplotlib .pyplot as plt
16
- from tools_train import get_train_params , OneHot , vis_square
14
+ from tools_train import get_train_params , OneHot , vis_square , count_model_params
17
15
from datetime import datetime
18
16
from tools_general import tf , np
19
17
from tools_networks import deconv , conv , dense , clipped_crossentropy , dropout
20
18
21
- def create_VAE_E (Xin , labels , is_training , Cout = 1 , trainable = True , reuse = False , networktype = 'vaeE' ):
19
+ def create_VAE_E (Xin , is_training , latentW , latentC , reuse = False , networktype = 'vaeE' ):
22
20
'''Xin: batchsize * H * W * Cin
23
- labels: batchsize * num_classes
24
21
output1-2: batchsize * Cout'''
25
22
with tf .variable_scope (networktype , reuse = reuse ):
26
- Eout = conv (Xin , is_training , kernel_w = 4 , stride = 2 , Cout = 64 , pad = 1 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'conv1' ) # 14*14
27
- Eout = conv (Eout , is_training , kernel_w = 4 , stride = 2 , Cout = 128 , pad = 1 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'conv2' ) # 7*7
28
- posteriorMu = dense (Eout , is_training , trainable = trainable , Cout = Cout , act = None , norm = None , name = 'dense_mean' )
29
- posteriorSigma = dense (Eout , is_training , trainable = trainable , Cout = Cout , act = None , norm = None , name = 'dense_var' )
23
+ Eout = conv (Xin , is_training , kernel_w = 4 , stride = 2 , Cout = 64 , pad = 1 , act = 'reLu' , norm = 'batchnorm' , name = 'conv1' ) # 14*14
24
+ Eout = conv (Eout , is_training , kernel_w = 4 , stride = 2 , Cout = 128 , pad = 1 , act = 'reLu' , norm = 'batchnorm' , name = 'conv2' ) # 7*7
25
+
26
+ posteriorMu = conv (Eout , is_training , kernel_w = 3 , stride = 1 , Cout = latentC , pad = 1 , act = None , norm = None , name = 'conv_mu' )
27
+ posteriorSigma = conv (Eout , is_training , kernel_w = 3 , stride = 1 , Cout = latentC , pad = 1 , act = None , norm = None , name = 'conv_sig' )
28
+
29
+ posteriorMu = tf .reshape (posteriorMu , shape = [- 1 , latentW * latentW * latentC ])
30
+ posteriorSigma = tf .reshape (posteriorSigma , shape = [- 1 , latentW * latentW * latentC ])
31
+
30
32
return posteriorMu , posteriorSigma
31
-
32
- def create_VAE_D (z , labels , is_training , Cout = 1 , trainable = True , reuse = False , networktype = 'vaeD' ):
33
- '''z : batchsize * latend_dim
34
- labels: batchsize * num_classes
33
+
34
+ def create_VAE_D (z , is_training , Cout , latentW , latentC , reuse = False , networktype = 'vaeD' ):
35
+ '''input : batchsize * latentDim
35
36
output: batchsize * 28 * 28 * 1'''
36
37
with tf .variable_scope (networktype , reuse = reuse ):
37
- Gz = dense (z , is_training , Cout = 4 * 4 * 256 , act = 'reLu' , norm = 'batchnorm' , name = 'dense1' )
38
- Gz = tf .reshape (Gz , shape = [- 1 , 4 , 4 , 256 ]) # 4
39
- Gz = deconv (Gz , is_training , kernel_w = 5 , stride = 2 , Cout = 256 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'deconv1' ) # 11
40
- Gz = deconv (Gz , is_training , kernel_w = 5 , stride = 2 , Cout = 128 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'deconv2' ) # 25
41
- Gz = deconv (Gz , is_training , kernel_w = 4 , stride = Cout , Cout = 1 , act = None , norm = None , name = 'deconv3' ) # 28
42
- Gz = tf .nn .sigmoid (Gz )
43
- return Gz
38
+ print ("Latent Space Dim = H=%d, W=%d, C=%d" % (latentW , latentW , latentC ))
39
+ Gout = tf .reshape (z , shape = [- 1 , latentW , latentW , latentC ])
40
+ Gout = deconv (Gout , is_training , kernel_w = 4 , stride = 2 , epf = 2 , Cout = 128 , act = 'reLu' , norm = 'batchnorm' , name = 'deconv1' ) # 14
41
+ Gout = deconv (Gout , is_training , kernel_w = 4 , stride = 2 , epf = 2 , Cout = Cout , act = None , norm = None , name = 'deconv2' ) # 28
42
+ Gout = tf .nn .sigmoid (Gout )
43
+ return Gout
44
44
45
- def create_vae_trainer (base_lr = 1e-4 , networktype = 'VAE' , latendDim = 100 ):
45
+ def create_vae_trainer (base_lr = 1e-4 , networktype = 'VAE' , Cout = 1 , latentW = 7 , latentC = 2 ):
46
46
'''Train a Variational AutoEncoder'''
47
47
eps = 1e-5
48
48
49
49
is_training = tf .placeholder (tf .bool , [], 'is_training' )
50
50
51
- inZ = tf .placeholder (tf .float32 , [None , latendDim ])
52
- inL = tf .placeholder (tf .float32 , [None , 10 ])
53
- inX = tf .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
51
+ Zph = tf .placeholder (tf .float32 , [None , latentW * latentW * latentC ])
52
+ Xph = tf .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
54
53
55
- posteriorMu , posteriorSigma = create_VAE_E (inX , inL , is_training , Cout = latendDim , trainable = True , reuse = False , networktype = networktype + '_vaeE' )
56
-
57
- Z = posteriorSigma * inZ + posteriorMu
58
- Xrec = create_VAE_D (Z , inL , is_training , trainable = True , reuse = False , networktype = networktype + '_vaeD' )
54
+ posteriorMu , posteriorSigma = create_VAE_E (Xph , is_training , latentW , latentC , reuse = False , networktype = networktype + '_E' )
55
+ Z_op = posteriorSigma * Zph + posteriorMu
56
+ Xrec_op = create_VAE_D (Z_op , is_training , Cout , latentW , latentC , reuse = False , networktype = networktype + '_D' )
59
57
60
58
# E[log P(X|z)]
61
- reconstruction_loss = tf .reduce_sum ((inX - 1.0 ) * tf .log (1.0 - Xrec + eps ) - inX * tf .log (Xrec + eps ), reduction_indices = [1 ,2 , 3 ] )
62
- # D_KL(Q(z|X) || P(z|X ))
63
- KL_QZ = 0.5 * tf . reduce_sum ( tf . exp ( posteriorSigma ) + tf . square ( posteriorMu ) - 1 - posteriorSigma , reduction_indices = 1 )
64
-
65
- total_loss = tf .reduce_mean ( reconstruction_loss + KL_QZ )
59
+ # rec_loss_op = tf.reduce_mean(tf. reduce_sum((Xph - 1.0) * tf.log(1.0 - Xrec_op + eps) - Xph * tf.log(Xrec_op + eps), reduction_indices= [1, 2, 3]) )
60
+ rec_loss_op = tf . reduce_mean ( tf . reduce_sum ( tf . square ( tf . subtract ( Xph , Xrec_op )), reduction_indices = [ 1 , 2 , 3 ] ))
61
+
62
+ # D_KL(Q(z|X) || P(z))
63
+ KL_loss = tf .reduce_mean (0.5 * tf . reduce_sum ( tf . exp ( posteriorSigma ) + tf . square ( posteriorMu ) - 1 - posteriorSigma , reduction_indices = [ 1 , ]))
66
64
67
- vaetrain = tf .train .AdamOptimizer (learning_rate = base_lr , beta1 = 0.9 ).minimize (total_loss )
65
+ total_loss_op = tf .add (rec_loss_op , KL_loss )
66
+ train_step_op = tf .train .AdamOptimizer (learning_rate = base_lr , beta1 = 0.9 ).minimize (total_loss_op )
67
+
68
+ E_varlist = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = networktype + '_E' )
69
+ D_varlist = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = networktype + '_D' )
70
+ print ('Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.' % (count_model_params (E_varlist ) / 1000000 , count_model_params (D_varlist ) / 1000000 ,))
68
71
69
- return vaetrain , total_loss , is_training , inZ , inX , inL , Xrec
72
+ return train_step_op , total_loss_op , rec_loss_op , KL_loss , is_training , Zph , Xph , Xrec_op
70
73
71
74
if __name__ == '__main__' :
72
75
networktype = 'VAE_MNIST'
73
76
74
77
batch_size = 128
75
78
base_lr = 1e-5
76
79
epochs = 200
77
- latendDim = 2
80
+
81
+ Cout = 1
82
+
83
+ latentW = 7
84
+ latentC = 2
85
+ latendDim = latentW * latentW * latentC
78
86
79
87
work_dir = expr_dir + '%s/%s/' % (networktype , datetime .strftime (datetime .today (), '%Y%m%d' ))
80
88
if not os .path .exists (work_dir ): os .makedirs (work_dir )
@@ -84,45 +92,40 @@ def create_vae_trainer(base_lr=1e-4, networktype='VAE', latendDim=100):
84
92
tf .reset_default_graph ()
85
93
sess = tf .InteractiveSession ()
86
94
87
- vaetrain , total_loss , is_training , inZ , inX , inL , Xrec = create_vae_trainer (base_lr , networktype = networktype , latendDim = latendDim )
95
+ train_step_op , total_loss_op , rec_loss_op , KL_loss , is_training , Zph , Xph , Xrec_op = create_vae_trainer (base_lr , networktype , Cout , latentW , latentC )
88
96
tf .global_variables_initializer ().run ()
89
97
90
-
91
98
var_list = [var for var in tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES ) if (networktype .lower () in var .name .lower ()) and ('adam' not in var .name .lower ())]
92
- saver = tf .train .Saver (var_list = var_list , max_to_keep = 100 )
99
+ saver = tf .train .Saver (var_list = var_list , max_to_keep = int ( epochs * .1 ) )
93
100
# saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')
94
101
95
- best_test_loss = np .inf
96
-
97
- train_loss = np .zeros (max_iter )
98
- test_loss = np .zeros (int (np .ceil (max_iter / test_int )))
102
+ best_test_total_loss = np .inf
99
103
100
- Z_test = np .random . normal ( size = [ batch_size , latendDim ], loc = 0.0 , scale = 1. ). astype ( np . float32 )
101
- labels_test = OneHot ( np .random . randint ( 10 , size = [ batch_size ]), n = 10 )
102
-
104
+ train_loss = np .zeros ([ max_iter , 3 ] )
105
+ test_loss = np . zeros ([ int ( np .ceil ( max_iter / test_int )), 3 ])
106
+
103
107
for it in range (max_iter ):
104
108
Z = np .random .normal (size = [batch_size , latendDim ], loc = 0.0 , scale = 1. ).astype (np .float32 )
105
109
106
110
if it % test_int == 0 : # Record summaries and test-set accuracy
107
- accumulated_loss = 0.0
111
+ acc_loss = np . zeros ([ 1 , 3 ])
108
112
for i_test in range (test_iter ):
109
- X , labels = data .test .next_batch (batch_size )
110
-
111
- recloss = sess .run (total_loss , feed_dict = {inX :X , inL :labels , inZ : Z , is_training :False })
112
- accumulated_loss = np .add (accumulated_loss , recloss )
113
+ X , _ = data .test .next_batch (batch_size )
114
+ resloss = sess .run ([total_loss_op , rec_loss_op , KL_loss ], feed_dict = {Xph :X , Zph : Z , is_training :False })
115
+ acc_loss = np .add (acc_loss , resloss )
113
116
114
- test_loss [it // test_int ] = np .divide (accumulated_loss , test_iter )
117
+ test_loss [it // test_int ] = np .divide (acc_loss , test_iter )
115
118
116
- print ("Iteration #%4d, testing .... Test Loss = %f " % ( it , test_loss [it // test_int ]) )
117
- if test_loss [it // test_int ] < best_test_loss :
118
- best_test_loss = test_loss [it // test_int ]
119
- print ('################ Best Results yet.[loss = %2.5f] saving results...' % best_test_loss )
120
- vaeD_sample = sess .run (Xrec , feed_dict = {inX :X , inL : labels_test , inZ : Z_test , is_training :False })
121
- vis_square (vaeD_sample [:121 ], [11 , 11 ], save_path = work_dir + 'Iter_% d.jpg' % it )
122
- saver .save (sess , work_dir + "%.3d_model.ckpt" % it )
119
+ print ("Iteration #%4d, testing .... Test Loss [total| rec| KL] = " % it , test_loss [it // test_int ])
120
+ if test_loss [it // test_int , 0 ] < best_test_total_loss :
121
+ best_test_total_loss = test_loss [it // test_int , 0 ]
122
+ print ('################ Best Results yet.[loss = %2.5f] saving results...' % best_test_total_loss )
123
+ vaeD_sample = sess .run (Xrec_op , feed_dict = {Xph :X , Zph : Z , is_training :False })
124
+ vis_square (vaeD_sample [:121 ], [11 , 11 ], save_path = work_dir + 'Epoch_%.3d_Iter_% d.jpg' % ( data . train . epochs_completed , it ) )
125
+ saver .save (sess , work_dir + "%.3d_%. 3d_model.ckpt" % ( data . train . epochs_completed , it ) )
123
126
124
- X , labels = data .train .next_batch (batch_size )
125
- recloss , _ = sess .run ([total_loss , vaetrain ], feed_dict = {inX :X , inL : labels , inZ : Z , is_training :True })
127
+ X , _ = data .train .next_batch (batch_size )
128
+ recloss , _ = sess .run ([total_loss_op , train_step_op ], feed_dict = {Xph :X , Zph : Z , is_training :True })
126
129
127
130
train_loss [it ] = recloss
128
131
if it % disp_int == 0 :print ("Iteration #%4d, Train Loss = %f" % (it , recloss ))
0 commit comments