This repository contains the code for the paper Vision Transformer with Deformable Attention (CVPR2022) [arXiv][video][poster][source code].
The source code repository is here, which has been improved to be inconsistent with the structure written in the paper. The implementation of this repository is consistent with the paper, and some annotations have been added.
(a) Vision Transformer(ViT) has proved its superiority over many tasks thanks to its large or even global receptive field. However, this global attention leads to excessive computational costs. (b) Swin Transformer proposes shifted window attention, which is a more efficient sparse attention mechanism with linear computation complexity. Nevertheless, this hand-crafted attention pattern is likely to drop important features outside one window, and shifting windows impedes the growth of the receptive field, limiting modeling the long-range dependencies. (c) DCN expands the receptive fields of the standard convolutions with the learned offsets for each different query. Howbeit, directly applying this technique to the Vision Transformer is non-trivial for the quadratic space complexity and the training difficulties. (d) Deformable Attention (DAT) is proposed to model the relations among tokens effectively under the guidance of the important regions in the feature maps. This flexible scheme enables the self-attention module to focus on relevant regions and capture more informative features.
By learning several groups of offsets for the grid reference points, the deformed keys and values are sampled from these shifted locations. This deformable attention can capture the most informative regions in the image. On this basis, we present Deformable Attention Transformer (DAT), a general backbone model with deformable attention for both image classification and other dense prediction tasks.
Visualizations show the most important keys denotes in orange circles, where larger circles indicates higher attention scores. That the important keys cover the main parts of the objects demonstrates the effectiveness of DAT.
- NVIDIA GPU + CUDA 11.3
- Python 3.9 (>=3.6, recommend to use Anaconda)
- cudatoolkit == 11.3.1
- PyTorch == 1.11.0
- torchvision == 0.12.0
- numpy
- timm == 0.5.4
- einops
- PyYAML
- yacs
- termcolor
We provide the pretrained models in the tiny, small, and base versions of DAT, as listed below.
model | resolution | acc@1 | config | pretrained weights |
---|---|---|---|---|
DAT-Tiny | 224x224 | 82.0 | config | GoogleDrive / TsinghuaCloud |
DAT-Small | 224x224 | 83.7 | config | GoogleDrive / TsinghuaCloud |
DAT-Base | 224x224 | 84.0 | config | GoogleDrive / TsinghuaCloud |
DAT-Base | 384x384 | 84.8 | config | GoogleDrive / TsinghuaCloud |
To evaluate one model, please download the pretrained weights to your local machine and run the script evaluate.sh
as follow.
bash evaluate.sh <gpu_nums> <path-to-config> <path-to-pretrained-weights>
E.g., suppose evaluating the DAT-Tiny model (dat_tiny_in1k_224.pth
) with 8 GPUs, the command should be:
bash evaluate.sh 8 configs/dat_tiny.yaml dat_tiny_in1k_224.pth
And the evaluation result should give:
[2022-06-07 04:08:50 dat_tiny] (main.py 288): INFO * Acc@1 82.034 Acc@5 95.850
[2022-06-07 04:08:50 dat_tiny] (main.py 150): INFO Accuracy of the network on the 50000 test images: 82.0%
Outputs of the other models are:
[2022-06-07 04:19:42 dat_small] (main.py 288): INFO * Acc@1 83.686 Acc@5 96.392
[2022-06-07 04:19:42 dat_small] (main.py 150): INFO Accuracy of the network on the 50000 test images: 83.7%
[2022-06-07 04:24:35 dat_base] (main.py 288): INFO * Acc@1 84.028 Acc@5 96.686
[2022-06-07 04:24:35 dat_base] (main.py 150): INFO Accuracy of the network on the 50000 test images: 84.0%
[2022-06-07 06:43:07 dat_base_384] (main.py 288): INFO * Acc@1 84.754 Acc@5 96.982
[2022-06-07 06:43:07 dat_base_384] (main.py 150): INFO Accuracy of the network on the 50000 test images: 84.8%
To train a model from scratch, we provide a simple script train.sh
. E.g, to train a model with 8 GPUs on a single node, you can use this command:
bash train.sh 8 <path-to-config> <experiment-tag>
We also provide a training script train_slurm.sh
for training models on multiple machines with a larger batch-size like 4096.
bash train_slurm.sh 32 <path-to-config> <slurm-job-name>
Remember to change the <path-to-imagenet> in the script files to your own ImageNet directory.
Appreciate the work from the following repositories: