[Project Page] | [arXiv] | [Data Repository]
We report counter-intuitive observations that theoretically unjustified design choices
for attributing diffusion models empirically outperform previous baselines
by a large margin.
Proponents and opponents visualization on ArtBench-2 using TRAK and D-TRAK with various # of timesteps (10 or 100). For each sample of interest, 5 most positive influential training samples and 3 most negative influential training samples are given together with the influence scores (below each sample).
Counterfactual visualization on CIFAR-2 | Counterfactual visualization on ArtBench-2 |
---|---|
Check quickstart.ipynb to conduct data attribution on pre-trained diffusion models loaded from huggingface directly!
To get started, follow these steps:
- Clone the GitHub Repository: Begin by cloning the repository using the command:
git clone https://github.com/sail-sg/D-TRAK.git
- Set Up Python Environment: Ensure you have a version 3.8.
name:
conda create -n dtrak python=3.8 -y conda activate dtrak
- Install Dependencies: Install the necessary dependencies by running:
pip install -r requirements.txt
We provide the commands to run experiments on CIFAR-2. It is easy to transfer to other datasets.
-
Data pre-processing:
cd CIFAR2
Run 00_EDA.ipynb to create dataset splits and subsets of the training set.
-
Train a diffusion model and generate images:
bash scripts/run_train.sh 0 18888 5000-0.5 bash scripts/run_gen.sh 0 0 5000-0.5
-
Construct the LDS benchmark:
Train 64 models corresponding to 64 subsets of the training set
bash scripts/run_lds_val_sub.sh 0 18888 5000-0.5 0 63
Evaluate the model outputs on the validation set
bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_val.pkl 0 63 bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_val.pkl 0 63 bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_val.pkl 0 63
Evaluate the model outputs on the generation set
bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_gen.pkl 0 63 bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_gen.pkl 0 63 bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_gen.pkl 0 63
-
Compute gradients:
We shard the training set into 5 parts, each has 1000 examples.
Use the following commands to compute the gradients to be used for TRAK.
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 loss uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 loss uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 loss uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 loss uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
Use the following commands to compute the gradients to be used for D-TRAK.
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768 bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
-
Compute the TRAK/D-TRAK attributions and evaluate the LDS scores
Run notebooks in methods/04_if.
The implementations of other baselines can also be found in methods.
-
Data pre-processing
Run this notebook first to get the indices of those training examples to be removed.
-
Retrain models after removing the top-influenctial training examples
bash scripts/run_counter.sh 0 18888 5000-0.5 0 59
-
Generate images using the retrained models
Run 02_counter.ipynb
-
Measure l2 distance
-
Measure CLIP cosine similarity
If you find this project useful in your research, please consider citing our paper:
@inproceedings{
zheng2023intriguing,
title={Intriguing Properties of Data Attribution on Diffusion Models},
author={Zheng, Xiaosen and Pang, Tianyu and Du, Chao and Jiang, Jing and Lin, Min},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024},
}