Skip to content

Commit 518cf58

Browse files
committed
first wasserstein gan implementation
notice the different operations in the training loop w.r.t. the algorithem in the original paper
1 parent 266af84 commit 518cf58

File tree

1 file changed

+128
-0
lines changed
  • GenerativeAdversarialNetworks

1 file changed

+128
-0
lines changed

GenerativeAdversarialNetworks/WGAN.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# -*- coding: utf-8 -*-
2+
'''
3+
Wasserstein GAN - Arjovsky et al. 2017
4+
5+
This work is absolutely not an effort to reproduce exact results of the cited paper, nor I confine my implementations to the suggestion of the original authors.
6+
I have tried to implement my own limited understanding of the original paper in hope to get a better insight into their work.
7+
Use this code with no warranty and please respect the accompanying license.
8+
'''
9+
10+
import sys
11+
sys.path.append('../common')
12+
13+
from tools_config import data_dir, expr_dir
14+
import os
15+
import matplotlib.pyplot as plt
16+
from tools_train import get_train_params, OneHot, vis_square
17+
from datetime import datetime
18+
from tools_general import tf, np
19+
from tools_networks import deconv, conv, dense, clipped_crossentropy, dropout
20+
21+
def concat_labels(X, labels):
22+
if X.get_shape().ndims == 4:
23+
X_shape = tf.shape(X)
24+
labels_reshaped = tf.reshape(labels, [-1, 1, 1, 10])
25+
a = tf.ones([X_shape[0], X_shape[1], X_shape[2], 10])
26+
X = tf.concat([X, labels_reshaped * a], axis=3)
27+
return X
28+
29+
def create_gan_G(z, labels, is_training, Cout=1, trainable=True, reuse=False, networktype='ganG'):
30+
'''input : batchsize * 100 and labels to make the generator conditional
31+
output: batchsize * 28 * 28 * 1'''
32+
with tf.variable_scope(networktype, reuse=reuse):
33+
z = tf.concat(axis=-1, values=[z, labels])
34+
Gout_op = dense(z, is_training, Cout=4 * 4 * 256, act='reLu', norm='batchnorm', name='dense1')
35+
Gout_op = tf.reshape(Gout_op, shape=[-1, 4, 4, 256]) # 4
36+
Gout_op = deconv(Gout_op, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='reLu', norm='batchnorm', name='deconv1') # 11
37+
Gout_op = deconv(Gout_op, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='reLu', norm='batchnorm', name='deconv2') # 25
38+
Gout_op = deconv(Gout_op, is_training, kernel_w=4, stride=Cout, Cout=1, act=None, norm=None, name='deconv3') # 28
39+
Gout_op = tf.nn.sigmoid(Gout_op)
40+
return Gout_op
41+
42+
def create_gan_D(xz, labels, is_training, trainable=True, reuse=False, networktype='ganD'):
43+
with tf.variable_scope(networktype, reuse=reuse):
44+
xz = concat_labels(xz, labels)
45+
Dxz = conv(xz, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='lrelu', norm=None, name='conv1') # 12
46+
Dxz = conv(Dxz, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='lrelu', norm='batchnorm', name='conv2') # 4
47+
Dxz = conv(Dxz, is_training, kernel_w=2, stride=2, Cout=256, trainable=trainable, act='lrelu', norm='batchnorm', name='conv3') # 2
48+
Dxz = conv(Dxz, is_training, kernel_w=2, stride=2, Cout=1, trainable=trainable, act='lrelu', norm='batchnorm', name='conv4') # 2
49+
Dxz = tf.nn.sigmoid(Dxz)
50+
return Dxz
51+
52+
def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan'):
53+
'''Train a Generative Adversarial Network'''
54+
55+
is_training = tf.placeholder(tf.bool, [], 'is_training')
56+
57+
Zph = tf.placeholder(tf.float32, [None, 100])
58+
Lph = tf.placeholder(tf.float32, [None, 10]) # we want to condition the generated out put on some parameters of the input
59+
Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])
60+
61+
Gout_op = create_gan_G(Zph, Lph, is_training, Cout=1, trainable=True, reuse=False, networktype=networktype + '_G')
62+
63+
fakeLogits_op = create_gan_D(Gout_op, Lph, is_training, trainable=True, reuse=False, networktype=networktype + '_D')
64+
realLogits_op = create_gan_D(Xph, Lph, is_training, trainable=True, reuse=True, networktype=networktype + '_D')
65+
66+
ganG_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G')
67+
print(len(ganG_var_list), [var.name for var in ganG_var_list])
68+
69+
ganD_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D')
70+
print(len(ganD_var_list), [var.name for var in ganD_var_list])
71+
72+
Dscore = tf.reduce_mean(realLogits_op - fakeLogits_op)
73+
Gscore = tf.reduce_mean(fakeLogits_op)
74+
75+
D_weights = [var for var in ganD_var_list if '_W' in var.name]
76+
D_weights_clip_op = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in D_weights]
77+
78+
#Gtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gscore, var_list=ganG_var_list)
79+
#Dtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dscore, var_list=ganD_var_list)
80+
81+
Gtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Gscore, var_list=ganG_var_list)
82+
Dtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Dscore, var_list=ganD_var_list)
83+
84+
return Gtrain_op, Dtrain_op, D_weights_clip_op, Gscore, Dscore, is_training, Zph, Xph, Lph, Gout_op
85+
86+
if __name__ == '__main__':
87+
networktype = 'WGAN_MNIST'
88+
89+
batch_size = 128
90+
base_lr = 5e-5 # 1e-4
91+
epochs = 300
92+
93+
work_dir = expr_dir + '%s/%s/' % (networktype, datetime.strftime(datetime.today(), '%Y%m%d'))
94+
if not os.path.exists(work_dir): os.makedirs(work_dir)
95+
96+
data, max_iter, test_iter, test_int, disp_int = get_train_params(data_dir + '/' + networktype, batch_size, epochs=epochs, test_in_each_epoch=1, networktype=networktype)
97+
98+
tf.reset_default_graph()
99+
sess = tf.InteractiveSession()
100+
101+
Gtrain_op, Dtrain_op, D_weights_clip_op, Gscore, Dscore, is_training, Zph, Xph, Lph, Gout_op = create_dcgan_trainer(base_lr, networktype=networktype)
102+
tf.global_variables_initializer().run()
103+
104+
var_list = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ('adam' not in var.name.lower())]
105+
saver = tf.train.Saver(var_list=var_list, max_to_keep=1000)
106+
# saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')
107+
108+
Z_test = np.random.uniform(size=[batch_size, 100], low=-1., high=1.).astype(np.float32)
109+
labels_test = OneHot(np.random.randint(10, size=[batch_size]), n=10)
110+
111+
k = 5
112+
113+
for it in range(1, max_iter):
114+
Z = np.random.uniform(size=[batch_size, 100], low=-1., high=1.).astype(np.float32)
115+
X, labels = data.train.next_batch(batch_size)
116+
117+
for itD in range(k):
118+
cur_Dscore, _ = sess.run([Dscore, Dtrain_op], feed_dict={Xph:X, Zph:Z, Lph:labels, is_training:True})
119+
sess.run(D_weights_clip_op)
120+
121+
cur_Gscore, _ = sess.run([Gscore, Gtrain_op], feed_dict={Zph:Z, Lph:labels, is_training:True})
122+
123+
if it % disp_int == 0:
124+
Gz_sample = sess.run(Gout_op, feed_dict={Zph: Z_test, Lph: labels_test, is_training:False})
125+
vis_square(Gz_sample[:121], [11, 11], save_path=work_dir + 'Iter_%d.jpg' % it)
126+
saver.save(sess, work_dir + "%.3d_model.ckpt" % it)
127+
if ('cur_Dscore' in vars()) and ('cur_Gscore' in vars()):
128+
print("Iteration #%4d, Train Gscore = %f, Dscore=%f" % (it, cur_Gscore, cur_Dscore))

0 commit comments

Comments
 (0)