Writer : Hiroshi Fukui
Maintainer : Tsubasa Hirakawa
- 15 Apr 2020: Add Docker environment (torch 0.4.0) [0.4.0 implementation]
- 02 Sep 2022: Update PyTorch version (torch 1.12.0)
This repository contains the source code of Attention Branch Network for image classification. The attention branch network is designed to extend the top-down visual explanation model by introducing a branch structure with an attention mechanism. ABN improves CNN’s performance and visual explanation at the same time by gnerating attention maps in the forward pass.
[CVF open access] [arXiv paper]
If you find this repository is useful. Please cite the following references.
@article{fukui2018cvpr,
author = {Hiroshi Fukui and Tsubasa Hirakawa and Takayoshi Yamashita and Hironobu Fujiyoshi},
title = {Attention Branch Network: Learning of Attention Mechanism for Visual Explanation},
journal = {Computer Vision and Pattern Recognition},
year = {2019},
pages = {10705-10714}
}
@article{fukui2018arxiv,
author = {Hiroshi Fukui and Tsubasa Hirakawa and Takayoshi Yamashita and Hironobu Fujiyoshi},
title = {Attention Branch Network: Learning of Attention Mechanism for Visual Explanation},
journal = {arXiv preprint arXiv:1812.10025},
year = {2018}
}
Our source code is based on https://github.com/bearpaw/pytorch-classification/ implemented with PyTorch. We are grateful for the author!
Now, we have updated PyTorch version (1.12.0
) from the original implementation (0.4.0
).
Required PyTorch version is as follows:
- PyTorch : 1.12.0
- PyTorch vision : 0.13.0
We prepared Docker environments for ABN. You can quickly start to use Docker and run scripts. For more details, please see docker/README.md.
Example of run command is as follows:
# CIFAR-100 dataset
python3 cifar.py -a resnet --dataset cifar100 --depth 110 --epochs 300 --schedule 150 225 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-110 --gpu-id 0,1
# ImageNet dataset
python3 imagenet.py -a resnet152 --data ../../dataset/imagenet_data/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnet152 --gpu-id 4,5,6,7 --test-batch 100
# CIFAR-100 dataset
python3 cifar.py -a resnet --dataset cifar100 --depth 110 --epochs 300 --schedule 150 225 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-110 --gpu-id 0,1 --evaluate --resume checkpoints/cifar100/resnet-110/model_best.pth.tar
# ImageNet dataset
python3 imagenet.py -a resnet152 --data ../../../../dataset/imagenet_data/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnet152 --gpu-id 4,5,6 --test-batch 10 --evaluate --resume checkpoints/imagenet/resnet152/model_best.pth.tar
If you try the other models, please see TRAINING.md.
We have published the model files of ABN, which are ResNet family models on CIFAR100 and ImageNet2012 dataset.
top-1 error (ABN) | top-1 error (original) | |
---|---|---|
ResNet110 | 22.5 | 24.1 |
DenseNet | 21.6 | 22.5 |
Wide ResNet | 18.1 | 18.9 |
ResNeXt | 17.7 | 18.3 |
top-1 error (ABN) | top-1 error (original) | |
---|---|---|
ResNet50 | 23.1 | 24.1 |
ResNet101 | 21.8 | 22.5 |
ResNet152 | 21.4 | 22.2 |