Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
help='dir dataset to download or/and load images')
parser.add_argument('--data_split', default='train', type=str,
help='Options: (default) train | val | test')
parser.add_argument('--arch', '-a', default='resnet152',
parser.add_argument('--arch', '-a', default='fbresnet152',
choices=convnets.model_names,
help='model architecture: ' +
' | '.join(convnets.model_names) +
' (default: fbresnet152)')
parser.add_argument('--workers', default=4, type=int,
parser.add_argument('--workers', default=4, type=int,
help='number of data loading workers (default: 4)')
parser.add_argument('--batch_size', '-b', default=80, type=int,
parser.add_argument('--batch_size', '-b', default=80, type=int,
help='mini-batch size (default: 80)')
parser.add_argument('--mode', default='both', type=str,
help='Options: att | noatt | (default) both')
Expand All @@ -56,7 +56,7 @@ def main():
if args.dataset == 'coco':
if 'coco' not in args.dir_data:
raise ValueError('"coco" string not in dir_data')
dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data),
dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data),
transform=transforms.Compose([
transforms.Scale(args.size),
transforms.CenterCrop(args.size),
Expand All @@ -68,7 +68,7 @@ def main():
raise ValueError('train split is required for vgenome')
if 'vgenome' not in args.dir_data:
raise ValueError('"vgenome" string not in dir_data')
dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data),
dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data),
transform=transforms.Compose([
transforms.Scale(args.size),
transforms.CenterCrop(args.size),
Expand Down Expand Up @@ -122,7 +122,7 @@ def extract(data_loader, model, path_file, mode):

nb_regions = output_att.size(2) * output_att.size(3)
output_noatt = output_att.sum(3).sum(2).div(nb_regions).view(-1, 2048)

batch_size = output_att.size(0)
if mode == 'both' or mode == 'att':
hdf5_att[idx:idx+batch_size] = output_att.data.cpu().numpy()
Expand All @@ -141,7 +141,7 @@ def extract(data_loader, model, path_file, mode):
i, len(data_loader),
batch_time=batch_time,
data_time=data_time,))

hdf5_file.close()

# Saving image names in the same order than extraction
Expand All @@ -154,4 +154,4 @@ def extract(data_loader, model, path_file, mode):


if __name__ == '__main__':
main()
main()
3 changes: 2 additions & 1 deletion options/vqa/mutan_att_trainval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ vqa:
samplingans: True
coco:
dir: data/coco
arch: fbresnet152torch
arch: fbresnet152
mode: att
size: 448
model:
arch: MutanAtt
dim_v: 2048
Expand Down
9 changes: 6 additions & 3 deletions vqa/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,21 @@ class COCOImages(AbstractImagesDataset):
def __init__(self, data_split, opt, transform=None, loader=default_loader):
self.split_name = split_name(data_split)
super(COCOImages, self).__init__(data_split, opt, transform, loader)
self.dir_split = os.path.join(self.dir_raw, self.split_name)
self.dir_split = self.get_dir_data()
self.dataset = ImagesFolder(self.dir_split, transform=self.transform, loader=self.loader)
self.name_to_index = self._load_name_to_index()

def get_dir_data(self):
return os.path.join(self.dir_raw, self.split_name)

def _raw(self):
if self.data_split in ['train', 'val']:
os.system('wget http://msvocds.blob.core.windows.net/coco2014/{}.zip -P {}'.format(self.split_name, self.dir_raw))
elif self.data_split == 'test':
os.execute('wget http://msvocds.blob.core.windows.net/coco2015/test2015.zip -P '+self.dir_raw)
os.system('wget http://msvocds.blob.core.windows.net/coco2015/test2015.zip -P '+self.dir_raw)
else:
assert False, 'data_split {} not exists'.format(self.data_split)
os.execute('unzip '+os.path.join(self.dir_raw, self.split_name+'.zip')+' -d '+self.dir_raw)
os.system('unzip '+os.path.join(self.dir_raw, self.split_name+'.zip')+' -d '+self.dir_raw)

def _load_name_to_index(self):
self.name_to_index = {name:index for index, name in enumerate(self.dataset.imgs)}
Expand Down
11 changes: 7 additions & 4 deletions vqa/datasets/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, root, transform=None, loader=default_loader):

def __getitem__(self, index):
item = {}
item['name'] = self.imgs[index]
item['name'] = self.imgs[index]
item['path'] = os.path.join(self.root, item['name'])
if self.loader is not None:
item['visual'] = self.loader(item['path'])
Expand All @@ -57,11 +57,14 @@ def __init__(self, data_split, opt, transform=None, loader=default_loader):
self.opt = opt
self.transform = transform
self.loader = loader

self.dir_raw = os.path.join(self.opt['dir'], 'raw')
if not os.path.exists(self.dir_raw):

if not os.path.exists(self.get_dir_data()):
self._raw()

def get_dir_data(self):
return self.dir_raw

def get_by_name(self, image_name):
index = self.name_to_index[image_name]
return self[index]
Expand All @@ -73,4 +76,4 @@ def __getitem__(self, index):
raise NotImplementedError

def __len__(self):
raise NotImplementedError
raise NotImplementedError
32 changes: 21 additions & 11 deletions vqa/models/convnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
import torch.nn as nn
import torchvision.models as pytorch_models

import sys
sys.path.append('vqa/external/pretrained-models.pytorch')
import pretrainedmodels as torch7_models
Expand All @@ -21,7 +20,21 @@
def factory(opt, cuda=True, data_parallel=True):
opt = copy.copy(opt)

# forward_* will be better handle in futur release
class WrapperModule(nn.Module):
def __init__(self, net, forward_fn):
super(WrapperModule, self).__init__()
self.net = net
self.forward_fn = forward_fn

def forward(self, x):
return self.forward_fn(self.net, x)

def __getattr__(self, attr):
try:
return super(WrapperModule, self).__getattr__(attr)
except AttributeError:
return getattr(self.net, attr)

def forward_resnet(self, x):
x = self.conv1(x)
x = self.bn1(x)
Expand Down Expand Up @@ -58,26 +71,23 @@ def forward_resnext(self, x):
if opt['arch'] in pytorch_resnet_names:
model = pytorch_models.__dict__[opt['arch']](pretrained=True)

convnet = model # ugly hack in case of DataParallel wrapping
model.forward = lambda x: forward_resnet(convnet, x)
model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping

elif opt['arch'] == 'fbresnet152':
model = torch7_models.__dict__[opt['arch']](num_classes=1000,
pretrained='imagenet')

convnet = model # ugly hack in case of DataParallel wrapping
model.forward = lambda x: forward_resnet(convnet, x)
model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping

elif opt['arch'] in torch7_resnet_names:
model = torch7_models.__dict__[opt['arch']](num_classes=1000,
pretrained='imagenet')

convnet = model # ugly hack in case of DataParallel wrapping
model.forward = lambda x: forward_resnext(convnet, x)

model = WrapperModule(model, forward_resnext) # ugly hack in case of DataParallel wrapping

else:
raise ValueError

if data_parallel:
model = nn.DataParallel(model).cuda()
if not cuda:
Expand All @@ -86,4 +96,4 @@ def forward_resnext(self, x):
if cuda:
model.cuda()

return model
return model