Skip to content

Commit

Permalink
try cnn_encoder_resnet for E
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Feb 3, 2017
1 parent a7c19d0 commit 7051e1e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 63 deletions.
54 changes: 0 additions & 54 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,60 +603,6 @@ def stackG_256(inputs, net_rnn, is_train, reuse):
# exit(network.outputs)
return network, logits

# def stackD_64(input_images, net_rnn_embed=None, is_train=True, reuse=False): # same as discriminator_txt2img
# # IMPLEMENTATION based on : https://github.com/paarthneekhara/text-to-image/blob/master/model.py
# # https://github.com/reedscot/icml2016/blob/master/main_cls_int.lua
# w_init = tf.random_normal_initializer(stddev=0.02)
# b_init = None # tf.constant_initializer(value=0.0)
# gamma_init=tf.random_normal_initializer(1., 0.02)
#
# with tf.variable_scope("stackD", reuse=reuse):
# tl.layers.set_name_reuse(reuse)
#
# net_in = InputLayer(input_images, name='stackD_input/images')
# net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2),
# padding='SAME', W_init=w_init, name='stackD_h0/conv2d') # (64, 32, 32, 64)
#
# net_h1 = Conv2d(net_h0, df_dim*2, (5, 5), (2, 2), act=None,
# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h1/conv2d')
# net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
# is_train=is_train, gamma_init=gamma_init, name='stackD_h1/batchnorm') # (64, 16, 16, 128)
#
# net_h2 = Conv2d(net_h1, df_dim*4, (5, 5), (2, 2), act=None,
# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h2/conv2d')
# net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
# is_train=is_train, gamma_init=gamma_init, name='stackD_h2/batchnorm') # (64, 8, 8, 256)
#
# net_h3 = Conv2d(net_h2, df_dim*8, (5, 5), (2, 2), act=None,
# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h3/conv2d')
# net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
# is_train=is_train, gamma_init=gamma_init, name='stackD_h3/batchnorm') # (64, 4, 4, 512) paper 4.1: when the spatial dim of the D is 4x4, we replicate the description embedding spatially and perform a depth concatenation
#
# if net_rnn_embed is not None:
# # paper : reduce the dim of description embedding in (seperate) FC layer followed by rectification
# net_reduced_text = DenseLayer(net_rnn_embed, n_units=t_dim,
# act=lambda x: tl.act.lrelu(x, 0.2),
# W_init=w_init, b_init=None, name='stackD_reduce_txt/dense')
# # net_reduced_text = net_rnn_embed # if reduce_txt in rnn_embed
# net_reduced_text.outputs = tf.expand_dims(net_reduced_text.outputs, 1) # you can use ExpandDimsLayer and TileLayer instead
# net_reduced_text.outputs = tf.expand_dims(net_reduced_text.outputs, 2)
# net_reduced_text.outputs = tf.tile(net_reduced_text.outputs, [1, 4, 4, 1], name='stackD_tiled_embeddings')
#
# net_h3_concat = ConcatLayer([net_h3, net_reduced_text], concat_dim=3, name='stackD_h3_concat') # (64, 4, 4, 640)
# # net_h3_concat = net_h3 # no text info
# net_h3 = Conv2d(net_h3_concat, df_dim*8, (1, 1), (1, 1),
# padding='SAME', W_init=w_init, b_init=b_init, name='stackD_h3/conv2d_2') # paper 4.1: perform 1x1 conv followed by rectification and a 4x4 conv to compute the final score from D
# net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
# is_train=is_train, gamma_init=gamma_init, name='stackD_h3/batch_norm_2') # (64, 4, 4, 512)
# else:
# print("No text info will be used, i.e. normal DCGAN")
#
# net_h4 = FlattenLayer(net_h3, name='stackD_h4/flatten') # (64, 8192)
# net_h4 = DenseLayer(net_h4, n_units=1, act=tf.identity,
# W_init = w_init, name='stackD_h4/dense')
# logits = net_h4.outputs
# net_h4.outputs = tf.nn.sigmoid(net_h4.outputs) # (64, 1)
# return net_h4, logits

def stackD_256(input_images, net_rnn_embed=None, is_train=True, reuse=False): # same as discriminator_txt2img
""" 256x256 -> real fake """
Expand Down
22 changes: 13 additions & 9 deletions train_uim2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,14 @@ def main_train_imageEncoder():
# E_256, 2000: 0.87 6000: 0.8 10000: 0.77 13172: 0.76
# E_256, resid 57720: 0.72
is_stackGAN = True # use stackGAN and use E with 256x256x3 input
is_weighted_loss = False # use weighted loss
is_weighted_loss = True # use weighted loss

if is_stackGAN:
stackG = model.stackG_256
stackD = model.stackD_256
# cnn_encoder = model.cnn_encoder_256
cnn_encoder = model.cnn_encoder # if use DownSampling2dLayer
# cnn_encoder = model.cnn_encoder_256 # if 256 input
# cnn_encoder = model.cnn_encoder # if 64 input
cnn_encoder = model.cnn_encoder_resnet
else:
cnn_encoder = model.cnn_encoder

Expand Down Expand Up @@ -522,13 +523,13 @@ def main_train_imageEncoder():
tl.files.save_npz(net_p.all_params, name=net_p_name, sess=sess)
# tl.files.save_npz(net_p.all_params, name=net_p_name + "_" + str(step), sess=sess)


def main_translation():
is_stackGAN = True # use stackGAN and use E with 256x256x3 input, otherwise, 64x64x3 as input
if is_stackGAN:
stackG = model.stackG_256
image_size = 64 # for 64 input
cnn_encoder = model.cnn_encoder # for 64 input
cnn_encoder = model.cnn_encoder_resnet # for 64 input
# image_size = 256 # for 256 input
# cnn_encoder = model.cnn_encoder_256 # for 256 input
# images_test = images_test_256 # for 256 input
Expand Down Expand Up @@ -619,11 +620,14 @@ def main_translation():
b_caption = captions_ids_test[idexs] # for debug sample_sentence = b_caption
b_caption = tl.prepro.pad_sequences(b_caption, padding='post') # for debug sample_sentence = b_caption

# b_z = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) # use fake image
# b_images = sess.run(net_g2.outputs, feed_dict={ # use fake image
# t_z : b_z, # use fake image
# t_caption : b_caption, # use fake image
# }) # use fake image
b_z = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32) # use fake image
b_images = sess.run(net_g2.outputs, feed_dict={ # use fake image
t_z : b_z, # use fake image
t_caption : b_caption, # use fake image
}) # use fake image
if is_stackGAN:
b_images = threading_data(b_images, imresize, size=[64, 64], interp='bilinear')
b_images = threading_data(b_images, prepro_img, mode='translation')

# sample_sentence = change_id(b_caption, color_ids, vocab.word_to_id("yellow"))
sample_sentence = b_caption # reconstruct from same sentences, test performance of reconstruction
Expand Down

0 comments on commit 7051e1e

Please sign in to comment.