Skip to content

Commit 9c890c3

Browse files
committed
Public release
0 parents  commit 9c890c3

File tree

13 files changed

+907
-0
lines changed

13 files changed

+907
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__pycache__
2+
*.pyc
3+
vqa

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "resnet"]
2+
path = resnet
3+
url = https://github.com/Cyanogenoid/pytorch-resnet

README.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Strong baseline for visual question answering
2+
3+
This is a re-implementation of Vahid Kazemi and Ali Elqursh's paper [Show, Ask, Attend, and Answer: A Strong Baseline For Visual Question Answering][0] in [PyTorch][1].
4+
5+
The paper shows that with a relatively simple model, using only common building blocks in Deep Learning, you can get better accuracies than the majority of previously published work on the popular [VQA v1][2] dataset.
6+
7+
This repository is intended to provide a straightforward implementation of the paper for other researchers to build on.
8+
The results closely match the reported results, as the majority of details should be exactly the same as the paper. (Thanks to the authors for answering my questions about some details!)
9+
This implementation seems to consistently converge to about 0.1% better results, but I am not aware of what implementation difference is causing this.
10+
11+
A fully trained model (convergence shown below) is [available for download][5].
12+
13+
![Graph of convergence of implementation versus paper results](http://i.imgur.com/moWYEm8.png)
14+
15+
16+
## Running the model
17+
18+
- Clone this repository with:
19+
```
20+
git clone https://github.com/Cyanogenoid/pytorch-vqa --recursive
21+
```
22+
- Set the paths to your downloaded [questions, answers, and MS COCO images][4] in `config.py`.
23+
- `qa_path` should contain the files `OpenEnded_mscoco_train2014_questions.json`, `OpenEnded_mscoco_val2014_questions.json`, `mscoco_train2014_annotations.json`, `mscoco_val2014_annotations.json`.
24+
- `train_path`, `val_path`, `test_path` should contain the train, validation, and test `.jpg` images respectively.
25+
- Pre-process images (93 GiB of free disk space required for f16 accuracy) with [ResNet152 weights ported from Caffe][3] and vocabularies for questions and answers with:
26+
```
27+
python preprocess-images.py
28+
python preprocess-vocab.py
29+
```
30+
- Train the model in `model.py` with:
31+
```
32+
python train.py
33+
```
34+
This will alternate between one epoch of training on the train split and one epoch of validation on the validation split while printing the current training progress to stdout and saving logs in the `logs` directory.
35+
The logs contain the name of the model, training statistics, contents of `config.py`, model weights, evaluation information (per-question answer and accuracy), and question and answer vocabularies.
36+
- During training (which takes a while), plot the training progress with:
37+
```
38+
python view-log.py <path to .pth log>
39+
```
40+
41+
42+
## Python 3 dependencies (tested on Python 3.6.2)
43+
44+
- torch
45+
- torchvision
46+
- h5py
47+
- tqdm
48+
49+
50+
51+
[0]: https://arxiv.org/abs/1704.03162
52+
[1]: https://github.com/pytorch/pytorch
53+
[2]: http://visualqa.org/
54+
[3]: https://github.com/ruotianluo/pytorch-resnet
55+
[4]: http://visualqa.org/vqa_v1_download.html
56+
[5]: https://github.com/Cyanogenoid/pytorch-vqa/releases

config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# paths
2+
qa_path = 'vqa' # directory containing the question and annotation jsons
3+
train_path = 'mscoco/train2014' # directory of training images
4+
val_path = 'mscoco/val2014' # directory of validation images
5+
test_path = 'mscoco/test2015' # directory of test images
6+
preprocessed_path = '/ssd/resnet-14x14.h5' # path where preprocessed features are saved to and loaded from
7+
vocabulary_path = 'vocab.json' # path where the used vocabularies for question and answers are saved to
8+
9+
task = 'OpenEnded'
10+
dataset = 'mscoco'
11+
12+
# preprocess config
13+
preprocess_batch_size = 64
14+
image_size = 448 # scale shorter end of image to this size and centre crop
15+
output_size = image_size // 32 # size of the feature maps after processing through a network
16+
output_features = 2048 # number of feature maps thereof
17+
central_fraction = 0.875 # only take this much of the centre when scaling and centre cropping
18+
19+
# training config
20+
epochs = 50
21+
batch_size = 128
22+
initial_lr = 1e-3 # default Adam lr
23+
lr_halflife = 50000 # in iterations
24+
data_workers = 8
25+
max_answers = 3000

data.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import json
2+
import os
3+
import os.path
4+
import re
5+
6+
from PIL import Image
7+
import h5py
8+
import torch
9+
import torch.utils.data as data
10+
import torchvision.transforms as transforms
11+
12+
import config
13+
import utils
14+
15+
16+
def get_loader(train=False, val=False, test=False):
17+
""" Returns a data loader for the desired split """
18+
assert train + val + test == 1, 'need to set exactly one of {train, val, test} to True'
19+
split = VQA(
20+
utils.path_for(train=train, val=val, test=test, question=True),
21+
utils.path_for(train=train, val=val, test=test, answer=True),
22+
config.preprocessed_path,
23+
answerable_only=train,
24+
)
25+
loader = torch.utils.data.DataLoader(
26+
split,
27+
batch_size=config.batch_size,
28+
shuffle=train, # only shuffle the data in training
29+
pin_memory=True,
30+
num_workers=config.data_workers,
31+
collate_fn=collate_fn,
32+
)
33+
return loader
34+
35+
36+
def collate_fn(batch):
37+
# put question lengths in descending order so that we can use packed sequences later
38+
batch.sort(key=lambda x: x[-1], reverse=True)
39+
return data.dataloader.default_collate(batch)
40+
41+
42+
class VQA(data.Dataset):
43+
""" VQA dataset, open-ended """
44+
def __init__(self, questions_path, answers_path, image_features_path, answerable_only=False):
45+
super(VQA, self).__init__()
46+
with open(questions_path, 'r') as fd:
47+
questions_json = json.load(fd)
48+
with open(answers_path, 'r') as fd:
49+
answers_json = json.load(fd)
50+
with open(config.vocabulary_path, 'r') as fd:
51+
vocab_json = json.load(fd)
52+
self._check_integrity(questions_json, answers_json)
53+
54+
# vocab
55+
self.vocab = vocab_json
56+
self.token_to_index = self.vocab['question']
57+
self.answer_to_index = self.vocab['answer']
58+
59+
# q and a
60+
self.questions = list(prepare_questions(questions_json))
61+
self.answers = list(prepare_answers(answers_json))
62+
self.questions = [self._encode_question(q) for q in self.questions]
63+
self.answers = [self._encode_answers(a) for a in self.answers]
64+
65+
# v
66+
self.image_features_path = image_features_path
67+
self.coco_id_to_index = self._create_coco_id_to_index()
68+
self.coco_ids = [q['image_id'] for q in questions_json['questions']]
69+
70+
# only use questions that have at least one answer?
71+
self.answerable_only = answerable_only
72+
if self.answerable_only:
73+
self.answerable = self._find_answerable()
74+
75+
@property
76+
def max_question_length(self):
77+
if not hasattr(self, '_max_length'):
78+
self._max_length = max(map(len, self.questions))
79+
return self._max_length
80+
81+
@property
82+
def num_tokens(self):
83+
return len(self.token_to_index) + 1 # add 1 for <unknown> token at index 0
84+
85+
def _create_coco_id_to_index(self):
86+
""" Create a mapping from a COCO image id into the corresponding index into the h5 file """
87+
with h5py.File(self.image_features_path, 'r') as features_file:
88+
coco_ids = features_file['ids'][()]
89+
coco_id_to_index = {id: i for i, id in enumerate(coco_ids)}
90+
return coco_id_to_index
91+
92+
def _check_integrity(self, questions, answers):
93+
""" Verify that we are using the correct data """
94+
qa_pairs = list(zip(questions['questions'], answers['annotations']))
95+
assert all(q['question_id'] == a['question_id'] for q, a in qa_pairs), 'Questions not aligned with answers'
96+
assert all(q['image_id'] == a['image_id'] for q, a in qa_pairs), 'Image id of question and answer don\'t match'
97+
assert questions['data_type'] == answers['data_type'], 'Mismatched data types'
98+
assert questions['data_subtype'] == answers['data_subtype'], 'Mismatched data subtypes'
99+
100+
def _find_answerable(self):
101+
""" Create a list of indices into questions that will have at least one answer that is in the vocab """
102+
answerable = []
103+
for i, answers in enumerate(self.answers):
104+
answer_has_index = len(answers.nonzero()) > 0
105+
# store the indices of anything that is answerable
106+
if answer_has_index:
107+
answerable.append(i)
108+
return answerable
109+
110+
def _encode_question(self, question):
111+
""" Turn a question into a vector of indices and a question length """
112+
vec = torch.zeros(self.max_question_length).long()
113+
for i, token in enumerate(question):
114+
index = self.token_to_index.get(token, 0)
115+
vec[i] = index
116+
return vec, len(question)
117+
118+
def _encode_answers(self, answers):
119+
""" Turn an answer into a vector """
120+
# answer vec will be a vector of answer counts to determine which answers will contribute to the loss.
121+
# this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up
122+
# to get the loss that is weighted by how many humans gave that answer
123+
answer_vec = torch.zeros(len(self.answer_to_index))
124+
for answer in answers:
125+
index = self.answer_to_index.get(answer)
126+
if index is not None:
127+
answer_vec[index] += 1
128+
return answer_vec
129+
130+
def _load_image(self, image_id):
131+
""" Load an image """
132+
if not hasattr(self, 'features_file'):
133+
# Loading the h5 file has to be done here and not in __init__ because when the DataLoader
134+
# forks for multiple works, every child would use the same file object and fail
135+
# Having multiple readers using different file objects is fine though, so we just init in here.
136+
self.features_file = h5py.File(self.image_features_path, 'r')
137+
index = self.coco_id_to_index[image_id]
138+
dataset = self.features_file['features']
139+
img = dataset[index].astype('float32')
140+
return torch.from_numpy(img)
141+
142+
def __getitem__(self, item):
143+
if self.answerable_only:
144+
# change of indices to only address answerable questions
145+
item = self.answerable[item]
146+
147+
q, q_length = self.questions[item]
148+
a = self.answers[item]
149+
image_id = self.coco_ids[item]
150+
v = self._load_image(image_id)
151+
# since batches are re-ordered for PackedSequence's, the original question order is lost
152+
# we return `item` so that the order of (v, q, a) triples can be restored if desired
153+
# without shuffling in the dataloader, these will be in the order that they appear in the q and a json's.
154+
return v, q, a, item, q_length
155+
156+
def __len__(self):
157+
if self.answerable_only:
158+
return len(self.answerable)
159+
else:
160+
return len(self.questions)
161+
162+
163+
# this is used for normalizing questions
164+
_special_chars = re.compile('[^a-z0-9 ]*')
165+
166+
# these try to emulate the original normalisation scheme for answers
167+
_period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)')
168+
_comma_strip = re.compile(r'(\d)(,)(\d)')
169+
_punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!')
170+
_punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars)))
171+
_punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars))
172+
173+
174+
def prepare_questions(questions_json):
175+
""" Tokenize and normalize questions from a given question json in the usual VQA format. """
176+
questions = [q['question'] for q in questions_json['questions']]
177+
for question in questions:
178+
question = question.lower()[:-1]
179+
yield question.split(' ')
180+
181+
182+
def prepare_answers(answers_json):
183+
""" Normalize answers from a given answer json in the usual VQA format. """
184+
answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json['annotations']]
185+
# The only normalisation that is applied to both machine generated answers as well as
186+
# ground truth answers is replacing most punctuation with space (see [0] and [1]).
187+
# Since potential machine generated answers are just taken from most common answers, applying the other
188+
# normalisations is not needed, assuming that the human answers are already normalized.
189+
# [0]: http://visualqa.org/evaluation.html
190+
# [1]: https://github.com/VT-vision-lab/VQA/blob/3849b1eae04a0ffd83f56ad6f70ebd0767e09e0f/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L96
191+
192+
def process_punctuation(s):
193+
# the original is somewhat broken, so things that look odd here might just be to mimic that behaviour
194+
# this version should be faster since we use re instead of repeated operations on str's
195+
if _punctuation.search(s) is None:
196+
return s
197+
s = _punctuation_with_a_space.sub('', s)
198+
if re.search(_comma_strip, s) is not None:
199+
s = s.replace(',', '')
200+
s = _punctuation.sub(' ', s)
201+
s = _period_strip.sub('', s)
202+
return s.strip()
203+
204+
for answer_list in answers:
205+
yield list(map(process_punctuation, answer_list))
206+
207+
208+
class CocoImages(data.Dataset):
209+
""" Dataset for MSCOCO images located in a folder on the filesystem """
210+
def __init__(self, path, transform=None):
211+
super(CocoImages, self).__init__()
212+
self.path = path
213+
self.id_to_filename = self._find_images()
214+
self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order
215+
print('found {} images in {}'.format(len(self), self.path))
216+
self.transform = transform
217+
218+
def _find_images(self):
219+
id_to_filename = {}
220+
for filename in os.listdir(self.path):
221+
if not filename.endswith('.jpg'):
222+
continue
223+
id_and_extension = filename.split('_')[-1]
224+
id = int(id_and_extension.split('.')[0])
225+
id_to_filename[id] = filename
226+
return id_to_filename
227+
228+
def __getitem__(self, item):
229+
id = self.sorted_ids[item]
230+
path = os.path.join(self.path, self.id_to_filename[id])
231+
img = Image.open(path).convert('RGB')
232+
233+
if self.transform is not None:
234+
img = self.transform(img)
235+
return id, img
236+
237+
def __len__(self):
238+
return len(self.sorted_ids)
239+
240+
241+
class Composite(data.Dataset):
242+
""" Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """
243+
def __init__(self, *datasets):
244+
self.datasets = datasets
245+
246+
def __getitem__(self, item):
247+
current = self.datasets[0]
248+
for d in self.datasets:
249+
if item < len(d):
250+
return d[item]
251+
item -= len(d)
252+
else:
253+
raise IndexError('Index too large for composite dataset')
254+
255+
def __len__(self):
256+
return sum(map(len, self.datasets))

logs/.dummy

Whitespace-only changes.

0 commit comments

Comments
 (0)