Quick Start | Algorithms | Metrics | Benchmark Settings | Benchmark Results
dattri is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms. You may use dattri to
- Deploy existing data attribution methods to PyTorch models
- e.g., Influence Function, TracIn, RPS, TRAK, ...
- Develop new data attribution methods with efficient implementation of low-level utility functions
- e.g., Hessian (HVP/IHVP), Fisher Information Matrix (IFVP), random projection, dropout ensembling, ...
- Benchmark data attribution methods with standard benchmark settings
- e.g., MNIST-10+LR/MLP, CIFAR-10/2+ResNet-9, MAESTRO + Music Transformer, Shakespeare + nanoGPT, ...
- A Library for Efficient Data Attribution
pip install dattri
If you want to use fast_jl
to accelerate the random projection, you may install the version with fast_jl
by
pip install dattri[fast_jl]
Note
It's highly recommended to use a device support CUDA to run dattri, especially for large models or datasets.
Note
It's required to have CUDA if you want to install and use the fast_jl version dattri[fast_jl]
to accelerate the random projection. The projection is mainly used in TRAKAttributor
. Please use pip<23
and torch<2.3
due to some known issue of fast_jl
library.
It's not required to follow the exact same steps in this section. But this is a verified environment setup flow that may help users to avoid most of the issues during the installation.
conda create -n dattri python=3.10
conda activate dattri
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install dattri[fast_jl]
One can apply different data attribution methods on PyTorch Models. One only needs to define:
- loss function used for model training (will be used as target function to be attributed if no other target function provided).
- trained model checkpoints (one or more).
- the data loaders for training samples and test samples (e.g.,
train_loader
,test_loader
). - (optional) target function to be attributed if it's not the same as loss function.
The following is an example to use IFAttributorCG
and AttributionTask
to apply data attribution to a PyTorch model.
More examples can be found here.
import torch
from torch import nn
from dattri.algorithm import IFAttributorCG
from dattri.task import AttributionTask
from dattri.benchmark.datasets.mnist import train_mnist_lr, create_mnist_dataset
from dattri.benchmark.utils import SubsetSampler
dataset_train, dataset_test = create_mnist_dataset("./data")
train_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=1000,
sampler=SubsetSampler(range(1000)),
)
test_loader = torch.utils.data.DataLoader(
dataset_test,
batch_size=100,
sampler=SubsetSampler(range(100)),
)
model = train_mnist_lr(train_loader)
def f(params, data_target_pair):
x, y = data_target_pair
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model, params, x)
return loss(yhat, y)
task = AttributionTask(loss_func=f,
model=model,
checkpoints=model.state_dict())
attributor = IFAttributorCG(
task=task,
max_iter=10,
regularization=1e-2
)
attributor.cache(train_loader)
with torch.no_grad():
score = attributor.attribute(train_loader, test_loader)
Hessian-vector product (HVP), inverse-Hessian-vector product
(IHVP) are widely used in data attribution methods. dattri
provides efficient implementation to these operators by torch.func
. This example shows how to use the CG implementation of the IHVP implementation.
import torch
from dattri.func.hessian import ihvp_cg, ihvp_at_x_cg
def f(x, param):
return torch.sin(x / param).sum()
x = torch.randn(2)
param = torch.randn(1)
v = torch.randn(5, 2)
# ihvp_cg method
ihvp_func = ihvp_cg(f, argnums=0, max_iter=2) # argnums=0 indicates that the param of (x, param) to be passed to ihvp_func is the model parameter
ihvp_result_1 = ihvp_func((x, param), v) # both (x, param) and v as the inputs
# ihvp_at_x_cg method: (x, param) is cached
ihvp_at_x_func = ihvp_at_x_cg(f, x, param, argnums=0, max_iter=2)
ihvp_result_2 = ihvp_at_x_func(v) # only v as the input
# the above two will give the same result
assert torch.allclose(ihvp_result_1, ihvp_result_2)
It has been shown that long vectors will retain most of their relative information when projected down to a smaller feature dimension. To reduce the computational cost, random projection is widely used in data attribution methods. Following is an example to use random_project
. The implementation leaverges fast_jl
.
from dattri.func.random_projection import random_project
# initialize the projector based on users' needs
project_func = random_project(tensor, tensor.size(0), proj_dim=512)
# obtain projected tensors
projected_tensor = project_func(torch.full_like(tensor))
Normally speaking, tensor
is probably the gradient of loss/target function and has a large dimension (i.e., the number of parameters).
Recent studies found that ensemble methods can significantly improve the performance of data attribution, DROPOUT ENSEMBLE is one of these ensemble methods. One may prepare their model with
from dattri.model_util.dropout import activate_dropout
# initialize a torch.nn.Module model
model = MLP()
# (option 1) activate all dropout layers
model = activate_dropout(model, dropout_prob=0.2)
# (option 2) activate specific dropout layers
# here "dropout1" and "dropout2" are the names of dropout layers within the model
model = activate_dropout(model, ["dropout1", "dropout2"], dropout_prob=0.2)
We have implemented most of the state-of-the-art methods. The categories and reference paper of the algorithms are listed in the following table.
Family | Algorithms |
---|---|
IF | Explicit |
CG | |
LiSSA | |
Arnoldi | |
DataInf | |
EK-FAC | |
TracIn | TracInCP |
Grad-Dot | |
Grad-Cos | |
RPS | RPS-L2 |
TRAK | TRAK |
Shapley Value | KNN-Shapley |
- Leave-one-out (LOO) correlation
- Linear datamodeling score (LDS)
- Area under the ROC curve (AUC) for noisy label detection
- Brittleness test for checking flipped label
Dataset | Model | Task | Sample Size (train, test) | Parameter Size | Metric | Data Source |
---|---|---|---|---|---|---|
MNIST-10 | LR | Image Classification | (5000,500) | 7840 | LOO/LDS/AUC | link |
MNIST-10 | MLP | Image Classification | (5000,500) | 0.11M | LOO/LDS/AUC | link |
CIFAR-2 | ResNet-9 | Image Classification | (5000,500) | 4.83M | LDS | link |
CIFAR-10 | ResNet-9 | Image Classification | (5000,500) | 4.83M | AUC | link |
MAESTRO | Music Transformer | Music Generation | (5000,178) | 13.3M | LDS | link |
Shakespeare | nanoGPT | Text Generation | (3921,435) | 10.7M | LDS | link |
- More (larger) benchmark settings to come
- ImageNet + ResNet-18
- Tinystories + nanoGPT
- Comparison with other libraries
- More algorithms and low-level utility functions to come
- KNN filter
- TF-IDF filter
- RelativeIF
- In-Run Shapley
- Better documentation
- Quick start colab notebooks
@inproceedings{deng2024dattri,
author = {Deng, Junwei and Li, Ting-Wei and Zhang, Shiyuan and Liu, Shixuan and Pan, Yijun and Huang, Hao and Wang, Xinhe and Hu, Pingbang and Zhang, Xingjian and Ma, Jiaqi W},
title = {dattri: A Library for Efficient Data Attribution},
booktitle = {Advances in Neural Information Processing Systems},
volume = {37},
year = {2024}
}