PyTorch implementation for the paper:
Large-scale Pretraining for Visual Dialog: A Simple State-of-the-Art Baseline
Vishvak Murahari, Dhruv Batra, Devi Parikh, Abhishek Das
Prior work in visual dialog has focused on training deep neural models on the VisDial dataset in isolation, which has led to great progress, but is limiting and wasteful. In this work, following recent trends in representation learning for language, we introduce an approach to leverage pretraining on related large-scale vision-language datasets before transferring to visual dialog. Specifically, we adapt the recently proposed ViLBERT model for multi-turn visually-grounded conversation sequences. Our model is pretrained on the Conceptual Captions and Visual Question Answering datasets, and finetuned on VisDial with a VisDial-specific input representation and the masked language modeling and next sentence prediction objectives (as in BERT). Our best single model achieves state-of-the-art on Visual Dialog, outperforming prior published work (including model ensembles) by more than 1% absolute on NDCG and MRR.
This repository contains code for reproducing results with and without finetuning on dense annotations. All results are on v1.0 of the Visual Dialog dataset. We provide pretrained model weights and associated configs to run inference or train these models from scratch.
If you find this work useful in your research, please cite:
@article{visdial_bert
title={Large-scale Pretraining for Visual Dialog: A Simple State-of-the-Art Baseline},
author={Vishvak Murahari and Dhruv Batra and Devi Parikh and Abhishek Das},
journal={arXiv preprint arXiv:1912.02379},
year={2019},
}
Our code is implemented in PyTorch (v1.0). To setup, do the following:
- Install Python 3.6
- Get the source:
git clone https://github.com/vmurahari3/visdial-bert.git visdial-bert
- Install requirements into the
visdial-bert
virtual environment, using Anaconda:
conda env create -f env.yml
Make both the scripts in scripts/
executable
chmod +x scripts/download_preprocessed.sh
chmod +x scripts/download_checkpoints.sh
Download preprocessed dataset and extracted features:
sh scripts/download_preprocessed.sh
To get these files from scratch:
python preprocessing/pre_process_visdial.py
However, we recommend downloading these files directly.
Download pre-trained checkpoints:
sh scripts/download_checkpoints.sh
After running the above scripts, all the pre-processed data is downloaded to data/visdial
and the major pre-trained model checkpoints used in the paper are downloaded to checkpoints-release
Here we list the training arguments to train the important variants in the paper.
To train the base model (no finetuning on dense annotations):
python train.py -batch_size 80 -batch_multiply 1 -lr 2e-5 -image_lr 2e-5 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/vqa_pretrained_weights
To finetune the base model with dense annotations:
python dense_annotation_finetuning.py -batch_size 80 -batch_multiply 10 -lr 1e-4 -image_lr 1e-4 -nsp_loss_coeff 0 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/basemodel
To finetune the base model with dense annotations and the next sentence prediction (NSP) loss:
python dense_annotation_finetuning.py -batch_size 80 -batch_multiply 10 -lr 1e-4 -image_lr 1e-4 -nsp_loss_coeff 1 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/basemodel
NOTE: Dense annotation finetuning is currently only supported for 8-GPU training. This is primarily due to memory issues. To calculate the cross entropy loss over the 100 options at a dialog round, we need to have all the 100 dialog sequences in memory. However, we can only fit 80 sequences on 8 GPUs with ~12 GB RAM and we only select 80 options. Performance gets worse with fewer GPUs as we need to further cut down on the number of answer options.
The below code snippet generates a prediction file which can be submitted to the test server to get results on the test split.
python evaluate.py -n_gpus 8 -start_path <path to model> -save_name <name of model>
The metrics for the pretrained checkpoints should match with the numbers mentioned in the paper. However, we mention them below too. These results are on v1.0 test-std.
Checkpoint | Mean Rank | MRR | R1 | R5 | R10 | NDCG |
---|---|---|---|---|---|---|
basemodel | 3.32 | 67.50 | 53.85 | 84.68 | 93.25 | 63.87 |
basemodel + dense | 6.28 | 50.74 | 37.95 | 64.13 | 80.00 | 74.47 |
basemodel + dense + nsp | 4.28 | 63.92 | 50.78 | 79.53 | 89.60 | 68.08 |
We use Visdom for all logging. Specify visdom_server
, visdom_port
and enable_visdom
arguments in options.py to use this feature.
Coming soon
Builds on Jiasen Lu's ViLBERT implementation.
BSD