|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +''' |
| 3 | +Wasserstein GAN - Arjovsky et al. 2017 |
| 4 | +
|
| 5 | +This work is absolutely not an effort to reproduce exact results of the cited paper, nor I confine my implementations to the suggestion of the original authors. |
| 6 | +I have tried to implement my own limited understanding of the original paper in hope to get a better insight into their work. |
| 7 | +Use this code with no warranty and please respect the accompanying license. |
| 8 | +''' |
| 9 | + |
| 10 | +import sys |
| 11 | +sys.path.append('../common') |
| 12 | + |
| 13 | +from tools_config import data_dir, expr_dir |
| 14 | +import os |
| 15 | +import matplotlib.pyplot as plt |
| 16 | +from tools_train import get_train_params, OneHot, vis_square |
| 17 | +from datetime import datetime |
| 18 | +from tools_general import tf, np |
| 19 | +from tools_networks import deconv, conv, dense, clipped_crossentropy, dropout |
| 20 | + |
| 21 | +def concat_labels(X, labels): |
| 22 | + if X.get_shape().ndims == 4: |
| 23 | + X_shape = tf.shape(X) |
| 24 | + labels_reshaped = tf.reshape(labels, [-1, 1, 1, 10]) |
| 25 | + a = tf.ones([X_shape[0], X_shape[1], X_shape[2], 10]) |
| 26 | + X = tf.concat([X, labels_reshaped * a], axis=3) |
| 27 | + return X |
| 28 | + |
| 29 | +def create_gan_G(z, labels, is_training, Cout=1, trainable=True, reuse=False, networktype='ganG'): |
| 30 | + '''input : batchsize * 100 and labels to make the generator conditional |
| 31 | + output: batchsize * 28 * 28 * 1''' |
| 32 | + with tf.variable_scope(networktype, reuse=reuse): |
| 33 | + z = tf.concat(axis=-1, values=[z, labels]) |
| 34 | + Gout_op = dense(z, is_training, Cout=4 * 4 * 256, act='reLu', norm='batchnorm', name='dense1') |
| 35 | + Gout_op = tf.reshape(Gout_op, shape=[-1, 4, 4, 256]) # 4 |
| 36 | + Gout_op = deconv(Gout_op, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='reLu', norm='batchnorm', name='deconv1') # 11 |
| 37 | + Gout_op = deconv(Gout_op, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='reLu', norm='batchnorm', name='deconv2') # 25 |
| 38 | + Gout_op = deconv(Gout_op, is_training, kernel_w=4, stride=Cout, Cout=1, act=None, norm=None, name='deconv3') # 28 |
| 39 | + Gout_op = tf.nn.sigmoid(Gout_op) |
| 40 | + return Gout_op |
| 41 | + |
| 42 | +def create_gan_D(xz, labels, is_training, trainable=True, reuse=False, networktype='ganD'): |
| 43 | + with tf.variable_scope(networktype, reuse=reuse): |
| 44 | + xz = concat_labels(xz, labels) |
| 45 | + Dxz = conv(xz, is_training, kernel_w=5, stride=2, Cout=128, trainable=trainable, act='lrelu', norm=None, name='conv1') # 12 |
| 46 | + Dxz = conv(Dxz, is_training, kernel_w=5, stride=2, Cout=256, trainable=trainable, act='lrelu', norm='batchnorm', name='conv2') # 4 |
| 47 | + Dxz = conv(Dxz, is_training, kernel_w=2, stride=2, Cout=256, trainable=trainable, act='lrelu', norm='batchnorm', name='conv3') # 2 |
| 48 | + Dxz = conv(Dxz, is_training, kernel_w=2, stride=2, Cout=1, trainable=trainable, act='lrelu', norm='batchnorm', name='conv4') # 2 |
| 49 | + Dxz = tf.nn.sigmoid(Dxz) |
| 50 | + return Dxz |
| 51 | + |
| 52 | +def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan'): |
| 53 | + '''Train a Generative Adversarial Network''' |
| 54 | + |
| 55 | + is_training = tf.placeholder(tf.bool, [], 'is_training') |
| 56 | + |
| 57 | + Zph = tf.placeholder(tf.float32, [None, 100]) |
| 58 | + Lph = tf.placeholder(tf.float32, [None, 10]) # we want to condition the generated out put on some parameters of the input |
| 59 | + Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) |
| 60 | + |
| 61 | + Gout_op = create_gan_G(Zph, Lph, is_training, Cout=1, trainable=True, reuse=False, networktype=networktype + '_G') |
| 62 | + |
| 63 | + fakeLogits_op = create_gan_D(Gout_op, Lph, is_training, trainable=True, reuse=False, networktype=networktype + '_D') |
| 64 | + realLogits_op = create_gan_D(Xph, Lph, is_training, trainable=True, reuse=True, networktype=networktype + '_D') |
| 65 | + |
| 66 | + ganG_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') |
| 67 | + print(len(ganG_var_list), [var.name for var in ganG_var_list]) |
| 68 | + |
| 69 | + ganD_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') |
| 70 | + print(len(ganD_var_list), [var.name for var in ganD_var_list]) |
| 71 | + |
| 72 | + Dscore = tf.reduce_mean(realLogits_op - fakeLogits_op) |
| 73 | + Gscore = tf.reduce_mean(fakeLogits_op) |
| 74 | + |
| 75 | + D_weights = [var for var in ganD_var_list if '_W' in var.name] |
| 76 | + D_weights_clip_op = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in D_weights] |
| 77 | + |
| 78 | + #Gtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gscore, var_list=ganG_var_list) |
| 79 | + #Dtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dscore, var_list=ganD_var_list) |
| 80 | + |
| 81 | + Gtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Gscore, var_list=ganG_var_list) |
| 82 | + Dtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Dscore, var_list=ganD_var_list) |
| 83 | + |
| 84 | + return Gtrain_op, Dtrain_op, D_weights_clip_op, Gscore, Dscore, is_training, Zph, Xph, Lph, Gout_op |
| 85 | + |
| 86 | +if __name__ == '__main__': |
| 87 | + networktype = 'WGAN_MNIST' |
| 88 | + |
| 89 | + batch_size = 128 |
| 90 | + base_lr = 5e-5 # 1e-4 |
| 91 | + epochs = 300 |
| 92 | + |
| 93 | + work_dir = expr_dir + '%s/%s/' % (networktype, datetime.strftime(datetime.today(), '%Y%m%d')) |
| 94 | + if not os.path.exists(work_dir): os.makedirs(work_dir) |
| 95 | + |
| 96 | + data, max_iter, test_iter, test_int, disp_int = get_train_params(data_dir + '/' + networktype, batch_size, epochs=epochs, test_in_each_epoch=1, networktype=networktype) |
| 97 | + |
| 98 | + tf.reset_default_graph() |
| 99 | + sess = tf.InteractiveSession() |
| 100 | + |
| 101 | + Gtrain_op, Dtrain_op, D_weights_clip_op, Gscore, Dscore, is_training, Zph, Xph, Lph, Gout_op = create_dcgan_trainer(base_lr, networktype=networktype) |
| 102 | + tf.global_variables_initializer().run() |
| 103 | + |
| 104 | + var_list = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ('adam' not in var.name.lower())] |
| 105 | + saver = tf.train.Saver(var_list=var_list, max_to_keep=1000) |
| 106 | + # saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt') |
| 107 | + |
| 108 | + Z_test = np.random.uniform(size=[batch_size, 100], low=-1., high=1.).astype(np.float32) |
| 109 | + labels_test = OneHot(np.random.randint(10, size=[batch_size]), n=10) |
| 110 | + |
| 111 | + k = 5 |
| 112 | + |
| 113 | + for it in range(1, max_iter): |
| 114 | + Z = np.random.uniform(size=[batch_size, 100], low=-1., high=1.).astype(np.float32) |
| 115 | + X, labels = data.train.next_batch(batch_size) |
| 116 | + |
| 117 | + for itD in range(k): |
| 118 | + cur_Dscore, _ = sess.run([Dscore, Dtrain_op], feed_dict={Xph:X, Zph:Z, Lph:labels, is_training:True}) |
| 119 | + sess.run(D_weights_clip_op) |
| 120 | + |
| 121 | + cur_Gscore, _ = sess.run([Gscore, Gtrain_op], feed_dict={Zph:Z, Lph:labels, is_training:True}) |
| 122 | + |
| 123 | + if it % disp_int == 0: |
| 124 | + Gz_sample = sess.run(Gout_op, feed_dict={Zph: Z_test, Lph: labels_test, is_training:False}) |
| 125 | + vis_square(Gz_sample[:121], [11, 11], save_path=work_dir + 'Iter_%d.jpg' % it) |
| 126 | + saver.save(sess, work_dir + "%.3d_model.ckpt" % it) |
| 127 | + if ('cur_Dscore' in vars()) and ('cur_Gscore' in vars()): |
| 128 | + print("Iteration #%4d, Train Gscore = %f, Dscore=%f" % (it, cur_Gscore, cur_Dscore)) |
0 commit comments