From 7051e1e3b5e217c7d3322afcf9c2c736d53fa8f7 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Fri, 3 Feb 2017 13:31:18 +0000 Subject: [PATCH] try cnn_encoder_resnet for E --- model.py | 54 ------------------------------------------------- train_uim2im.py | 22 +++++++++++--------- 2 files changed, 13 insertions(+), 63 deletions(-) diff --git a/model.py b/model.py index 56c35fc2..26d3e62d 100755 --- a/model.py +++ b/model.py @@ -603,60 +603,6 @@ def stackG_256(inputs, net_rnn, is_train, reuse): # exit(network.outputs) return network, logits -# def stackD_64(input_images, net_rnn_embed=None, is_train=True, reuse=False): # same as discriminator_txt2img -# # IMPLEMENTATION based on : https://github.com/paarthneekhara/text-to-image/blob/master/model.py -# # https://github.com/reedscot/icml2016/blob/master/main_cls_int.lua -# w_init = tf.random_normal_initializer(stddev=0.02) -# b_init = None # tf.constant_initializer(value=0.0) -# gamma_init=tf.random_normal_initializer(1., 0.02) -# -# with tf.variable_scope("stackD", reuse=reuse): -# tl.layers.set_name_reuse(reuse) -# -# net_in = InputLayer(input_images, name='stackD_input/images') -# net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), -# padding='SAME', W_init=w_init, name='stackD_h0/conv2d') # (64, 32, 32, 64) -# -# net_h1 = Conv2d(net_h0, df_dim*2, (5, 5), (2, 2), act=None, -# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h1/conv2d') -# net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2), -# is_train=is_train, gamma_init=gamma_init, name='stackD_h1/batchnorm') # (64, 16, 16, 128) -# -# net_h2 = Conv2d(net_h1, df_dim*4, (5, 5), (2, 2), act=None, -# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h2/conv2d') -# net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2), -# is_train=is_train, gamma_init=gamma_init, name='stackD_h2/batchnorm') # (64, 8, 8, 256) -# -# net_h3 = Conv2d(net_h2, df_dim*8, (5, 5), (2, 2), act=None, -# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h3/conv2d') -# net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2), -# is_train=is_train, gamma_init=gamma_init, name='stackD_h3/batchnorm') # (64, 4, 4, 512) paper 4.1: when the spatial dim of the D is 4x4, we replicate the description embedding spatially and perform a depth concatenation -# -# if net_rnn_embed is not None: -# # paper : reduce the dim of description embedding in (seperate) FC layer followed by rectification -# net_reduced_text = DenseLayer(net_rnn_embed, n_units=t_dim, -# act=lambda x: tl.act.lrelu(x, 0.2), -# W_init=w_init, b_init=None, name='stackD_reduce_txt/dense') -# # net_reduced_text = net_rnn_embed # if reduce_txt in rnn_embed -# net_reduced_text.outputs = tf.expand_dims(net_reduced_text.outputs, 1) # you can use ExpandDimsLayer and TileLayer instead -# net_reduced_text.outputs = tf.expand_dims(net_reduced_text.outputs, 2) -# net_reduced_text.outputs = tf.tile(net_reduced_text.outputs, [1, 4, 4, 1], name='stackD_tiled_embeddings') -# -# net_h3_concat = ConcatLayer([net_h3, net_reduced_text], concat_dim=3, name='stackD_h3_concat') # (64, 4, 4, 640) -# # net_h3_concat = net_h3 # no text info -# net_h3 = Conv2d(net_h3_concat, df_dim*8, (1, 1), (1, 1), -# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h3/conv2d_2') # paper 4.1: perform 1x1 conv followed by rectification and a 4x4 conv to compute the final score from D -# net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2), -# is_train=is_train, gamma_init=gamma_init, name='stackD_h3/batch_norm_2') # (64, 4, 4, 512) -# else: -# print("No text info will be used, i.e. normal DCGAN") -# -# net_h4 = FlattenLayer(net_h3, name='stackD_h4/flatten') # (64, 8192) -# net_h4 = DenseLayer(net_h4, n_units=1, act=tf.identity, -# W_init = w_init, name='stackD_h4/dense') -# logits = net_h4.outputs -# net_h4.outputs = tf.nn.sigmoid(net_h4.outputs) # (64, 1) -# return net_h4, logits def stackD_256(input_images, net_rnn_embed=None, is_train=True, reuse=False): # same as discriminator_txt2img """ 256x256 -> real fake """ diff --git a/train_uim2im.py b/train_uim2im.py index 33aad194..95bd75f5 100755 --- a/train_uim2im.py +++ b/train_uim2im.py @@ -303,13 +303,14 @@ def main_train_imageEncoder(): # E_256, 2000: 0.87 6000: 0.8 10000: 0.77 13172: 0.76 # E_256, resid 57720: 0.72 is_stackGAN = True # use stackGAN and use E with 256x256x3 input - is_weighted_loss = False # use weighted loss + is_weighted_loss = True # use weighted loss if is_stackGAN: stackG = model.stackG_256 stackD = model.stackD_256 - # cnn_encoder = model.cnn_encoder_256 - cnn_encoder = model.cnn_encoder # if use DownSampling2dLayer + # cnn_encoder = model.cnn_encoder_256 # if 256 input + # cnn_encoder = model.cnn_encoder # if 64 input + cnn_encoder = model.cnn_encoder_resnet else: cnn_encoder = model.cnn_encoder @@ -522,13 +523,13 @@ def main_train_imageEncoder(): tl.files.save_npz(net_p.all_params, name=net_p_name, sess=sess) # tl.files.save_npz(net_p.all_params, name=net_p_name + "_" + str(step), sess=sess) - def main_translation(): is_stackGAN = True # use stackGAN and use E with 256x256x3 input, otherwise, 64x64x3 as input if is_stackGAN: stackG = model.stackG_256 image_size = 64 # for 64 input cnn_encoder = model.cnn_encoder # for 64 input + cnn_encoder = model.cnn_encoder_resnet # for 64 input # image_size = 256 # for 256 input # cnn_encoder = model.cnn_encoder_256 # for 256 input # images_test = images_test_256 # for 256 input @@ -619,11 +620,14 @@ def main_translation(): b_caption = captions_ids_test[idexs] # for debug sample_sentence = b_caption b_caption = tl.prepro.pad_sequences(b_caption, padding='post') # for debug sample_sentence = b_caption - # b_z = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) # use fake image - # b_images = sess.run(net_g2.outputs, feed_dict={ # use fake image - # t_z : b_z, # use fake image - # t_caption : b_caption, # use fake image - # }) # use fake image + b_z = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) # use fake image + b_images = sess.run(net_g2.outputs, feed_dict={ # use fake image + t_z : b_z, # use fake image + t_caption : b_caption, # use fake image + }) # use fake image + if is_stackGAN: + b_images = threading_data(b_images, imresize, size=[64, 64], interp='bilinear') + b_images = threading_data(b_images, prepro_img, mode='translation') # sample_sentence = change_id(b_caption, color_ids, vocab.word_to_id("yellow")) sample_sentence = b_caption # reconstruct from same sentences, test performance of reconstruction