Skip to content

Commit 4a52267

Browse files
committed
removed conditional part
1 parent f8aa066 commit 4a52267

File tree

1 file changed

+60
-62
lines changed

1 file changed

+60
-62
lines changed

GenerativeAdversarialNetworks/DCGAN.py

Lines changed: 60 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
Generative Adversarial Networks - Goodfellow et al
44
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks - Radford et al
55
6-
This work is absolutely not an effort to reproduce exact results of the cited paper, nor I confine my implementations to the suggestion of the original authors.
7-
I have tried to implement my own limited understanding of the original paper in hope to get a better insight into their work.
86
Use this code with no warranty and please respect the accompanying license.
97
'''
108

@@ -19,104 +17,104 @@
1917
from tools_general import tf, np
2018
from tools_networks import deconv, conv, dense, clipped_crossentropy, dropout
2119

22-
def concat_labels(X, labels):
23-
if X.get_shape().ndims == 4:
24-
X_shape = tf.shape(X)
25-
labels_reshaped = tf.reshape(labels, [-1, 1, 1, 10])
26-
a = tf.ones([X_shape[0], X_shape[1], X_shape[2], 10])
27-
X = tf.concat([X, labels_reshaped * a], axis=3)
28-
return X
29-
30-
def create_gan_G(z, labels, is_training, Cout=1, trainable=True, reuse=False, networktype='ganG'):
31-
'''input : batchsize * 100 and labels to make the generator conditional
20+
from tensorflow.examples.tutorials.mnist import input_data
21+
22+
def create_gan_G(z, is_training, Cout=1, trainable=True, reuse=False, networktype='ganG'):
23+
'''input : batchsize * 100
3224
output: batchsize * 28 * 28 * 1'''
3325
with tf.variable_scope(networktype, reuse=reuse):
34-
z = tf.concat(axis=-1, values=[z, labels])
35-
Gz = dense(z, is_training, Cout=4 * 4 * 256, act='reLu', norm='batchnorm', name='dense1')
36-
Gz = tf.reshape(Gz, shape=[-1, 4, 4, 256]) # 4
37-
Gz = deconv(Gz, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='reLu', norm='batchnorm', name='deconv1') # 11
38-
Gz = deconv(Gz, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='reLu', norm='batchnorm', name='deconv2') # 25
39-
Gz = deconv(Gz, is_training, kernel_w=4, stride=1, Cout=Cout, act=None, norm=None, name='deconv3') # 28
40-
Gz = tf.nn.sigmoid(Gz)
41-
return Gz
26+
Gout_op = dense(z, is_training, Cout=4 * 4 * 256, trainable=trainable, act='reLu', norm='batchnorm', name='dense1')
27+
Gout_op = tf.reshape(Gout_op, shape=[-1, 4, 4, 256]) # 4
28+
Gout_op = deconv(Gout_op, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='reLu', norm='batchnorm', name='deconv1') # 11
29+
Gout_op = deconv(Gout_op, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='reLu', norm='batchnorm', name='deconv2') # 25
30+
Gout_op = deconv(Gout_op, is_training, kernel_w=4, stride=1, Cout=Cout, trainable=trainable, act=None, norm=None, name='deconv3') # 28
31+
Gout_op = tf.nn.sigmoid(Gout_op)
32+
return Gout_op
4233

43-
def create_gan_D(xz, labels, is_training, trainable=True, reuse=False, networktype='ganD'):
34+
def create_gan_D(xz, is_training, trainable=True, reuse=False, networktype='ganD'):
4435
with tf.variable_scope(networktype, reuse=reuse):
45-
xz = concat_labels(xz, labels)
4636
Dxz = conv(xz, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='lrelu', norm=None, name='conv1') # 12
4737
Dxz = conv(Dxz, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='lrelu', norm='batchnorm', name='conv2') # 4
4838
Dxz = conv(Dxz, is_training, kernel_w=2, stride=2, Cout=256, trainable=trainable, act='lrelu', norm='batchnorm', name='conv3') # 2
49-
Dxz = conv(Dxz, is_training, kernel_w=2, stride=2, Cout=1, trainable=trainable, act='lrelu', norm='batchnorm', name='conv4') # 2
39+
Dxz = conv(Dxz, is_training, kernel_w=2, stride=2, Cout=1, trainable=trainable, act=None, norm='batchnorm', name='conv4') # 2
5040
Dxz = tf.nn.sigmoid(Dxz)
5141
return Dxz
5242

53-
def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan'):
43+
def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan', latentDim=100):
5444
'''Train a Generative Adversarial Network'''
55-
# with tf.name_scope('train_%s' % networktype):
5645
is_training = tf.placeholder(tf.bool, [], 'is_training')
5746

58-
inZ = tf.placeholder(tf.float32, [None, 100]) # tf.random_uniform(shape=[batch_size, 100], minval=-1., maxval=1., dtype=tf.float32)
59-
inL = tf.placeholder(tf.float32, [None, 10]) # we want to condition the generated out put on some parameters of the input
60-
inX = tf.placeholder(tf.float32, [None, 28, 28, 1])
47+
Zph = tf.placeholder(tf.float32, [None, latentDim]) # tf.random_uniform(shape=[batch_size, 100], minval=-1., maxval=1., dtype=tf.float32)
48+
Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])
6149

62-
Gz = create_gan_G(inZ, inL, is_training, Cout=1, trainable=True, reuse=False, networktype=networktype + '_G')
50+
Gout_op = create_gan_G(Zph, is_training, Cout=1, trainable=True, reuse=False, networktype=networktype + '_G')
6351

64-
DGz = create_gan_D(Gz, inL, is_training, trainable=True, reuse=False, networktype=networktype + '_D')
65-
Dx = create_gan_D(inX, inL, is_training, trainable=True, reuse=True, networktype=networktype + '_D')
52+
fakeLogits = create_gan_D(Gout_op, is_training, trainable=True, reuse=False, networktype=networktype + '_D')
53+
realLogits = create_gan_D(Xph, is_training, trainable=True, reuse=True, networktype=networktype + '_D')
6654

67-
ganG_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G')
68-
print(len(ganG_var_list), [var.name for var in ganG_var_list])
55+
G_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G')
56+
print(len(G_varlist), [var.name for var in G_varlist])
6957

70-
ganD_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D')
71-
print(len(ganD_var_list), [var.name for var in ganD_var_list])
58+
D_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D')
59+
print(len(D_varlist), [var.name for var in D_varlist])
7260

73-
Gscore = clipped_crossentropy(DGz, tf.ones_like(DGz))
74-
Dscore = clipped_crossentropy(DGz, tf.zeros_like(DGz)) + clipped_crossentropy(Dx, tf.ones_like(Dx))
61+
Gloss = clipped_crossentropy(fakeLogits, tf.ones_like(fakeLogits))
62+
Dloss = clipped_crossentropy(fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy(realLogits, tf.ones_like(realLogits))
7563

76-
Gtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gscore, var_list=ganG_var_list)
77-
Dtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dscore, var_list=ganD_var_list)
64+
Gtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gloss, var_list=G_varlist)
65+
Dtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dloss, var_list=D_varlist)
7866

79-
return Gtrain, Dtrain, Gscore, Dscore, is_training, inZ, inX, inL, Gz
67+
return Gtrain_op, Dtrain_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op
8068

8169
if __name__ == '__main__':
8270
networktype = 'DCGAN_MNIST'
8371

8472
batch_size = 128
85-
base_lr = 0.0002 # 1e-4
86-
epochs = 30
73+
base_lr = 2e-4
74+
epochs = 1000
75+
latentDim = 100
8776

8877
work_dir = expr_dir + '%s/%s/' % (networktype, datetime.strftime(datetime.today(), '%Y%m%d'))
8978
if not os.path.exists(work_dir): os.makedirs(work_dir)
9079

91-
data, max_iter, test_iter, test_int, disp_int = get_train_params(data_dir + '/' + networktype, batch_size, epochs=epochs, test_in_each_epoch=1, networktype=networktype)
80+
data = input_data.read_data_sets(data_dir + '/' + networktype, reshape=False)
81+
disp_int = 2 * int(data.train.num_examples / batch_size) # every two epochs
9282

9383
tf.reset_default_graph()
9484
sess = tf.InteractiveSession()
9585

96-
Gtrain, Dtrain, Gscore, Dscore, is_training, inZ, inX, inL, Gz = create_dcgan_trainer(base_lr, networktype=networktype)
86+
Gtrain_op, Dtrain_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op = create_dcgan_trainer(base_lr, networktype=networktype)
9787
tf.global_variables_initializer().run()
9888

9989
var_list = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ('adam' not in var.name.lower())]
100-
saver = tf.train.Saver(var_list=var_list, max_to_keep = 1000)
90+
saver = tf.train.Saver(var_list=var_list, max_to_keep=1000)
10191
# saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')
102-
103-
Z_test = np.random.uniform(size=[batch_size, 100], low=-1., high=1.).astype(np.float32)
104-
labels_test = OneHot(np.random.randint(10, size=[batch_size]), n=10)
105-
92+
10693
k = 1
107-
108-
for it in range(1, max_iter):
109-
Z = np.random.uniform(size=[batch_size, 100], low=-1., high=1.).astype(np.float32)
110-
X, labels = data.train.next_batch(batch_size)
111-
94+
it = 0
95+
disp_losses = False
96+
97+
while data.train.epochs_completed < epochs:
98+
dtemploss = 0
99+
112100
for itD in range(k):
113-
cur_Dscore, _ = sess.run([Dscore, Dtrain], feed_dict={inX:X, inZ:Z, inL:labels, is_training:True})
101+
it += 1
102+
Z = np.random.uniform(size=[batch_size, latentDim], low=-1., high=1.).astype(np.float32)
103+
X, _ = data.train.next_batch(batch_size)
104+
105+
cur_Dloss, _ = sess.run([Dloss, Dtrain_op], feed_dict={Xph:X, Zph:Z, is_training:True})
106+
dtemploss += cur_Dloss
114107

115-
cur_Gscore, _ = sess.run([Gscore, Gtrain], feed_dict={inZ:Z, inL:labels, is_training:True})
108+
if it % disp_int == 0:disp_losses = True
109+
110+
cur_Dloss = dtemploss / k
111+
112+
Z = np.random.uniform(size=[batch_size, latentDim], low=-1., high=1.).astype(np.float32)
113+
cur_Gscore, _ = sess.run([Gloss, Gtrain_op], feed_dict={Zph:Z, is_training:True})
116114

117-
if it % disp_int == 0:
118-
Gz_sample = sess.run(Gz, feed_dict={inZ: Z_test, inL: labels_test, is_training:False})
119-
vis_square(Gz_sample[:121], [11, 11], save_path=work_dir + 'Iter_%d.jpg' % it)
120-
saver.save(sess, work_dir + "%.3d_model.ckpt" % it)
121-
if ('cur_Dscore' in vars()) and ('cur_Gscore' in vars()):
122-
print("Iteration #%4d, Train Gscore = %f, Dscore=%f" % (it, cur_Gscore, cur_Dscore))
115+
if disp_losses:
116+
Gsample = sess.run(Gout_op, feed_dict={Zph: Z, is_training:False})
117+
vis_square(Gsample[:121], [11, 11], save_path=work_dir + 'Epoch%.3d.jpg' % data.train.epochs_completed)
118+
saver.save(sess, work_dir + "%.3d_model.ckpt" % data.train.epochs_completed)
119+
print("Epoch #%.3d, Train Gloss = %f, Dloss=%f" % (data.train.epochs_completed, cur_Gloss, cur_Dloss))
120+
disp_losses = False

0 commit comments

Comments
 (0)