Skip to content

Commit

Permalink
DONE
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Jan 19, 2017
1 parent 76f2ef0 commit e5b95c3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 62 deletions.
6 changes: 3 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
## GAN for text to img =========================================================
batch_size = 64
vocab_size = 8000
word_embedding_size = 512 # paper said 1024 char-CNN-RNN
rnn_hidden_size = 256
keep_prob = 0.7
word_embedding_size = 256 # paper said 1024 char-CNN-RNN
rnn_hidden_size = 128#256
keep_prob = 1.0
z_dim = 100 # Noise dimension
t_dim = 128 # Text feature dimension # paper said 128
image_size = 64 # 64 x 64
Expand Down
21 changes: 16 additions & 5 deletions train_dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

pp = pprint.PrettyPrinter()

os.system("mkdir samples")
"""
TensorLayer implementation of DCGAN to generate image.
Expand All @@ -46,6 +45,12 @@
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
FLAGS = flags.FLAGS

os.system("mkdir samples")
os.system("mkdir checkpoint")
os.system("mkdir samples/"+FLAGS.dataset+"_dcgan")
os.system("mkdir checkpoint/"+FLAGS.dataset+"_dcgan")


def merge(images, size):
h, w = images.shape[1], images.shape[2]
img = np.zeros((h * size[0], w * size[1], 3))
Expand Down Expand Up @@ -97,9 +102,14 @@ def save_images(images, size, image_path):

## store all captions ids in list
captions_ids = []
for key, value in captions_dict.iteritems():
try: # python3
tmp = captions_dict.items()
except: # python3
tmp = captions_dict.iteritems()
for key, value in tmp:
# for key, value in captions_dict.iteritems():
for v in value:
captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v)] )
captions_ids.append( [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(v) + [vocab.end_id]])
# print(v) # prominent purple stigma,petals are white inc olor
# print(captions_ids) # [[152, 19, 33, 15, 3, 8, 14, 719, 723]]
# exit()
Expand Down Expand Up @@ -180,7 +190,8 @@ def main(_):
.minimize(g_loss, var_list=g_vars)

sess=tf.Session()
sess.run(tf.initialize_all_variables())
# sess.run(tf.initialize_all_variables())
tl.layers.initialize_global_variables(sess)

# load checkpoints
print("[*] Loading checkpoints...")
Expand All @@ -204,7 +215,7 @@ def main(_):
for epoch in range(FLAGS.epoch):
idexs = get_random_int(min=0, max=n_captions-1, number=FLAGS.batch_size)
sample_images = images[np.floor(np.asarray(idexs).astype('float')/n_captions_per_image).astype('int')]
print("[*]Sample images updated!")
print("[*] Sample images updated!")

batch_idxs = int(n_images / FLAGS.batch_size)

Expand Down
86 changes: 33 additions & 53 deletions train_txt2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from utils import *
from model import *


os.system("mkdir samples")
os.system("mkdir checkpoint")

""" Generative Adversarial Text to Image Synthesis
Expand Down Expand Up @@ -78,7 +78,7 @@
captions_ids = []
try: # python3
tmp = captions_dict.items()
except:
except: # python3
tmp = captions_dict.iteritems()
for key, value in tmp:
for v in value:
Expand Down Expand Up @@ -163,31 +163,6 @@
###======================== DEFIINE MODEL ===================================###
## define data augmentation method
from tensorlayer.prepro import *
def prepro_img(x, mode=None):
if mode=='train':
# rescale [0, 255] --> (-1, 1), random flip, crop, rotate
# paper 5.1: During mini-batch selection for training we randomly pick
# an image view (e.g. crop, flip) of the image and one of the captions
# flip, rotate, crop, resize : https://github.com/reedscot/icml2016/blob/master/data/donkey_folder_coco.lua
# flip : https://github.com/paarthneekhara/text-to-image/blob/master/Utils/image_processing.py
# x = flip_axis(x, axis=1, is_random=True)
# x = rotation(x, rg=16, is_random=True, fill_mode='nearest')
# x = crop(x, wrg=50, hrg=50, is_random=True)
# x = imresize(x, size=[64, 64], interp='bilinear', mode=None)
x = x / (255. / 2.)
x = x - 1.
elif mode=='rescale':
# rescale (-1, 1) --> (0, 1) for display
x = (x + 1.) / 2.
elif mode=='debug':
x = flip_axis(x, axis=1, is_random=False)
# x = rotation(x, rg=16, is_random=False, fill_mode='nearest')
# x = crop(x, wrg=50, hrg=50, is_random=True)
# x = imresize(x, size=[64, 64], interp='bilinear', mode=None)
x = x / 255.
else:
raise Exception("Not support : %s" % mode)
return x

## you may want to see how the data augmentation work
# save_images(images[:64], [8, 8], 'temp.png')
Expand Down Expand Up @@ -290,50 +265,54 @@ def prepro_img(x, mode=None):
if not os.path.exists(save_dir):
print("[!] Folder (%s) is not exist, creating it ..." % save_dir)
os.mkdir(save_dir)

# load the latest checkpoints
net_e_name = os.path.join(save_dir, 'net_e.npz')
net_c_name = os.path.join(save_dir, 'net_c.npz')
net_g_name = os.path.join(save_dir, 'net_g.npz')
net_d_name = os.path.join(save_dir, 'net_d.npz')
if not (os.path.exists(net_e_name) and os.path.exists(net_c_name)):
print("[!] Loading RNN and CNN checkpoints failed!")
else:
net_c_loaded_params = tl.files.load_npz(name=net_c_name)
net_e_loaded_params = tl.files.load_npz(name=net_e_name)
tl.files.assign_params(sess, net_c_loaded_params, net_cnn)
tl.files.assign_params(sess, net_e_loaded_params, net_rnn)
print("[*] Loading RNN and CNN checkpoints SUCCESS!")

if not (os.path.exists(net_g_name) and os.path.exists(net_d_name)):
print("[!] Loading G and D checkpoints failed!")
else:
net_g_loaded_params = tl.files.load_npz(name=net_g_name)
net_d_loaded_params = tl.files.load_npz(name=net_d_name)
tl.files.assign_params(sess, net_g_loaded_params, net_g)
tl.files.assign_params(sess, net_d_loaded_params, net_d)
print("[*] Loading G and D checkpoints SUCCESS!")
if False:
if not (os.path.exists(net_e_name) and os.path.exists(net_c_name)):
print("[!] Loading RNN and CNN checkpoints failed!")
else:
net_c_loaded_params = tl.files.load_npz(name=net_c_name)
net_e_loaded_params = tl.files.load_npz(name=net_e_name)
tl.files.assign_params(sess, net_c_loaded_params, net_cnn)
tl.files.assign_params(sess, net_e_loaded_params, net_rnn)
print("[*] Loading RNN and CNN checkpoints SUCCESS!")

if not (os.path.exists(net_g_name) and os.path.exists(net_d_name)):
print("[!] Loading G and D checkpoints failed!")
else:
net_g_loaded_params = tl.files.load_npz(name=net_g_name)
net_d_loaded_params = tl.files.load_npz(name=net_d_name)
tl.files.assign_params(sess, net_g_loaded_params, net_g)
tl.files.assign_params(sess, net_d_loaded_params, net_d)
print("[*] Loading G and D checkpoints SUCCESS!")

# sess=tf.Session()
# tl.ops.set_gpu_fraction(sess=sess, gpu_fraction=0.998)
# sess.run(tf.initialize_all_variables())

## seed for generation, z and sentence ids
sample_size = batch_size
sample_seed = np.random.uniform(low=-1, high=1, size=(sample_size, z_dim)).astype(np.float32) # paper said [0, 1]
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)
# sample_seed = np.random.uniform(low=-1, high=1, size=(sample_size, z_dim)).astype(np.float32) # paper said [0, 1]
# sample_sentence = ["this white and yellow flower have thin white petals and a round yellow stamen", \
# "the flower has petals that are bright pinkish purple with white stigma"] * 32
# sample_sentence = ["these flowers have petals that start off white in color and end in a dark purple towards the tips"] * 32 + \
# ["bright droopy yellow petals with burgundy streaks and a yellow stigma"] * 32
# sample_sentence = ["these white flowers have petals that start off white in color and end in a white towards the tips",
# "this yellow petals with burgundy streaks and a yellow stigma"] * 32
sample_sentence = ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/8) + \
["this yellow petals with burgundy streaks and a yellow stigma."] * int(sample_size/8) + \
["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/8) + \
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/8) + \
["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/8) + \
["the petals on this flower are white with a yellow center"] * int(sample_size/8) + \
["this flower has a lot of small round pink petals."] * int(sample_size/8) + \
["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/8) + \
["the flower has yellow petals and the center of it is brown."] * int(sample_size/8) + \
["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/8) + \
["this flower has petals that are blue and white."] * int(sample_size/8)
["this flower has petals that are blue and white."] * int(sample_size/8) +\
["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/8)

# sample_sentence = captions_ids_test[0:sample_size]
for i, sentence in enumerate(sample_sentence):
print("seed: %s" % sentence)
Expand All @@ -343,7 +322,7 @@ def prepro_img(x, mode=None):
sample_sentence = tl.prepro.pad_sequences(sample_sentence, padding='post')


n_epoch = 1000 # 600 when pre-trained rnn
n_epoch = 100 # 600 when pre-trained rnn
print_freq = 1
n_batch_epoch = int(n_images / batch_size)
for epoch in range(n_epoch):
Expand All @@ -364,7 +343,8 @@ def prepro_img(x, mode=None):
idexs2 = get_random_int(min=0, max=n_images_train-1, number=batch_size) # remove if DCGAN only
b_wrong_images = images_train[idexs2] # remove if DCGAN only
## get noise
b_z = np.random.uniform(low=-1, high=1, size=[batch_size, z_dim]).astype(np.float32) # paper said [0, 1], but [-1, 1] is better
b_z = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, z_dim)).astype(np.float32)
# b_z = np.random.uniform(low=-1, high=1, size=[batch_size, z_dim]).astype(np.float32) # paper said [0, 1], but [-1, 1] is better
## check data
# print(np.min(b_real_images), np.max(b_real_images), b_real_images.shape) # [0, 1] (64, 64, 64, 3)
# for i, seq in enumerate(b_real_caption):
Expand All @@ -374,7 +354,7 @@ def prepro_img(x, mode=None):
# exit()

## updates text-to-image mapping
if epoch < 30:
if epoch < 10:
errE, _ = sess.run([e_loss, e_optim], feed_dict={
t_real_image : b_real_images,
t_wrong_image : b_wrong_images,
Expand Down
26 changes: 25 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,31 @@ def save_images(images, size, image_path):
return imsave(images, size, image_path)



def prepro_img(x, mode=None):
if mode=='train':
# rescale [0, 255] --> (-1, 1), random flip, crop, rotate
# paper 5.1: During mini-batch selection for training we randomly pick
# an image view (e.g. crop, flip) of the image and one of the captions
# flip, rotate, crop, resize : https://github.com/reedscot/icml2016/blob/master/data/donkey_folder_coco.lua
# flip : https://github.com/paarthneekhara/text-to-image/blob/master/Utils/image_processing.py
# x = flip_axis(x, axis=1, is_random=True)
# x = rotation(x, rg=16, is_random=True, fill_mode='nearest')
# x = crop(x, wrg=50, hrg=50, is_random=True)
# x = imresize(x, size=[64, 64], interp='bilinear', mode=None)
x = x / (255. / 2.)
x = x - 1.
elif mode=='rescale':
# rescale (-1, 1) --> (0, 1) for display
x = (x + 1.) / 2.
elif mode=='debug':
x = flip_axis(x, axis=1, is_random=False)
# x = rotation(x, rg=16, is_random=False, fill_mode='nearest')
# x = crop(x, wrg=50, hrg=50, is_random=True)
# x = imresize(x, size=[64, 64], interp='bilinear', mode=None)
x = x / 255.
else:
raise Exception("Not support : %s" % mode)
return x



Expand Down

0 comments on commit e5b95c3

Please sign in to comment.