Skip to content

Commit

Permalink
re-factor
Browse files Browse the repository at this point in the history
  • Loading branch information
markdtw committed Feb 8, 2019
1 parent 9ac023a commit 88894c8
Show file tree
Hide file tree
Showing 10 changed files with 730 additions and 554 deletions.
86 changes: 44 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# 2017 VQA Challenge Winner (CVPR'17 Workshop)
Pytorch implementation of [Tips and Tricks for Visual Question Answering: Learnings from the 2017 Challenge by Teney et al](https://arxiv.org/pdf/1708.02711.pdf).
pytorch implementation of [Tips and Tricks for Visual Question Answering: Learnings from the 2017 Challenge by Teney et al](https://arxiv.org/pdf/1708.02711.pdf).

![Model architecture](https://i.imgur.com/phBHIqZ.png)

## Prerequisites
- Python 2.7+
- [NumPy](http://www.numpy.org/)
- [PyTorch](http://pytorch.org/)
- [tqdm](https://pypi.python.org/pypi/tqdm) (visualizing preprocessing progress only)
- [nltk](http://www.nltk.org/install.html) (and [this](https://nlp.stanford.edu/software/tokenizer.shtml) to tokenize questions)
- python 3.6+
- numpy
- [pytorch](http://pytorch.org/) 0.4
- [tqdm](https://pypi.python.org/pypi/tqdm)
- [nltk](http://www.nltk.org/install.html)
- [pandas](https://pandas.pydata.org/)


## Data
Expand All @@ -18,51 +19,52 @@ Pytorch implementation of [Tips and Tricks for Visual Question Answering: Learni


## Preparation
- For questions and answers, go to `data/` folder and execute `preproc.py` directly.
- You'll need to install the Stanford Tokenizer, follow the instructions in [their page](https://nlp.stanford.edu/software/tokenizer.shtml).
- The tokenizing step may take up to 36 hrs to process the training questions (I have a Xeon E5 CPU already), write a pure java code to tokenize them should be a lot faster. (Since python nltk will call the java binding, and python is slow)
- For image feature, slightly modify [this code](https://github.com/peteanderson80/bottom-up-attention/blob/master/tools/read_tsv.py) to convert tsv to a npy file `coco_features.npy` that contains a list of dictionaries with key being image id and value being the feature (shape: 36, 2048).
- Download and extract GloVe to `data/` folder as well.
- Now we should be able to train, reassure that the `data/` folder should now contain at least:
- To download and extract vqav2, glove, and pretrained visual features:
```bash
bash scripts/download_extract.sh
```
- glove.6B.300d.txt
- vqa_train_final.json
- coco_features.npy
- train_q_dict.p
- train_a_dict.p
- To prepare data for training:
```bash
python scripts/preproc.py
```
- The structure of `data/` directory should look like this:
```
- data/
- zips/
- v2_XXX...zip
- ...
- glove...zip
- trainval_36.zip
- glove/
- glove...txt
- ...
- v2_XXX.json
- ...
- trainval_resnet...tsv
(The above are files created after executing scripts/download_extract.sh)
- tokenizers/
- ...
- dict_ans.pkl
- dict_q.pkl
- glove_pretrained_300.npy
- train_qa.pkl
- val_qa.pkl
- train_vfeats.pkl
- val_vfeats.pkl
(The above are files created after executing scripts/preproc.py)
```
- (*Update*) For convenience, [here](https://drive.google.com/open?id=0B5j6QKJb0ztbYmVXT0hBUF91RHM) is the link to tokenized questions `vqa_train_toked.json` and `vqa_val_toked.json`, make sure you run `data/preproc.py` to generate `vqa_train_final.json`, `train_q_dict.p`, etc.


## Train
Use default parameters:
```bash
python main.py --train
bash scripts/train.sh
```
Train from a previous checkpoint:
```bash
python main.py --train --modelpath=/path/to/saved.pth.tar
```
Check out tunable parameters:
```bash
python main.py
```

## Evaluate
```bash
python main.py --eval
```
This will generate `result.json` (validation set only), format is referred to [vqa evaluation format](http://www.visualqa.org/evaluation.html).


## Notes
- The default classifier is softmax classifier, sigmoid multi-label classifier is also implemented but I can't train based on that.
- Training for 50 epochs reach around 64.42% training accuracy.
- For the output classifier, I did not use the pretrained weight since it's hard to retrieve so I followed *eq. 5* in the paper.
- To prepare validation data you need to uncomment some line of code in `data/preproc.py`.
- `coco_features.npy` is a really fat file (34GB including train+val image features), you can split it and modify the data loading mechanisms in `loader.py`.
- This code is tested with train = train and eval = val, no test data included.
- Issues are welcome!
- Huge re-factor (especially data preprocessing), tested based on pytorch 0.4.1 and python 3.6
- Training for 20 epochs reach around 50% training accuracy. (model seems buggy in my implementation)
- After all the preprocessing, `data/` directory may be up to 38G+
- Some of `preproc.py` and `utils.py` are based on [this repo](https://github.com/hengyuan-hu/bottom-up-attention-vqa)


## Resources
Expand Down
174 changes: 0 additions & 174 deletions data/preproc.py

This file was deleted.

79 changes: 79 additions & 0 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import division, print_function, absolute_import

import os
import pdb
import pickle

import torch
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm

class VQAv2(Dataset):

def __init__(self, root, train, seqlen=14):
"""
root (str): path to data directory
train (bool): training or validation
seqlen (int): maximum words in a question
"""
if train:
prefix = 'train'
else:
prefix = 'val'
print("Loading preprocessed files... ({})".format(prefix))
qas = pickle.load(open(os.path.join(root, prefix + '_qa.pkl'), 'rb'))
idx2word, word2idx = pickle.load(open(os.path.join(root, 'dict_q.pkl'), 'rb'))
idx2ans, ans2idx = pickle.load(open(os.path.join(root, 'dict_ans.pkl'), 'rb'))
vfeats = pickle.load(open(os.path.join(root, prefix + '_vfeats.pkl'), 'rb'))

print("Setting up everything... ({})".format(prefix))
self.vqas = []
for qa in tqdm(qas):
que = np.ones(seqlen, dtype=np.int64) * len(word2idx)
for i, word in enumerate(qa['question_toked']):
if word in word2idx:
que[i] = word2idx[word]

ans = np.zeros(len(idx2ans), dtype=np.float32)
for a, s in qa['answer']:
ans[ans2idx[a]] = s

self.vqas.append({
'v': vfeats[qa['image_id']],
'q': que,
'a': ans,
'q_txt': qa['question'],
'a_txt': qa['answer']
})

def __len__(self):
return len(self.vqas)

def __getitem__(self, idx):
return self.vqas[idx]['v'], self.vqas[idx]['q'], self.vqas[idx]['a'], self.vqas[idx]['q_txt'], self.vqas[idx]['a_txt']

@staticmethod
def get_n_classes(fpath=os.path.join('data', 'dict_ans.pkl')):
idx2ans, _ = pickle.load(open(fpath, 'rb'))
return len(idx2ans)

@staticmethod
def get_vocab_size(fpath=os.path.join('data', 'dict_q.pkl')):
idx2word, _ = pickle.load(open(fpath, 'rb'))
return len(idx2word)


def prepare_data(args):

train_loader = torch.utils.data.DataLoader(
VQAv2(root=args.data_root, train=True),
batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=args.pin_mem)

val_loader = torch.utils.data.DataLoader(
VQAv2(root=args.data_root, train=False),
batch_size=args.vbatch_size, shuffle=False, num_workers=args.n_workers, pin_memory=args.pin_mem)

vocab_size = VQAv2.get_vocab_size()
num_classes = VQAv2.get_n_classes()
return train_loader, val_loader, vocab_size, num_classes
Loading

0 comments on commit 88894c8

Please sign in to comment.