This is the official implementation for the CVPR 2022 paper On Guiding Visual Attention with Language Specification by Suzanne Petryk*, Lisa Dunlap*, Keyan Nasseri, Joseph Gonzalez, Trevor Darrell, and Anna Rohrbach.
If you find our code or paper useful, please cite:
@article{petryk2022gals,
title={On Guiding Visual Attention with Language Specification},
author={Petryk, Suzanne and Dunlap, Lisa and Nasseri, Keyan and Gonzalez, Joseph and Darrell, Trevor and Rohrbach, Anna},
journal={Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}
Conda
conda env create -f env.yaml
conda activate gals
Pip
pip install -r requirements.txt
Please see the original dataset pages for further detail:
- Waterbirds
- Waterbirds-95%: The original dataset page includes instructions on how to generate different biased splits. Please download Waterbirds-95% from the link for
waterbird_complete95_forest2water2
on their page. - Waterbirds-100%: We use the script from Sagawa & Koh et al. to generate the 100% biased split. For convenience, we supply this split of the dataset here: Waterbirds-100%.
- Waterbirds-95%: The original dataset page includes instructions on how to generate different biased splits. Please download Waterbirds-95% from the link for
- Food101: Original dataset page. Please download the images from the original dataset page. We construct the 5-class subset for Red Meat with images from the page.
- MSCOCO-ApparentGender: Original dataset page. Please use the original dataset page to download the COCO 2014 train & validation images and annotations. We base MSCOCO-ApparentGender on the dataset used in Women Also Snowboard (by Burns & Hendricks et al.). We modify the training IDs slightly, yet keep the same evaluation set. Please download the files about the splits here: MSCOCO-ApparentGender.
The data is expected to be under the folder ./data
. More specifically, here is the suggested data file structure:
./data
waterbird_complete95_forest2water2/
(Waterbirds-95%)waterbird_1.0_forest2water2/
(Waterbirds-100%)food-101/
(Red Meat)COCO/
annotations/
(COCO annotations from original dataset page)train2014/
(COCO images from original dataset page)val2014/
( COCO images from original dataset page)COCO_gender/
(ApparentGender files we provided)
main.py
is the point of entry for model training.configs/
contains.yaml
configuration files for each dataset and model type.extract_attention.py
is the script to precompute attention with CLIP ResNet50 GradCAM and CLIP ViT transformer attention.approaches/
contains training code. approaches/base.py is for general training, and is extended by model-specific approaches such as approaches/abn.py, or datasets requiring extra evaluation (such asapproaches/coco_gender.py
).datasets/
contains PyTorch dataset creation files.models/
contains architectures for both vanilla and ABN ResNet50 classification models.utils/
contains helper functions for general training, loss and attention computation.
This repo also expects the following additional folders:
./data
: contains the dataset folders./weights
: contains pretrained ImageNet ResNet50 weights for the ABN model, namedresnet50_abn_imagenet.pth.tar
. These weights are provided by Hiroshi Fukui & Tsubasa Hirakawa from their codebase. For convenience, you may also find the weights with the correct naming here.
We use Weights & Biases to log experiments. This requires the user to be logged in to a (free) W&B account. Details to set up an account here.
Training models using GALS is a 2 stage process:
- Generate and store attention per image
- Train model using attention
Example commands training networks with GALS as well as the baselines within the paper are below.
NOTE: To change .yaml
configuration values on the command line, add text of the form ATTRIBUTE.NESTED=new_value
to the end of the command. For example:
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/waterbirds_100_gals.yaml DATA.BATCH_SIZE=96
Important files:
extract_attention.py
is the script to precompute attention with CLIP ResNet50 GradCAM and CLIP ViT transformer attention.configs/
includes configuration files for precomputing attention (relevant files should end in_attention.yaml
).CLIP/
contains a slightly modified copy of OpenAI's CLIP repository. It is drawn directly from Hila Chefer's repository on computing transformer attention.Visualize_VL_Attention.ipynb
visualizes attention for users to play around with different VL models/language specification.
Sample command:
CUDA_VISIBLE_DEVICES=0 python extract_attention.py --config configs/coco_attention.yaml
Important files:
approaches/base.py
is where most of the training code is.configs/
includes configuration files for model training.
The model configs include the hyperparameters and attention settings used to reproduce results in our paper.
An example command to train a model with GALS on Waterbirds-100%:
CUDA_VISIBLE_DEVICES=0,1,2 python main.py --name waterbirds100_gals --config configs/waterbirds_100_gals.yaml
The --name
flag is used for Weights & Biases logging. You can add --dryrun
to the command to run locally without uploading to the W&B server. This can be useful for debugging.
To evaluate a model on the test split for a given dataset, simply use the --test_checkpoint
flag and provide a path to a trained checkpoint. For example, to evaluate a Waterbirds-95% GALS model with weights under a trained_weights
directory
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/waterbirds_95_gals.yaml --test_checkpoint trained_weights/waterbirds_95_gals.ckpt
Note: For MSCOCO-ApparentGender, the Ratio Delta
in our paper is 1-test_ratio
in the output results.
In our paper, we report the mean and standard deviation over 10 trials. Below, we include a checkpoint from a single trial per experiment.
Waterbirds 100%
Method | Per Group Acc (%) | Worst Group Acc (%) |
---|---|---|
GALS | 80.67 | 57.00 |
Vanilla | 72.36 | 32.20 |
UpWeight | 72.22 | 37.29 |
ABN | 71.96 | 44.39 |
Waterbirds 95%
Method | Per Group Acc (%) | Worst Group Acc (%) |
---|---|---|
GALS | 89.03 | 79.91 |
Vanilla | 86.91 | 73.21 |
UpWeight | 87.51 | 76.48 |
ABN | 86.85 | 69.31 |
Red Meat (Food101)
Method | Acc (%) | Worst Group Acc (%) |
---|---|---|
GALS | 72.24 | 58.00 |
Vanilla | 69.20 | 48.80 |
ABN | 69.28 | 52.80 |
MSCOCO-ApparentGender
Method | Ratio Delta | Outcome Divergence |
---|---|---|
GALS | 0.160 | 0.022 |
Vanilla | 0.349 | 0.071 |
UpWeight | 0.272 | 0.040 |
ABN | 0.334 | 0.068 |
We are very grateful to the following people, from which we have used code throughout this repository that is taken or based off of their work:
- Hila Chefer, Shir Gur, Lior Wolf:
https://github.com/hila-chefer/Transformer-MM-Explainability
- Kaylee Burns, Lisa Anne Hendricks, Kate Saenko, Trevor Darrell, Anna Rohrbach:
https://github.com/kayburns/women-snowboard/tree/master/research/im2txt
- Vitali Petsiuk, Abir Das, Kate Saenko:
https://github.com/eclique/RISE
- Shiori Sagawa, Pang Wei Koh, Tatsunori Hashimoto, and Percy Liang:
https://github.com/kohpangwei/group_DRO
- Kazuto Nakashima:
https://github.com/kazuto1011/grad-cam-pytorch
- Hiroshi Fukui, Tsubasa Hirakawa, Takayoshi Yamashita, Hironobu Fujiyoshi:
https://github.com/machine-perception-robotics-group/attention_branch_network