Skip to content

Commit

Permalink
basic_vqa variant (not working)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhiraj2001 committed Apr 17, 2023
1 parent 6d1f0e0 commit 9a037e2
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 70 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ datasets/
models/
logs/

__pycache__
**/__pycache__
**/.vscode
tmp
48 changes: 34 additions & 14 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,47 @@ def __len__(self):
return len(self.vqa)


def get_loader(input_dir, input_vqa_train, input_vqa_valid, max_qst_length, max_num_ans, batch_size, num_workers):
def get_loader(input_dir, input_vqa_train, input_vqa_valid, max_qst_length, max_num_ans, batch_size, num_workers, subset=None):

transform = {
phase: transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
for phase in ['train', 'valid']}

vqa_dataset = {
'train': VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_train,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['train']),
'valid': VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_valid,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['valid'])}
if subset is None:
vqa_dataset = {
'train': VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_train,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['train']),
'valid': VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_valid,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['valid'])}
else:
vqa_dataset = {
'train': torch.utils.data.Subset(
VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_train,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['train']),
range(subset)),
'valid': torch.utils.data.Subset(
VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_valid,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['valid']),
range(subset))
}

data_loader = {
phase: torch.utils.data.DataLoader(
Expand Down
117 changes: 117 additions & 0 deletions models_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import torch.nn as nn
import torchvision.models as models


class ImgEncoder(nn.Module):

def __init__(self, embed_size):
"""(1) Load the pretrained model as you want.
cf) one needs to check structure of model using 'print(model)'
to remove last fc layer from the model.
(2) Replace final fc layer (score values from the ImageNet)
with new fc layer (image feature).
(3) Normalize feature vector.
"""
super(ImgEncoder, self).__init__()

# model = models.vgg19(pretrained=True)
model = models.vgg19(weights=models.VGG19_Weights.DEFAULT)

modules = list(model.features.children())[:-1]
# modules = list(model.features.children())[:-2]
self.features = nn.Sequential(*modules) # remove last maxpool layer

# self.classifier = model.classifier

in_features = self.features[-3].out_channels # input size of feature vector
# in_features = model.features[-5].out_channels # input size of feature vector

self.fc = nn.Linear(in_features, embed_size)
# self.fc = nn.Sequential(
# nn.Linear(in_features, embed_size),
# nn.Tanh())

def forward(self, image):
"""Extract feature vector from image vector.
"""

with torch.no_grad():
img_feature = self.features(image) # [batch_size, vgg16(19)_fc=4096]

## Flattening out the image part (not with channels)
img_feature = img_feature.view(-1, 512, 196).transpose(1, 2) # [batch_size, 196, 512]

# with torch.no_grad():
# img_feature = self.classifier(img_feature)

img_feature = self.fc(img_feature) # [batch_size, 196, embed_size]

## Normalizing Ouput
# l2_norm = img_feature.norm(p=2, dim=1, keepdim=True).detach()
# l2_norm = img_feature.norm(p=2, dim=2, keepdim=True).detach()

# img_feature = img_feature.div(l2_norm) # l2-normalized feature vector

return img_feature


class QstEncoder(nn.Module):

def __init__(self, qst_vocab_size, word_embed_size, embed_size, num_layers, hidden_size):

super(QstEncoder, self).__init__()
self.word2vec = nn.Embedding(qst_vocab_size, word_embed_size)
self.tanh = nn.Tanh()
self.lstm = nn.LSTM(word_embed_size, hidden_size, num_layers)
self.fc = nn.Linear(2*num_layers*hidden_size, embed_size) # 2 for hidden and cell states

def forward(self, question):

qst_vec = self.word2vec(question) # [batch_size, max_qst_length=30, word_embed_size=300]
qst_vec = self.tanh(qst_vec)
qst_vec = qst_vec.transpose(0, 1) # [max_qst_length=30, batch_size, word_embed_size=300]
_, (hidden, cell) = self.lstm(qst_vec) # [num_layers=2, batch_size, hidden_size=512] qst_feature = torch.cat((hidden, cell), 2) # [num_layers=2, batch_size, 2*hidden_size=1024]
qst_feature = qst_feature.transpose(0, 1) # [batch_size, num_layers=2, 2*hidden_size=1024]
qst_feature = qst_feature.reshape(qst_feature.size()[0], -1) # [batch_size, 2*num_layers*hidden_size=2048]

qst_feature = self.tanh(qst_feature)
qst_feature = self.fc(qst_feature) # [batch_size, embed_size]

return qst_feature


class VqaModel(nn.Module):

def __init__(self, embed_size, qst_vocab_size, ans_vocab_size, word_embed_size, num_layers, hidden_size):

super(VqaModel, self).__init__()
self.img_encoder = ImgEncoder(embed_size)
self.qst_encoder = QstEncoder(qst_vocab_size, word_embed_size, embed_size, num_layers, hidden_size)

self.tanh = nn.Tanh()
self.dropout = nn.Dropout(0.5)

self.fc1 = nn.Linear(embed_size, ans_vocab_size)
self.fc2 = nn.Linear(ans_vocab_size, ans_vocab_size)

def forward(self, img, qst):

img_feature = self.img_encoder(img) # [batch_size, 196, embed_size

# qst_feature = self.qst_encoder(qst) # [batch_size, 196, embed_size]
qst_feature = self.qst_encoder(qst).unsqueeze(dim=1) # [batch_size, 196, embed_size]

# combined_feature = torch.mul(img_feature, qst_feature) # [batch_size, embed_size]
combined_feature = (img_feature + qst_feature).sum(dim=1) # [batch_size, embed_size]

combined_feature = self.tanh(combined_feature)
combined_feature = self.dropout(combined_feature)

combined_feature = self.fc1(combined_feature) # [batch_size, ans_vocab_size=1000]
combined_feature = self.tanh(combined_feature)

combined_feature = self.dropout(combined_feature)
combined_feature = self.fc2(combined_feature) # [batch_size, ans_vocab_size=1000]

return combined_feature
1 change: 1 addition & 0 deletions rsync_ignore.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
**/.git

**/datasets
**/datasets.tar
*.pth

**/__pycache__
Expand Down
51 changes: 35 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import torch.optim as optim
from torch.optim import lr_scheduler
from data_loader import get_loader
from models import VqaModel

# from models import VqaModel
from models_2 import VqaModel

## To avoid Cuda out of Memory Error (if doesn't work, try reducing batch size)
torch.cuda.empty_cache()
Expand All @@ -26,29 +28,43 @@ def main(args):
max_qst_length=args.max_qst_length,
max_num_ans=args.max_num_ans,
batch_size=args.batch_size,
num_workers=args.num_workers)

qst_vocab_size = data_loader['train'].dataset.qst_vocab.vocab_size
ans_vocab_size = data_loader['train'].dataset.ans_vocab.vocab_size
ans_unk_idx = data_loader['train'].dataset.ans_vocab.unk2idx
num_workers=args.num_workers,
subset=args.subset)

if isinstance(data_loader['train'].dataset, torch.utils.data.Subset):
qst_vocab_size = data_loader['train'].dataset.dataset.qst_vocab.vocab_size
ans_vocab_size = data_loader['train'].dataset.dataset.ans_vocab.vocab_size
ans_unk_idx = data_loader['train'].dataset.dataset.ans_vocab.unk2idx
else:
qst_vocab_size = data_loader['train'].dataset.qst_vocab.vocab_size
ans_vocab_size = data_loader['train'].dataset.ans_vocab.vocab_size
ans_unk_idx = data_loader['train'].dataset.ans_vocab.unk2idx

model = VqaModel(
embed_size=args.embed_size,
qst_vocab_size=qst_vocab_size,
ans_vocab_size=ans_vocab_size,
word_embed_size=args.word_embed_size,
num_layers=args.num_layers,
hidden_size=args.hidden_size).to(device)

# hidden_size=args.hidden_size).to(device)
hidden_size=args.hidden_size)

params = list(model.img_encoder.fc.parameters()) \
+ list(model.qst_encoder.parameters()) \
+ list(model.fc1.parameters()) \
+ list(model.fc2.parameters())

# params = list(model.module.img_encoder.fc.parameters()) \
# + list(model.module.qst_encoder.parameters()) \
# + list(model.module.fc1.parameters()) \
# + list(model.module.fc2.parameters())

## Data Parallel for larger batch size
model = nn.DataParallel(model)
if torch.cuda.device_count() > 1:
print("Using", torch.cuda.device_count(), "GPUs.")
# dim = 0 [40, xxx] -> [10, ...], [10, ...], [10, ...], [10, ...] on 4 GPUs
model = nn.DataParallel(model)

model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params, lr=args.learning_rate)
Expand Down Expand Up @@ -88,8 +104,7 @@ def main(args):
if phase == 'train':
loss.backward()
optimizer.step()
scheduler.step()


# Evaluation metric of 'multiple choice'
# Exp1: our model prediction to '<unk>' IS accepted as the answer.
# Exp2: our model prediction to '<unk>' is NOT accepted as the answer.
Expand All @@ -103,6 +118,10 @@ def main(args):
print('| {} SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f}'
.format(phase.upper(), epoch+1, args.num_epochs, batch_idx, int(batch_step_size), loss.item()))

# Update the learning rate.
if phase == 'train':
scheduler.step()

# Print the average loss and accuracy in an epoch.
epoch_loss = running_loss / batch_step_size
epoch_acc_exp1 = running_corr_exp1.double() / len(data_loader[phase].dataset) # multiple choice
Expand Down Expand Up @@ -168,9 +187,6 @@ def main(args):
parser.add_argument('--gamma', type=float, default=0.1,
help='multiplicative factor of learning rate decay.')

# parser.add_argument('--num_epochs', type=int, default=30,
# help='number of epochs.')

parser.add_argument('--num_epochs', type=int, default=30,
help='number of epochs.')

Expand All @@ -180,7 +196,10 @@ def main(args):
parser.add_argument('--num_workers', type=int, default=8,
help='number of processes working on cpu.')

parser.add_argument('--save_step', type=int, default=10,
parser.add_argument('--subset', type=int, default=None,
help='subset size of dataset.')

parser.add_argument('--save_step', type=int, default=1,
help='save step of model.')

args = parser.parse_args()
Expand Down
Loading

0 comments on commit 9a037e2

Please sign in to comment.