Codebase of our ICML'22 paper "Visual Attention Emerges from Recurrent Sparse Reconstruction".
Install PyTorch 1.7.0+ and torchvision 0.8.1+ from the official website.
requirements.txt
lists all the dependencies:
pip install -r requirements.txt
In addition, please also install the magickwand library:
apt-get install libmagickwand-dev
Take RVT-Ti with VARS-D for an example. We use single node with 8 gpus for training:
python -m torch.distributed.launch --nproc_per_node=8 --master_port 12345 main.py --model rvt_tiny --data-path path/to/imagenet --output_dir output/here --num_workers 8 --batch-size 128 --attention vars_d
We provide pretrained weights for VARS-D and VARS-SD.
To train models with different scales or different attention algorithms, please change the arguments --model
and --attention
.
python main.py --model rvt_tiny --data-path path/to/imagenet --eval --resume path/to/checkpoint --attention vars_d
To enable robustness evaluation, please add one of --inc_path /path/to/imagenet-c
, --ina_path /path/to/imagenet-a
, --inr_path /path/to/imagenet-r
or --insk_path /path/to/imagenet-sketch
to test ImageNet-C, ImageNet-A, ImageNet-R or ImageNet-Sketch.
If you want to test the accuracy under adversarial attackers, please add --fgsm_test
or --pgd_test
.
This codebase is built upon the official code of "Towards Robust Vision Transformer".
If you found this code helpful, please consider citing our work:
@article{shi2022visual,
title={Visual Attention Emerges from Recurrent Sparse Reconstruction},
author={Shi, Baifeng and Song, Yale and Joshi, Neel and Darrell, Trevor and Wang, Xin},
journal={arXiv preprint arXiv:2204.10962},
year={2022}
}