|
| 1 | +import json |
| 2 | +import os |
| 3 | +import os.path |
| 4 | +import re |
| 5 | + |
| 6 | +from PIL import Image |
| 7 | +import h5py |
| 8 | +import torch |
| 9 | +import torch.utils.data as data |
| 10 | +import torchvision.transforms as transforms |
| 11 | + |
| 12 | +import config |
| 13 | +import utils |
| 14 | + |
| 15 | + |
| 16 | +def get_loader(train=False, val=False, test=False): |
| 17 | + """ Returns a data loader for the desired split """ |
| 18 | + assert train + val + test == 1, 'need to set exactly one of {train, val, test} to True' |
| 19 | + split = VQA( |
| 20 | + utils.path_for(train=train, val=val, test=test, question=True), |
| 21 | + utils.path_for(train=train, val=val, test=test, answer=True), |
| 22 | + config.preprocessed_path, |
| 23 | + answerable_only=train, |
| 24 | + ) |
| 25 | + loader = torch.utils.data.DataLoader( |
| 26 | + split, |
| 27 | + batch_size=config.batch_size, |
| 28 | + shuffle=train, # only shuffle the data in training |
| 29 | + pin_memory=True, |
| 30 | + num_workers=config.data_workers, |
| 31 | + collate_fn=collate_fn, |
| 32 | + ) |
| 33 | + return loader |
| 34 | + |
| 35 | + |
| 36 | +def collate_fn(batch): |
| 37 | + # put question lengths in descending order so that we can use packed sequences later |
| 38 | + batch.sort(key=lambda x: x[-1], reverse=True) |
| 39 | + return data.dataloader.default_collate(batch) |
| 40 | + |
| 41 | + |
| 42 | +class VQA(data.Dataset): |
| 43 | + """ VQA dataset, open-ended """ |
| 44 | + def __init__(self, questions_path, answers_path, image_features_path, answerable_only=False): |
| 45 | + super(VQA, self).__init__() |
| 46 | + with open(questions_path, 'r') as fd: |
| 47 | + questions_json = json.load(fd) |
| 48 | + with open(answers_path, 'r') as fd: |
| 49 | + answers_json = json.load(fd) |
| 50 | + with open(config.vocabulary_path, 'r') as fd: |
| 51 | + vocab_json = json.load(fd) |
| 52 | + self._check_integrity(questions_json, answers_json) |
| 53 | + |
| 54 | + # vocab |
| 55 | + self.vocab = vocab_json |
| 56 | + self.token_to_index = self.vocab['question'] |
| 57 | + self.answer_to_index = self.vocab['answer'] |
| 58 | + |
| 59 | + # q and a |
| 60 | + self.questions = list(prepare_questions(questions_json)) |
| 61 | + self.answers = list(prepare_answers(answers_json)) |
| 62 | + self.questions = [self._encode_question(q) for q in self.questions] |
| 63 | + self.answers = [self._encode_answers(a) for a in self.answers] |
| 64 | + |
| 65 | + # v |
| 66 | + self.image_features_path = image_features_path |
| 67 | + self.coco_id_to_index = self._create_coco_id_to_index() |
| 68 | + self.coco_ids = [q['image_id'] for q in questions_json['questions']] |
| 69 | + |
| 70 | + # only use questions that have at least one answer? |
| 71 | + self.answerable_only = answerable_only |
| 72 | + if self.answerable_only: |
| 73 | + self.answerable = self._find_answerable() |
| 74 | + |
| 75 | + @property |
| 76 | + def max_question_length(self): |
| 77 | + if not hasattr(self, '_max_length'): |
| 78 | + self._max_length = max(map(len, self.questions)) |
| 79 | + return self._max_length |
| 80 | + |
| 81 | + @property |
| 82 | + def num_tokens(self): |
| 83 | + return len(self.token_to_index) + 1 # add 1 for <unknown> token at index 0 |
| 84 | + |
| 85 | + def _create_coco_id_to_index(self): |
| 86 | + """ Create a mapping from a COCO image id into the corresponding index into the h5 file """ |
| 87 | + with h5py.File(self.image_features_path, 'r') as features_file: |
| 88 | + coco_ids = features_file['ids'][()] |
| 89 | + coco_id_to_index = {id: i for i, id in enumerate(coco_ids)} |
| 90 | + return coco_id_to_index |
| 91 | + |
| 92 | + def _check_integrity(self, questions, answers): |
| 93 | + """ Verify that we are using the correct data """ |
| 94 | + qa_pairs = list(zip(questions['questions'], answers['annotations'])) |
| 95 | + assert all(q['question_id'] == a['question_id'] for q, a in qa_pairs), 'Questions not aligned with answers' |
| 96 | + assert all(q['image_id'] == a['image_id'] for q, a in qa_pairs), 'Image id of question and answer don\'t match' |
| 97 | + assert questions['data_type'] == answers['data_type'], 'Mismatched data types' |
| 98 | + assert questions['data_subtype'] == answers['data_subtype'], 'Mismatched data subtypes' |
| 99 | + |
| 100 | + def _find_answerable(self): |
| 101 | + """ Create a list of indices into questions that will have at least one answer that is in the vocab """ |
| 102 | + answerable = [] |
| 103 | + for i, answers in enumerate(self.answers): |
| 104 | + answer_has_index = len(answers.nonzero()) > 0 |
| 105 | + # store the indices of anything that is answerable |
| 106 | + if answer_has_index: |
| 107 | + answerable.append(i) |
| 108 | + return answerable |
| 109 | + |
| 110 | + def _encode_question(self, question): |
| 111 | + """ Turn a question into a vector of indices and a question length """ |
| 112 | + vec = torch.zeros(self.max_question_length).long() |
| 113 | + for i, token in enumerate(question): |
| 114 | + index = self.token_to_index.get(token, 0) |
| 115 | + vec[i] = index |
| 116 | + return vec, len(question) |
| 117 | + |
| 118 | + def _encode_answers(self, answers): |
| 119 | + """ Turn an answer into a vector """ |
| 120 | + # answer vec will be a vector of answer counts to determine which answers will contribute to the loss. |
| 121 | + # this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up |
| 122 | + # to get the loss that is weighted by how many humans gave that answer |
| 123 | + answer_vec = torch.zeros(len(self.answer_to_index)) |
| 124 | + for answer in answers: |
| 125 | + index = self.answer_to_index.get(answer) |
| 126 | + if index is not None: |
| 127 | + answer_vec[index] += 1 |
| 128 | + return answer_vec |
| 129 | + |
| 130 | + def _load_image(self, image_id): |
| 131 | + """ Load an image """ |
| 132 | + if not hasattr(self, 'features_file'): |
| 133 | + # Loading the h5 file has to be done here and not in __init__ because when the DataLoader |
| 134 | + # forks for multiple works, every child would use the same file object and fail |
| 135 | + # Having multiple readers using different file objects is fine though, so we just init in here. |
| 136 | + self.features_file = h5py.File(self.image_features_path, 'r') |
| 137 | + index = self.coco_id_to_index[image_id] |
| 138 | + dataset = self.features_file['features'] |
| 139 | + img = dataset[index].astype('float32') |
| 140 | + return torch.from_numpy(img) |
| 141 | + |
| 142 | + def __getitem__(self, item): |
| 143 | + if self.answerable_only: |
| 144 | + # change of indices to only address answerable questions |
| 145 | + item = self.answerable[item] |
| 146 | + |
| 147 | + q, q_length = self.questions[item] |
| 148 | + a = self.answers[item] |
| 149 | + image_id = self.coco_ids[item] |
| 150 | + v = self._load_image(image_id) |
| 151 | + # since batches are re-ordered for PackedSequence's, the original question order is lost |
| 152 | + # we return `item` so that the order of (v, q, a) triples can be restored if desired |
| 153 | + # without shuffling in the dataloader, these will be in the order that they appear in the q and a json's. |
| 154 | + return v, q, a, item, q_length |
| 155 | + |
| 156 | + def __len__(self): |
| 157 | + if self.answerable_only: |
| 158 | + return len(self.answerable) |
| 159 | + else: |
| 160 | + return len(self.questions) |
| 161 | + |
| 162 | + |
| 163 | +# this is used for normalizing questions |
| 164 | +_special_chars = re.compile('[^a-z0-9 ]*') |
| 165 | + |
| 166 | +# these try to emulate the original normalisation scheme for answers |
| 167 | +_period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)') |
| 168 | +_comma_strip = re.compile(r'(\d)(,)(\d)') |
| 169 | +_punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!') |
| 170 | +_punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars))) |
| 171 | +_punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars)) |
| 172 | + |
| 173 | + |
| 174 | +def prepare_questions(questions_json): |
| 175 | + """ Tokenize and normalize questions from a given question json in the usual VQA format. """ |
| 176 | + questions = [q['question'] for q in questions_json['questions']] |
| 177 | + for question in questions: |
| 178 | + question = question.lower()[:-1] |
| 179 | + yield question.split(' ') |
| 180 | + |
| 181 | + |
| 182 | +def prepare_answers(answers_json): |
| 183 | + """ Normalize answers from a given answer json in the usual VQA format. """ |
| 184 | + answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json['annotations']] |
| 185 | + # The only normalisation that is applied to both machine generated answers as well as |
| 186 | + # ground truth answers is replacing most punctuation with space (see [0] and [1]). |
| 187 | + # Since potential machine generated answers are just taken from most common answers, applying the other |
| 188 | + # normalisations is not needed, assuming that the human answers are already normalized. |
| 189 | + # [0]: http://visualqa.org/evaluation.html |
| 190 | + # [1]: https://github.com/VT-vision-lab/VQA/blob/3849b1eae04a0ffd83f56ad6f70ebd0767e09e0f/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L96 |
| 191 | + |
| 192 | + def process_punctuation(s): |
| 193 | + # the original is somewhat broken, so things that look odd here might just be to mimic that behaviour |
| 194 | + # this version should be faster since we use re instead of repeated operations on str's |
| 195 | + if _punctuation.search(s) is None: |
| 196 | + return s |
| 197 | + s = _punctuation_with_a_space.sub('', s) |
| 198 | + if re.search(_comma_strip, s) is not None: |
| 199 | + s = s.replace(',', '') |
| 200 | + s = _punctuation.sub(' ', s) |
| 201 | + s = _period_strip.sub('', s) |
| 202 | + return s.strip() |
| 203 | + |
| 204 | + for answer_list in answers: |
| 205 | + yield list(map(process_punctuation, answer_list)) |
| 206 | + |
| 207 | + |
| 208 | +class CocoImages(data.Dataset): |
| 209 | + """ Dataset for MSCOCO images located in a folder on the filesystem """ |
| 210 | + def __init__(self, path, transform=None): |
| 211 | + super(CocoImages, self).__init__() |
| 212 | + self.path = path |
| 213 | + self.id_to_filename = self._find_images() |
| 214 | + self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order |
| 215 | + print('found {} images in {}'.format(len(self), self.path)) |
| 216 | + self.transform = transform |
| 217 | + |
| 218 | + def _find_images(self): |
| 219 | + id_to_filename = {} |
| 220 | + for filename in os.listdir(self.path): |
| 221 | + if not filename.endswith('.jpg'): |
| 222 | + continue |
| 223 | + id_and_extension = filename.split('_')[-1] |
| 224 | + id = int(id_and_extension.split('.')[0]) |
| 225 | + id_to_filename[id] = filename |
| 226 | + return id_to_filename |
| 227 | + |
| 228 | + def __getitem__(self, item): |
| 229 | + id = self.sorted_ids[item] |
| 230 | + path = os.path.join(self.path, self.id_to_filename[id]) |
| 231 | + img = Image.open(path).convert('RGB') |
| 232 | + |
| 233 | + if self.transform is not None: |
| 234 | + img = self.transform(img) |
| 235 | + return id, img |
| 236 | + |
| 237 | + def __len__(self): |
| 238 | + return len(self.sorted_ids) |
| 239 | + |
| 240 | + |
| 241 | +class Composite(data.Dataset): |
| 242 | + """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """ |
| 243 | + def __init__(self, *datasets): |
| 244 | + self.datasets = datasets |
| 245 | + |
| 246 | + def __getitem__(self, item): |
| 247 | + current = self.datasets[0] |
| 248 | + for d in self.datasets: |
| 249 | + if item < len(d): |
| 250 | + return d[item] |
| 251 | + item -= len(d) |
| 252 | + else: |
| 253 | + raise IndexError('Index too large for composite dataset') |
| 254 | + |
| 255 | + def __len__(self): |
| 256 | + return sum(map(len, self.datasets)) |
0 commit comments