|
| 1 | +# Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability |
| 2 | + |
| 3 | +This repository contains code for running and replicating the experiments from [Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability](https://arxiv.org/abs/2010.09750). It is a modified fork of [Classifier-Agnostic Saliency Map Extraction](https://github.com/kondiz/casme), and contains the code originally forked from the [ImageNet training in PyTorch](https://github.com/pytorch/examples/tree/master/imagenet). |
| 4 | + |
| 5 | +<p align="center"> |
| 6 | + <br> |
| 7 | + <img src="./saliency_overview.png"/> |
| 8 | + <br> |
| 9 | +<p> |
| 10 | + |
| 11 | +*(A) Overview of the training setup for our final model. The masker is trained to maximize masked-in classification accuracy and masked-out prediction entropy.* |
| 12 | +*(B) Masker architecture. The masker takes as input the hidden activations of different layers of the ResNet-50 and produces a mask of the same resolution as the input image.* |
| 13 | +*(C) Few-shot training of masker. Performance drops only slightly when trained on much fewer examples compared to the full training procedure.* |
| 14 | + |
| 15 | + |
| 16 | +## Software requirements |
| 17 | + |
| 18 | +- This repository requires Python 3.7 or later. |
| 19 | +- Experiments were run with the following library versions: |
| 20 | + |
| 21 | +``` |
| 22 | +pytorch==1.4.0 |
| 23 | +torchvision==0.5.0 |
| 24 | +opencv-python==4.1.2.30 |
| 25 | +beautifulsoup4==4.8.1 |
| 26 | +tqdm==4.35.0 |
| 27 | +pandas==0.24.2 |
| 28 | +scikit-learn==0.20.2 |
| 29 | +scipy==1.3.0 |
| 30 | +``` |
| 31 | + |
| 32 | +In addition, `git clone https://github.com/zphang/zutils` and add it to your `PYTHONPATH` |
| 33 | + |
| 34 | + |
| 35 | +**Additional requirements** |
| 36 | + |
| 37 | +- If you want to use the PxAP metric from [Evaluating Weakly Supervised Object Localization Methods Right](https://arxiv.org/abs/2007.04178): |
| 38 | + - `git clone https://github.com/clovaai/wsolevaluation` and add it to your `PYTHONPATH`. To avoid potential naming conflicts, add it to the front of your `PYTHONPATH`. |
| 39 | + - `pip install munch` (as well as any other requirements listed [here](https://github.com/clovaai/wsolevaluation#3-code-dependencies)) |
| 40 | +- If you want to run the Grad-CAM and Guided-backprop saliency methods: |
| 41 | + - `pip install torchray`, or `git clone https://github.com/facebookresearch/TorchRay` and add it to your `PYTHONPATH` |
| 42 | +- If you want to use the CA-GAN infiller from [Generative Image Inpainting with Contextual Attention](https://arxiv.org/abs/1801.07892) |
| 43 | + - `git clone https://github.com/daa233/generative-inpainting-pytorch` and add it to your `PYTHONPATH` |
| 44 | + - Download the linked [pretrained model](https://drive.google.com/drive/folders/123-F7eSAXJzgztYj2im2egsuk4mY5LWa) for PyTorch, and set environment variable `CA_MODEL_PATH` to point to it |
| 45 | +- If you want to use the DFNet infiller from [https://arxiv.org/abs/1904.08060](https://arxiv.org/abs/1904.08060) |
| 46 | + - `git clone https://github.com/hughplay/DFNet` and add it to your `PYTHONPATH` |
| 47 | + - Download the linked [pretrained model](https://github.com/hughplay/DFNet#testing) for PyTorch, and set environment variable `DFNET_MODEL_PATH` to point to it. Use the Places 2 model. |
| 48 | + |
| 49 | + |
| 50 | +## Data requirements |
| 51 | + |
| 52 | +- ImageNet dataset should be stored in `IMAGENET_PATH` path and set up in the usual way (separate `train` and `val` folders with 1000 subfolders each). See [this repo](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset) for detailed instructions how to download and set up the dataset. |
| 53 | +- ImageNet bounding box annotations should be in `IMAGENET_ANN` directory that contains 50000 files named `ILSVRC2012_val_<id>.xml` where `<id>` is the validation image id (for example `ILSVRC2012_val_00050000.xml`). It may be simply obtained by unzipping [the official validation bounding box annotations archive](http://www.image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_bbox_val_v3.tgz) to `IMAGENET_ANN` directory. |
| 54 | +- Bounding box annotations for parts of the training set can downloaded from [here](http://image-net.org/Annotation/Annotation.tar.gz). This will be used for our Train-Validation set. |
| 55 | +- If want to use the PxAP metrics from [Evaluating Weakly Supervised Object Localization Methods Right](https://arxiv.org/abs/2007.04178): |
| 56 | + - Download the relevant datasets in described [here](https://github.com/clovaai/wsolevaluation#2-dataset-downloading-and-license) |
| 57 | + |
| 58 | +## Running the code |
| 59 | + |
| 60 | +We will assume that experiments will be run in the following folder: |
| 61 | + |
| 62 | +```bash |
| 63 | +export EXP_DIR=/path/to/experiments |
| 64 | +``` |
| 65 | + |
| 66 | +### Data Preparation |
| 67 | +To facilitate easy subsetting and label shuffling for the ImageNet training set, we write a JSON files containing the paths to the example images, and their corresponding labels. These will be consumed by a modified ImageNet PyTorch Dataset. |
| 68 | + |
| 69 | +Run the following command: |
| 70 | + |
| 71 | +```bash |
| 72 | +python casme/tasks/imagenet/preproc.py \ |
| 73 | + --train_path ${IMAGENET_PATH}/train \ |
| 74 | + --val_path ${IMAGENET_PATH}/val \ |
| 75 | + --val_annotation_path ${IMAGENET_ANN}/val \ |
| 76 | + --output_base_path ${EXP_DIR}/metadata |
| 77 | +``` |
| 78 | + |
| 79 | +This script does several things: |
| 80 | + |
| 81 | +- Packages the ImageNet Train and Validation image data and labels into metadata JSON files (`train.json`, `val.json`) |
| 82 | +- Splits the train data into Train-Train and Train-Validation subsets (`train_train.json`, `train_val.json`) |
| 83 | +- Generates a shuffled version of the Train JSON (`train_shuffle.json`) for DRT Sanity Check |
| 84 | +- Also packages the bounding box annotation for the Validation set into JSON files (`val_bboxes.json`) |
| 85 | +- Optionally, to use bounding boxes for the Train-Validation set, unzip the downloaded data from [here](http://image-net.org/Annotation/Annotation.tar.gz), and provided an additional argument `--extended_annot_base_path`. (`train_val_bboxes.json`) |
| 86 | + |
| 87 | +### Training |
| 88 | + |
| 89 | +To train a FIX or CA model, you can run: |
| 90 | + |
| 91 | +```bash |
| 92 | +python train_casme.py \ |
| 93 | + --train_json ${EXP_DIR}/metadata/train.json \ |
| 94 | + --val_json ${EXP_DIR}/metadata/val.json \ |
| 95 | + --ZZsrc ./assets/fix.json \ |
| 96 | + --masker_use_layers 3,4 \ |
| 97 | + --output_path ${EXP_DIR}/runs/ \ |
| 98 | + --epochs 60 --lrde 20 \ |
| 99 | + --name fix |
| 100 | + |
| 101 | +python train_casme.py \ |
| 102 | + --train_json ${EXP_DIR}/metadata/train.json \ |
| 103 | + --val_json ${EXP_DIR}/metadata/val.json \ |
| 104 | + --ZZsrc ./assets/ca.json \ |
| 105 | + --masker_use_layers 3,4 \ |
| 106 | + --output_path ${EXP_DIR}/runs/ \ |
| 107 | + --epochs 60 --lrde 20 \ |
| 108 | + --name ca |
| 109 | +``` |
| 110 | + |
| 111 | +- The `--ZZsrc` arguments provide JSON files with additional options for the command-line interface. `./assets/fix.json` and `./assets/ca.json` contain options and final hyper-parameters chosen for the FIX and CA models in the paper. |
| 112 | +- We also only use the 4th and 5th layers from the classifier in the masker model. |
| 113 | +- `--train_json` and `--val_json` point to the JSON files containing the paths to the example images, and their corresponding labels, described above. |
| 114 | + |
| 115 | +### Evaluation |
| 116 | + |
| 117 | +To evaluate the model on WSOL metrics and Saliency Metric, run: |
| 118 | + |
| 119 | +```bash |
| 120 | +python casme/tasks/imagenet/score_bboxes.py \ |
| 121 | + --val_json ${EXP_DIR}/metadata/val.json \ |
| 122 | + --mode casme \ |
| 123 | + --bboxes_path ${EXP_DIR}/metadata/val_bboxes.json \ |
| 124 | + --casm_path ${EXP_DIR}/runs/ca/epoch_XXX.chk \ |
| 125 | + --output_path ${EXP_DIR}/runs/ca/metrics/scores.json |
| 126 | +``` |
| 127 | + |
| 128 | +where `epoch_XXX.chk` corresponds to the model checkpoint you want to evaluate. Chain the `val_json` and `bboxes_path` paths to evaluate on the Train-Validation or Validation sets respectively. Note that the mode should be `casme` regardless of whether you are using FIX or CA models. |
| 129 | + |
| 130 | +The output JSON looks something like this: |
| 131 | +``` |
| 132 | +{ |
| 133 | + "F1": 0.6201832851563015, |
| 134 | + "F1a": 0.5816041554785251, |
| 135 | + "OM": 0.48426, |
| 136 | + "LE": 0.35752, |
| 137 | + "SM": 0.523097248590095, |
| 138 | + "SM1": -0.5532185246243142, |
| 139 | + "SM2": -1.076315772478443, |
| 140 | + "top1": 75.222, |
| 141 | + "top5": 92.488, |
| 142 | + "sm_acc": 74.124, |
| 143 | + "binarized": 0.4486632848739624, |
| 144 | + "avg_mask": 0.44638757080078123, |
| 145 | + "std_mask": 0.1815464876794815, |
| 146 | + "entropy": 0.034756517103545534, |
| 147 | + "tv": 0.006838996527194977 |
| 148 | +} |
| 149 | +``` |
| 150 | + |
| 151 | +- OM, LE, F1, SM and `avg_mask` correspond to the respective columns in Table 1. |
| 152 | +- For a given image, an F1-score is compute for each of the bounding boxes. F1 takes the max while F1a takes the mean F1-score for all boxes in the image, and the result is averaged over all the images in the dataset. |
| 153 | +- SM1 and SM2 refer to the first and second terms of the Saliency Metric formulation. `sm_acc` is the top-1 accuracy under the crop-and-scale transformation for the Saliency Metric. |
| 154 | +- Top 1 and Top 5 are the accuracies of the classifier. |
| 155 | +- Binarized the is average over the binarized mask pixels over the whole dataset. `std_mask` is the Standard deviation of the continuous mask pixels over the dataset. |
| 156 | +- TV is the total variation, entropy is the entropy over predictions for masked imaged. |
| 157 | + |
| 158 | +To evaluate the model on PxAP, run: |
| 159 | + |
| 160 | +```bash |
| 161 | +python casme/tasks/imagenet/wsoleval.py \ |
| 162 | + --cam_loader casme \ |
| 163 | + --casm_base_path ${EXP_DIR}/runs/ca/epoch_XXX.chk \ |
| 164 | + --casme_load_mode specific \ |
| 165 | + --dataset OpenImages \ |
| 166 | + --dataset_split test \ |
| 167 | + --dataset_path ${WSOLEVAL_PATH}/dataset \ |
| 168 | + --metadata_path ${WSOLEVAL_PATH}/metadata \ |
| 169 | + --output_base_path ${EXP_DIR}/runs/ca/metrics/scores.json |
| 170 | +``` |
| 171 | + |
| 172 | +where `WSOLEVAL_PATH` is the location where [wsolevaluation](https://github.com/clovaai/wsolevaluation) has been cloned to, and after running the relevant dataset downloading scripts. |
| 173 | + |
| 174 | +## Pretrained Checkpoints |
| 175 | + |
| 176 | +- [fix.chk](https://drive.google.com/file/d/1m4-oHYZYalk4VKcs_GDh-9UB657PcMSw/view?usp=sharing) corresponds to our best-performing FIX model (Row I of Table 1). |
| 177 | +- [ca.chk](https://drive.google.com/file/d/1RPoTtj4RWtx8QsJ9RokFi5r6Wd4y805v/view?usp=sharing) corresponds to our best-performing CA model (Row J of Table 1). |
| 178 | + |
| 179 | +## Reference |
| 180 | + |
| 181 | +If you found this code useful, please cite [the following paper](https://arxiv.org/abs/2010.09750): |
| 182 | + |
| 183 | +Jason Phang, Jungkyu Park, Krzysztof J. Geras **"Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability."** *arXiv preprint arXiv:2010.09750 (2020).* |
| 184 | +``` |
| 185 | +@article{phang2020investigating, |
| 186 | + title={Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability}, |
| 187 | + author={Phang, Jason and Park, Jungkyu and Geras, Krzysztof J}, |
| 188 | + journal={arXiv preprint arXiv:2010.09750, |
| 189 | + year={2020} |
| 190 | +} |
| 191 | +``` |
0 commit comments