Skip to content

Commit 3d92a09

Browse files
Initial commit
0 parents  commit 3d92a09

File tree

11 files changed

+1177
-0
lines changed

11 files changed

+1177
-0
lines changed

README.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Text to Image Synthesis using Skip-thought Vectors
2+
3+
## Description
4+
This is a PyTorch implementation of the paper Generative Adversarial Text-to-Image Synthesis [http://arxiv.org/abs/1605.05396] using skip thought vectors for caption embedding. This implementation is based on DCGAN. Below is the model architecture where blue bars represent skip thought vector for the captions.
5+
6+
[Figure]
7+
Image Source : Paper
8+
9+
## Setup and Installments
10+
** Python==3.6.6
11+
** PyTorch==0.4.0
12+
** TorchVision==0.2.1
13+
** Theano
14+
15+
## Dataset
16+
**This model has been trained on the flowers dataset. Download flower dataset from here[] and save the images in Data folder as Data/flowers.
17+
18+
** Now download the corresponding captions from here[]. After extracting, copy the text_c10 folder and paste it in Data folder as Data/text_c10.
19+
20+
## Skip-Thought Model
21+
** Download the pretrained models and vocabulary for skip thought vectors as per the instructions given below. Save the downloaded files in Data/skipthoughts.
22+
23+
** Some of the files are quite large(>2GB). So make sure there is enough space available.
24+
25+
** Run below code to download skip thought model and all other required files
26+
python download_skipthought.py
27+
28+
29+
## Usage
30+
** Data Pre-processing :
31+
python data_loader.py
32+
33+
** Training :
34+
** Args :
35+
dataset : Dataset used. Default = flowers
36+
batch_size : Batch Size. Default = 1
37+
num_epochs : NUmber of epochs to train. Default = 200
38+
img_size : Size of the image. Default = 64
39+
z_dim : Latent variable dimension. Default = 100
40+
text_embedding_dim : Embedding dim of caption. Default = 4800
41+
reduced_text_dim : Reduced embedding dim of caption. Default = 1024
42+
learning_rate : Learning Rate. Default = 0.0002
43+
beta1 : Hyperparameter of the Adam optimizer. Default = 0.5
44+
beta2 : Hyperparameter of the Adam optimizer. Default = 0.999
45+
l1_coeff : Coefficient for the L1 Loss. Default = 50
46+
resume_epoch : Resume epoch to resume training. Default = 1
47+
48+
** Train the model by running below code
49+
python main.py
50+
51+
** Testing model by giving custom input text :
52+
python predict.py --text="Input caption to be used to generate the image"
53+
54+
The generated image will be save to text directory inside Data folder as Data/Testing
55+
56+
## Model key-points
57+
58+
** Skip Thought is an efficient model used for sentence embedding and is based on the concept of word
59+
embedding (word2vec or Glove). It returns a numpy array of dimension 4800 in which the first 2400
60+
dimensions is the uni-skip model and the last 2400 dimensions is the bi-skip model. We use the combine
61+
-skip vectors as experimentally, they perform the best.
62+
63+
** Text2Image model is a Generarive Adversarial Network based model which is built on top of the DCGAN.
64+
It consists of a Discriminator network and a Generator network.
65+
66+
** Discriminator network not only classifies the images generated by the generate as a fake image but also those real images which do not correspond to the correct caption. In short, fake examples are categorized by following :
67+
Fake Image + Correct Caption
68+
False Image(Real Image) + Incorrect Caption
69+
70+
** Images are 64 x 64 in dimension
71+
72+
## Generated Images
73+
Following are some of the images generated by this model
74+
[A table of few 5-6 images along with their captions]
75+
76+
## TODO
77+
Implementation of the same using an autoencoder for sentence embedding
78+
79+
80+
## References
81+
* Generative Adversarial Text-to-Image Synthesis - http://arxiv.org/abs/1605.05396
82+
* Tensorflow implementation - https://github.com/paarthneekhara/text-to-image
83+
* Skip-Thought Model - https://github.com/ryankiros/skip-thoughts
84+
85+
86+
## License
87+
MIT

data_loader.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import torch
3+
import skipthoughts
4+
import numpy as np
5+
from PIL import Image
6+
from tqdm import tqdm
7+
from torch.autograd import Variable
8+
from torch.utils.data import Dataset
9+
10+
# Each batch will have 3 things : true image, its captions(5), and false image(real image but image
11+
# corresponding to an incorrect caption).
12+
# Discriminator is trained in such a way that true_img + caption corresponds to a real example and
13+
# false_img + caption corresponds to a fake example.
14+
15+
16+
class Text2ImageDataset(Dataset):
17+
18+
def __init__(self, data_dir):
19+
self.data_dir = data_dir
20+
21+
self.load_flower_dataset()
22+
23+
def load_flower_dataset(self):
24+
# It will return two things : a list of image file names, a dictionary of 5 captions per image
25+
# with image file name as the key of the dictionary and 5 values(captions) for each key.
26+
27+
print ("------------------ Loading images ------------------")
28+
self.img_files = []
29+
for f in os.listdir(os.path.join(self.data_dir, 'flowers')):
30+
self.img_files.append(f)
31+
32+
print ('Total number of images : {}'.format(len(self.img_files)))
33+
34+
print ("------------------ Loading captions ----------------")
35+
self.img_captions = {}
36+
for class_dir in tqdm(os.listdir(os.path.join(self.data_dir, 'text_c10'))):
37+
if not 't7' in class_dir:
38+
for cap_file in class_dir:
39+
if 'txt' in cap_file:
40+
with open(cap_file) as f:
41+
captions = f.read().split('\n')
42+
img_file = cap_file[:11] + '.jpg'
43+
# 5 captions per image
44+
self.img_captions[img_file] = captions[:5]
45+
46+
print ("--------------- Loading Skip-thought Model ---------------")
47+
model = skipthoughts.load_model()
48+
self.encoded_captions = {}
49+
50+
print ("------------ Encoding of image captions STARTED ------------")
51+
for img_file in self.img_captions:
52+
self.encoded_captions[img_file] = skipthoughts.encode(model, self.img_captions[img_file])
53+
# print (type(self.encoded_captions[img_file]))
54+
# convert it to torch tensor if it is a numpy array
55+
56+
print ("------------- Encoding of image captions DONE -------------")
57+
58+
def read_image(self, image_file_name):
59+
image = Image.open(os.path.join(self.data_dir, 'flowers/' + image_file_name))
60+
# check its shape and reshape it to (64, 64, 3)
61+
return image
62+
63+
def get_false_img(self, index):
64+
false_img_id = np.random.randint(len(self.img_files))
65+
if false_img_id != index:
66+
return self.img_files[false_img_id]
67+
68+
return self.get_false_img(index)
69+
70+
def __len__(self):
71+
72+
return len(self.img_files)
73+
74+
def __getitem__(self, index):
75+
76+
sample = {}
77+
sample['true_imgs'] = torch.FloatTensor(self.read_image(self.img_files[index]))
78+
sample['false_imgs'] = torch.FloatTensor(self.read_image(self.get_false_img(index)))
79+
sample['true_embed'] = torch.FloatTensor(self.encoded_captions[self.img_files[index]])
80+
81+
return sample

download_skipthought.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import os
2+
3+
print ('Downloading Skip-Thought Model ...........')
4+
os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/dictionary.txt')
5+
os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/utable.npy')
6+
os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/btable.npy')
7+
os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/uni_skip.npz')
8+
os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/uni_skip.npz.pkl')
9+
os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/bi_skip.npz')
10+
os.sysytem('wget http://www.cs.toronto.edu/~rkiros/models/bi_skip.npz.pkl')
11+
12+
print ('Download Completed ............')

imgs/net.jpeg

71.6 KB
Loading

main.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import os
2+
import torch
3+
import train
4+
import argparse
5+
import numpy as np
6+
7+
from train import GAN_CLS
8+
from torch.utils.data import DataLoader
9+
from data_loader import Text2ImageDataset
10+
11+
12+
def check_dir(dir_name):
13+
if not os.path.exists(dir_name):
14+
os.makedirs(dir_name)
15+
16+
print ('{} created'.format(dir_name))
17+
18+
19+
def check_args(args):
20+
# Make all directories if they don't exist
21+
22+
# --checkpoint_dir
23+
check_dir(args.checkpoint_dir)
24+
25+
# --sample_dir
26+
check_dir(args.sample_dir)
27+
28+
# --log_dir
29+
check_dir(args.log_dir)
30+
31+
# --final_model dir
32+
check_dir(args.final_model)
33+
34+
# --epoch
35+
assert args.num_epochs > 0, 'Number of epochs must be greater than 0'
36+
37+
# --batch_size
38+
assert args.batch_size > 0, 'Batch size must be greater than zero'
39+
40+
# --z_dim
41+
assert args.z_dim > 0, 'Size of the noise vector must be greater than zero'
42+
43+
return args
44+
45+
46+
def main():
47+
48+
parser = argparse.ArgumentParser()
49+
50+
parser.add_argument_group('Dataset related arguments')
51+
parser.add_argument('--data_dir', type=str, default="Data",
52+
help='Data Directory')
53+
54+
parser.add_argument('--dataset', type=str, default="flowers",
55+
help='Dataset to train')
56+
57+
parser.add_argument_group('Model saving path and steps related arguments')
58+
parser.add_argument('--log_step', type=int, default=100,
59+
help='Save INFO into logger after every x iterations')
60+
61+
parser.add_argument('--sample_step', type=int, default=100,
62+
help='Save generated image after every x iterations')
63+
64+
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints',
65+
help='Save model checkpoints after every x iterations')
66+
67+
parser.add_argument('--sample_dir', type=str, default='sample',
68+
help='Save generated image after every x iterations')
69+
70+
parser.add_argument('--log_dir', type=str, default='logs',
71+
help='Save INFO into logger after every x iterations')
72+
73+
parser.add_argument('--final_model', type=str, default='final_model',
74+
help='Save INFO into logger after every x iterations')
75+
76+
parser.add_argument_group('Model training related arguments')
77+
parser.add_argument('--num_epochs', type=int, default=200,
78+
help='Total number of epochs to train')
79+
80+
parser.add_argument('--batch_size', type=int, default=1,
81+
help='Batch Size')
82+
83+
parser.add_argument('--img_size', type=int, default=64,
84+
help='Size of the image')
85+
86+
parser.add_argument('--z_dim', type=int, default=100,
87+
help='Size of the latent variable')
88+
89+
parser.add_argument('--text_embed_dim', type=int, default=4800,
90+
help='Size of the embeddding for the captions')
91+
92+
parser.add_argument('--text_reduced_dim', type=int, default=1024,
93+
help='Reduced dimension of the caption encoding')
94+
95+
parser.add_argument('--learning_rate', type=float, default=0.0002,
96+
help='Learning Rate')
97+
98+
parser.add_argument('--beta1', type=float, default=0.5,
99+
help='Hyperparameter of the Adam optimizer')
100+
101+
parser.add_argument('--beta2', type=float, default=0.999,
102+
help='Hyperparameter of the Adam optimizer')
103+
104+
parser.add_argument('--l1_coeff', type=float, default=50,
105+
help='Coefficient for the L1 Loss')
106+
107+
parser.add_argument('--resume_epoch', type=int, default=1,
108+
help='Resume epoch to resume training')
109+
110+
args = parser.parse_args()
111+
112+
check_args(args)
113+
114+
dataset = Text2ImageDataset()
115+
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
116+
117+
gan = GAN_CLS(args, data_loader)
118+
119+
gan.build_model()
120+
gan.train_model()
121+
122+
123+
if __name__ == '__main__':
124+
main()

nets/__init__.py

Whitespace-only changes.

nets/discriminator.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.autograd import Variable
4+
from torch.nn import functional as F
5+
6+
7+
class Discriminator(nn.Module):
8+
def __init__(self, batch_size, img_size, text_embed_dim, text_reduced_dim):
9+
super(Discriminator, self).__init__()
10+
11+
self.batch_size = batch_size
12+
self.img_size = img_size
13+
self.in_channels = img_size.size()[2]
14+
self.text_embed_dim = text_embed_dim
15+
self.text_reduced_dim = text_reduced_dim
16+
17+
# Defining the discriminator network architecture
18+
self.d_net = nn.Sequential(
19+
nn.Conv2d(self.in_channels, 64, 4, 2, 1, bias=False),
20+
nn.LeakyReLU(0.2, inplace=True),
21+
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
22+
nn.BatchNorm2d(128),
23+
nn.LeakyReLU(0.2, inplace=True),
24+
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
25+
nn.BatchNorm2d(256),
26+
nn.LeakyReLU(0.2, inplace=True),
27+
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
28+
nn.BatchNorm2d(512),
29+
nn.LeakyReLU(0.2, inplace=True))
30+
31+
# output_dim = (batch_size, 4, 4, 512)
32+
# text.size() = (batch_size, text_embed_dim)
33+
34+
# Defining a linear layer to reduce the dimensionality of caption embedding
35+
# from text_embed_dim to text_reduced_dim
36+
self.text_reduced_dim = nn.Linear(self.text_embed_dim, self.text_reduced_dim)
37+
38+
self.cat_net = nn.Sequential(
39+
nn.Conv2d(512 + self.text_reduced_dim, 512, 4, 2, 1, bias=False),
40+
nn.BatchNorm2d(512),
41+
nn.LeakyReLU(0.2, inplace=True))
42+
43+
self.linear = nn.Linear(2 * 2 * 512, 1)
44+
45+
def forward(self, image, text):
46+
""" Given the image and its caption embedding, predict whether the image
47+
is real or fake.
48+
49+
Arguments
50+
---------
51+
image : torch.FloatTensor
52+
image.size() = (batch_size, 64, 64, 3)
53+
54+
text : torch.FloatTensor
55+
Output of the skipthought embedding model for the caption
56+
text.size() = (batch_size, text_embed_dim)
57+
58+
--------
59+
Returns
60+
--------
61+
output : Probability for the image being real/fake
62+
logit : Final score of the discriminator
63+
64+
"""
65+
66+
d_net_out = self.d_net(image) # (batch_size, 4, 4, 512)
67+
text_reduced = self.text_reduced_dim(text) # (batch_size, text_reduced_dim)
68+
text_reduced = text_reduced.squeeze(1) # (batch_size, 1, text_reduced_dim)
69+
text_reduced = text_reduced.squeeze(2) # (batch_size, 1, 1, text_reduced_dim)
70+
text_reduced = text_reduced.expand(1, 4, 4, self.text_reduced_dim)
71+
72+
concat_out = torch.cat((d_net_out, text_reduced), 3) # (1, 4, 4, 512+text_reduced_dim)
73+
74+
logit = self.cat_net(concat_out)
75+
concat_out = torch.view(-1, concat_out.size()[1] * concat_out.size()[2] * concat_out.size()[3])
76+
concat_out = self.linear(concat_out)
77+
78+
output = F.sigmoid(logit)
79+
80+
return output, logit

0 commit comments

Comments
 (0)