3
3
Generative Adversarial Networks - Goodfellow et al
4
4
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks - Radford et al
5
5
6
- This work is absolutely not an effort to reproduce exact results of the cited paper, nor I confine my implementations to the suggestion of the original authors.
7
- I have tried to implement my own limited understanding of the original paper in hope to get a better insight into their work.
8
6
Use this code with no warranty and please respect the accompanying license.
9
7
'''
10
8
19
17
from tools_general import tf , np
20
18
from tools_networks import deconv , conv , dense , clipped_crossentropy , dropout
21
19
22
- def concat_labels (X , labels ):
23
- if X .get_shape ().ndims == 4 :
24
- X_shape = tf .shape (X )
25
- labels_reshaped = tf .reshape (labels , [- 1 , 1 , 1 , 10 ])
26
- a = tf .ones ([X_shape [0 ], X_shape [1 ], X_shape [2 ], 10 ])
27
- X = tf .concat ([X , labels_reshaped * a ], axis = 3 )
28
- return X
29
-
30
- def create_gan_G (z , labels , is_training , Cout = 1 , trainable = True , reuse = False , networktype = 'ganG' ):
31
- '''input : batchsize * 100 and labels to make the generator conditional
20
+ from tensorflow .examples .tutorials .mnist import input_data
21
+
22
+ def create_gan_G (z , is_training , Cout = 1 , trainable = True , reuse = False , networktype = 'ganG' ):
23
+ '''input : batchsize * 100
32
24
output: batchsize * 28 * 28 * 1'''
33
25
with tf .variable_scope (networktype , reuse = reuse ):
34
- z = tf .concat (axis = - 1 , values = [z , labels ])
35
- Gz = dense (z , is_training , Cout = 4 * 4 * 256 , act = 'reLu' , norm = 'batchnorm' , name = 'dense1' )
36
- Gz = tf .reshape (Gz , shape = [- 1 , 4 , 4 , 256 ]) # 4
37
- Gz = deconv (Gz , is_training , kernel_w = 5 , stride = 2 , Cout = 256 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'deconv1' ) # 11
38
- Gz = deconv (Gz , is_training , kernel_w = 5 , stride = 2 , Cout = 128 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'deconv2' ) # 25
39
- Gz = deconv (Gz , is_training , kernel_w = 4 , stride = 1 , Cout = Cout , act = None , norm = None , name = 'deconv3' ) # 28
40
- Gz = tf .nn .sigmoid (Gz )
41
- return Gz
26
+ Gout_op = dense (z , is_training , Cout = 4 * 4 * 256 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'dense1' )
27
+ Gout_op = tf .reshape (Gout_op , shape = [- 1 , 4 , 4 , 256 ]) # 4
28
+ Gout_op = deconv (Gout_op , is_training , kernel_w = 5 , stride = 2 , Cout = 256 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'deconv1' ) # 11
29
+ Gout_op = deconv (Gout_op , is_training , kernel_w = 5 , stride = 2 , Cout = 128 , trainable = trainable , act = 'reLu' , norm = 'batchnorm' , name = 'deconv2' ) # 25
30
+ Gout_op = deconv (Gout_op , is_training , kernel_w = 4 , stride = 1 , Cout = Cout , trainable = trainable , act = None , norm = None , name = 'deconv3' ) # 28
31
+ Gout_op = tf .nn .sigmoid (Gout_op )
32
+ return Gout_op
42
33
43
- def create_gan_D (xz , labels , is_training , trainable = True , reuse = False , networktype = 'ganD' ):
34
+ def create_gan_D (xz , is_training , trainable = True , reuse = False , networktype = 'ganD' ):
44
35
with tf .variable_scope (networktype , reuse = reuse ):
45
- xz = concat_labels (xz , labels )
46
36
Dxz = conv (xz , is_training , kernel_w = 5 , stride = 2 , Cout = 128 , trainable = trainable , act = 'lrelu' , norm = None , name = 'conv1' ) # 12
47
37
Dxz = conv (Dxz , is_training , kernel_w = 5 , stride = 2 , Cout = 256 , trainable = trainable , act = 'lrelu' , norm = 'batchnorm' , name = 'conv2' ) # 4
48
38
Dxz = conv (Dxz , is_training , kernel_w = 2 , stride = 2 , Cout = 256 , trainable = trainable , act = 'lrelu' , norm = 'batchnorm' , name = 'conv3' ) # 2
49
- Dxz = conv (Dxz , is_training , kernel_w = 2 , stride = 2 , Cout = 1 , trainable = trainable , act = 'lrelu' , norm = 'batchnorm' , name = 'conv4' ) # 2
39
+ Dxz = conv (Dxz , is_training , kernel_w = 2 , stride = 2 , Cout = 1 , trainable = trainable , act = None , norm = 'batchnorm' , name = 'conv4' ) # 2
50
40
Dxz = tf .nn .sigmoid (Dxz )
51
41
return Dxz
52
42
53
- def create_dcgan_trainer (base_lr = 1e-4 , networktype = 'dcgan' ):
43
+ def create_dcgan_trainer (base_lr = 1e-4 , networktype = 'dcgan' , latentDim = 100 ):
54
44
'''Train a Generative Adversarial Network'''
55
- # with tf.name_scope('train_%s' % networktype):
56
45
is_training = tf .placeholder (tf .bool , [], 'is_training' )
57
46
58
- inZ = tf .placeholder (tf .float32 , [None , 100 ]) # tf.random_uniform(shape=[batch_size, 100], minval=-1., maxval=1., dtype=tf.float32)
59
- inL = tf .placeholder (tf .float32 , [None , 10 ]) # we want to condition the generated out put on some parameters of the input
60
- inX = tf .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
47
+ Zph = tf .placeholder (tf .float32 , [None , latentDim ]) # tf.random_uniform(shape=[batch_size, 100], minval=-1., maxval=1., dtype=tf.float32)
48
+ Xph = tf .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
61
49
62
- Gz = create_gan_G (inZ , inL , is_training , Cout = 1 , trainable = True , reuse = False , networktype = networktype + '_G' )
50
+ Gout_op = create_gan_G (Zph , is_training , Cout = 1 , trainable = True , reuse = False , networktype = networktype + '_G' )
63
51
64
- DGz = create_gan_D (Gz , inL , is_training , trainable = True , reuse = False , networktype = networktype + '_D' )
65
- Dx = create_gan_D (inX , inL , is_training , trainable = True , reuse = True , networktype = networktype + '_D' )
52
+ fakeLogits = create_gan_D (Gout_op , is_training , trainable = True , reuse = False , networktype = networktype + '_D' )
53
+ realLogits = create_gan_D (Xph , is_training , trainable = True , reuse = True , networktype = networktype + '_D' )
66
54
67
- ganG_var_list = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = networktype + '_G' )
68
- print (len (ganG_var_list ), [var .name for var in ganG_var_list ])
55
+ G_varlist = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = networktype + '_G' )
56
+ print (len (G_varlist ), [var .name for var in G_varlist ])
69
57
70
- ganD_var_list = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = networktype + '_D' )
71
- print (len (ganD_var_list ), [var .name for var in ganD_var_list ])
58
+ D_varlist = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = networktype + '_D' )
59
+ print (len (D_varlist ), [var .name for var in D_varlist ])
72
60
73
- Gscore = clipped_crossentropy (DGz , tf .ones_like (DGz ))
74
- Dscore = clipped_crossentropy (DGz , tf .zeros_like (DGz )) + clipped_crossentropy (Dx , tf .ones_like (Dx ))
61
+ Gloss = clipped_crossentropy (fakeLogits , tf .ones_like (fakeLogits ))
62
+ Dloss = clipped_crossentropy (fakeLogits , tf .zeros_like (fakeLogits )) + clipped_crossentropy (realLogits , tf .ones_like (realLogits ))
75
63
76
- Gtrain = tf .train .AdamOptimizer (learning_rate = base_lr , beta1 = 0.5 ).minimize (Gscore , var_list = ganG_var_list )
77
- Dtrain = tf .train .AdamOptimizer (learning_rate = base_lr , beta1 = 0.5 ).minimize (Dscore , var_list = ganD_var_list )
64
+ Gtrain_op = tf .train .AdamOptimizer (learning_rate = base_lr , beta1 = 0.5 ).minimize (Gloss , var_list = G_varlist )
65
+ Dtrain_op = tf .train .AdamOptimizer (learning_rate = base_lr , beta1 = 0.5 ).minimize (Dloss , var_list = D_varlist )
78
66
79
- return Gtrain , Dtrain , Gscore , Dscore , is_training , inZ , inX , inL , Gz
67
+ return Gtrain_op , Dtrain_op , Gloss , Dloss , is_training , Zph , Xph , Gout_op
80
68
81
69
if __name__ == '__main__' :
82
70
networktype = 'DCGAN_MNIST'
83
71
84
72
batch_size = 128
85
- base_lr = 0.0002 # 1e-4
86
- epochs = 30
73
+ base_lr = 2e-4
74
+ epochs = 1000
75
+ latentDim = 100
87
76
88
77
work_dir = expr_dir + '%s/%s/' % (networktype , datetime .strftime (datetime .today (), '%Y%m%d' ))
89
78
if not os .path .exists (work_dir ): os .makedirs (work_dir )
90
79
91
- data , max_iter , test_iter , test_int , disp_int = get_train_params (data_dir + '/' + networktype , batch_size , epochs = epochs , test_in_each_epoch = 1 , networktype = networktype )
80
+ data = input_data .read_data_sets (data_dir + '/' + networktype , reshape = False )
81
+ disp_int = 2 * int (data .train .num_examples / batch_size ) # every two epochs
92
82
93
83
tf .reset_default_graph ()
94
84
sess = tf .InteractiveSession ()
95
85
96
- Gtrain , Dtrain , Gscore , Dscore , is_training , inZ , inX , inL , Gz = create_dcgan_trainer (base_lr , networktype = networktype )
86
+ Gtrain_op , Dtrain_op , Gloss , Dloss , is_training , Zph , Xph , Gout_op = create_dcgan_trainer (base_lr , networktype = networktype )
97
87
tf .global_variables_initializer ().run ()
98
88
99
89
var_list = [var for var in tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES ) if (networktype .lower () in var .name .lower ()) and ('adam' not in var .name .lower ())]
100
- saver = tf .train .Saver (var_list = var_list , max_to_keep = 1000 )
90
+ saver = tf .train .Saver (var_list = var_list , max_to_keep = 1000 )
101
91
# saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')
102
-
103
- Z_test = np .random .uniform (size = [batch_size , 100 ], low = - 1. , high = 1. ).astype (np .float32 )
104
- labels_test = OneHot (np .random .randint (10 , size = [batch_size ]), n = 10 )
105
-
92
+
106
93
k = 1
107
-
108
- for it in range (1 , max_iter ):
109
- Z = np .random .uniform (size = [batch_size , 100 ], low = - 1. , high = 1. ).astype (np .float32 )
110
- X , labels = data .train .next_batch (batch_size )
111
-
94
+ it = 0
95
+ disp_losses = False
96
+
97
+ while data .train .epochs_completed < epochs :
98
+ dtemploss = 0
99
+
112
100
for itD in range (k ):
113
- cur_Dscore , _ = sess .run ([Dscore , Dtrain ], feed_dict = {inX :X , inZ :Z , inL :labels , is_training :True })
101
+ it += 1
102
+ Z = np .random .uniform (size = [batch_size , latentDim ], low = - 1. , high = 1. ).astype (np .float32 )
103
+ X , _ = data .train .next_batch (batch_size )
104
+
105
+ cur_Dloss , _ = sess .run ([Dloss , Dtrain_op ], feed_dict = {Xph :X , Zph :Z , is_training :True })
106
+ dtemploss += cur_Dloss
114
107
115
- cur_Gscore , _ = sess .run ([Gscore , Gtrain ], feed_dict = {inZ :Z , inL :labels , is_training :True })
108
+ if it % disp_int == 0 :disp_losses = True
109
+
110
+ cur_Dloss = dtemploss / k
111
+
112
+ Z = np .random .uniform (size = [batch_size , latentDim ], low = - 1. , high = 1. ).astype (np .float32 )
113
+ cur_Gscore , _ = sess .run ([Gloss , Gtrain_op ], feed_dict = {Zph :Z , is_training :True })
116
114
117
- if it % disp_int == 0 :
118
- Gz_sample = sess .run (Gz , feed_dict = {inZ : Z_test , inL : labels_test , is_training :False })
119
- vis_square (Gz_sample [:121 ], [11 , 11 ], save_path = work_dir + 'Iter_%d. jpg' % it )
120
- saver .save (sess , work_dir + "%.3d_model.ckpt" % it )
121
- if ( 'cur_Dscore' in vars ()) and ( 'cur_Gscore' in vars ()):
122
- print ( "Iteration #%4d, Train Gscore = %f, Dscore=%f" % ( it , cur_Gscore , cur_Dscore ))
115
+ if disp_losses :
116
+ Gsample = sess .run (Gout_op , feed_dict = {Zph : Z , is_training :False })
117
+ vis_square (Gsample [:121 ], [11 , 11 ], save_path = work_dir + 'Epoch%.3d. jpg' % data . train . epochs_completed )
118
+ saver .save (sess , work_dir + "%.3d_model.ckpt" % data . train . epochs_completed )
119
+ print ( "Epoch #%.3d, Train Gloss = %f, Dloss=%f" % ( data . train . epochs_completed , cur_Gloss , cur_Dloss ))
120
+ disp_losses = False
0 commit comments