This repository contains the official implementation of the paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers". (NeurIPS 2024)
- Overview
- Prerequisites
- Docker Support
- Installation
- Dataset Preparation
- Training
- Analysis
- Model Zoo
- Citation
- Acknowledgments
ProtoViT is a novel approach that combines Vision Transformers with prototype-based learning to create interpretable image classification models. Our implementation provides both high accuracy and explainability through learned prototypes.
These packages should be enough to reproduce our results. We add requirement.txt based on our conda environment for reference just in case.
- Python 3.8+
- PyTorch with cuda
- NumPy
- OpenCV (cv2)
- Augmentor
- Timm==0.4.12 (Note: Higher versions may require modifications to the ViT encoder)
Recommended GPU configurations:
- 1× NVIDIA Quadro RTX 6000 (24GB) or
- 1× NVIDIA GeForce RTX 4090 (24GB) or
- 1× NVIDIA RTX A6000 (48GB)
We provide Docker support for easy deployment and reproducibility. The Dockerfile includes all necessary dependencies and CUDA support.
# Pull the pre-built image
docker pull ayushnangia16/protovit:v1
# Run the container
docker run --gpus all -it --rm -v /path/to/your/data:/app/datasets ayushnangia16/protovit:v1# Build the Docker image
docker build -t protovit .
# Run the container
docker run --gpus all -it --rm -v /path/to/your/data:/app/datasets protovitgit clone https://github.com/Henrymachiyu/ProtoViT.git
cd ProtoViT
pip install -r requirements.txt- Download CUB_200_2011.tgz
- Extract the dataset:
#Download the dataset CUB_200_2011.tgz from http://www.vision.caltech.edu/visipedia/CUB-200-2011.html tar -xzf CUB_200_2011.tgz - Process the dataset using our preprocessing tools:
# Preprocess CUB dataset (crops and splits data) python preprocess_cub.py # Augment training data python augment_data.py
We also provide support for the Pinterest dataset:
-
Prepare your Pinterest dataset in the following structure:
105_classes_pins_dataset/ ├── class1/ ├── class2/ └── ... -
Process the dataset:
# Preprocess Pinterest dataset python preprocess_pins.py # Augment training data python augment_pins.py
- Configure settings in
settings.py:
# Dataset paths
data_path = "./datasets/cub200_cropped/"
train_dir = data_path + "train_cropped_augmented/"
test_dir = data_path + "test_cropped/"
train_push_dir = data_path + "train_cropped/"- Start training:
python main.pyThe corresponsing parameter settings for global and local analysis are saved in the analysis_settings.py
load_model_dir = 'saved model path'#'./saved_models/vgg19/003/'
load_model_name = 'model_name'#'14finetuned0.9230.pth'
save_analysis_path = 'saved_dir_rt'
img_name = 'prototype_vis_file'# 'img/'
test_data = "test_dir"
check_test_acc = False
check_list =['list of test images'] #"163_Mercedes-Benz SL-Class Coupe 2009/03123.jpg", Could be a list of imagesTo produce the reasoning plots:
We analyze nearest prototypes for specific test images and retrieve model reasoning process for predictions. The visualization tools in prototype_visualization.py help create detailed visual explanations of the model's decision-making process.
# this function provdes results for model's reasoning and local analysis
python local_analysis.py -gpuid 0To produce the global analysis plots:
This following file finds nearest patches for each prototype to ensure the prototypes are semantically consistent across samples in train and test data:
python global_analysis.py -gpuid 0To run the experiment, you would also need cleverhans
pip install cleverhansAll the parameters used for reproducing our results on location misalignment are stored in adv_settings.py
load_model_path = "."
test_dir = "./cub200_cropped/test_cropped"
model_output_dir = "." # dir for saving all the results To run the adversarial attack and retrieve the results
cd ./spatial_alignment_test
python run_adv_test.py # as default, we ran experiment over entire test setWe provide checkpoints after projection and last layer finetuning on CUB-200-2011 dataset.
| Model Version | Backbone | Resolution | Accuracy | Checkpoint |
|---|---|---|---|---|
| ProtoViT-T | DeiT-Tiny | 224×224 | 83.36% | Download |
| ProtoViT-S | DeiT-Small | 224×224 | 85.30% | Download |
| ProtoViT-CaiT | CaiT_xxs24 | 224×224 | 86.02% | Download |
This implementation is based on the timm, ProtoPNet repository and its variations. We thank the authors for their valuable work.
If you have any questions regarding the paper or implementations, please don't hesitate to email us: chiyu.ma.gr@dartmouth.edu
Feel free to ⭐ the repo, contribute, or share it with others who might find it useful!
If you find this work helpful in your research, please also consider citing:
@article{ma2024interpretable,
title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers},
author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan},
journal={arXiv preprint arXiv:2410.20722},
year={2024}
}This project is licensed under the MIT License.


