Skip to content

Commit

Permalink
[dattri] Add huggingface example (GPT-2 + WikiText-2) and related bug…
Browse files Browse the repository at this point in the history
… fix (#157)

* add huggingface example and fix trak and task

* update

* add dropout example

* change according to comments
  • Loading branch information
TheaperDeng authored Dec 30, 2024
1 parent a73864f commit 2ecf87a
Show file tree
Hide file tree
Showing 10 changed files with 2,209 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/examples_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ jobs:
python examples/noisy_label_detection/trak_noisy_label.py --device cpu
python examples/pretrained_benchmark/influence_function_lds.py --device cpu
python examples/pretrained_benchmark/trak_loo.py --device cpu
sed -i 's/* 10/* 2/g' examples/pretrained_benchmark/trak_dropout_lds.py
python examples/pretrained_benchmark/trak_dropout_lds.py --device cpu
python examples/brittleness/mnist_lr_brittleness.py --method cg --device cpu
- name: Uninstall the package
run: |
Expand Down
50 changes: 37 additions & 13 deletions dattri/algorithm/trak.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def m(params, image_label_pair):
"""
self.task = task
self.norm_scaler = (
sum(p.numel() for p in self.task.get_param(ckpt_idx=0)[0]) ** 0.5
sum(
p.numel()
for _, p in self.task.get_model().named_parameters()
if p.requires_grad
)
** 0.5
)
self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS
if projector_kwargs is not None:
Expand All @@ -84,6 +89,7 @@ def m(params, image_label_pair):
self.correct_probability_func = vmap(
correct_probability_func,
in_dims=(None, 0),
randomness="different",
)
self.full_train_dataloader = None

Expand Down Expand Up @@ -127,14 +133,22 @@ def cache(
desc="calculating gradient of training set...",
leave=False,
):
train_batch_data = tuple(data.to(self.device) for data in train_data)
# TODO: reorganize the data pre-grad processing.
if isinstance(train_data, (tuple, list)):
train_batch_data = tuple(
data.to(self.device) for data in train_data
)
else:
train_batch_data = train_data

grad_t = self.grad_loss_func(parameters, train_batch_data)
grad_t = torch.nan_to_num(grad_t)
grad_t /= self.norm_scaler
batch_size = grad_t.shape[0]
grad_p = (
random_project(
grad_t,
train_batch_data[0].shape[0],
batch_size,
**self.projector_kwargs,
)(grad_t, ensemble_id=ckpt_idx)
.clone()
Expand All @@ -143,7 +157,7 @@ def cache(
full_train_projected_grad.append(grad_p)
Q.append(
(
torch.ones(train_batch_data[0].shape[0]).to(self.device)
torch.ones(batch_size).to(self.device)
- self.correct_probability_func(
_unflatten_params(full_parameters, self.task.get_model()),
train_batch_data,
Expand All @@ -167,7 +181,7 @@ def cache(
self.inv_XTX_XT_list = inv_XTX_XT_list
self.Q = running_Q

def attribute( # noqa: PLR0912
def attribute( # noqa: PLR0912,PLR0915
self,
test_dataloader: torch.utils.data.DataLoader,
train_dataloader: Optional[torch.utils.data.DataLoader] = None,
Expand Down Expand Up @@ -239,20 +253,26 @@ def attribute( # noqa: PLR0912
desc="calculating gradient of training set...",
leave=False,
):
train_batch_data = tuple(
data.to(self.device) for data in train_data
)
# TODO: reorganize the data pre-grad processing.
if isinstance(train_data, (tuple, list)):
train_batch_data = tuple(
data.to(self.device) for data in train_data
)
else:
train_batch_data = train_data

grad_t = self.grad_loss_func(
parameters,
train_batch_data,
)
grad_t = torch.nan_to_num(grad_t)
grad_t /= self.norm_scaler
batch_size = grad_t.shape[0]

grad_p = (
random_project(
grad_t,
train_batch_data[0].shape[0],
batch_size,
**self.projector_kwargs,
)(grad_t, ensemble_id=ckpt_idx)
.clone()
Expand All @@ -261,7 +281,7 @@ def attribute( # noqa: PLR0912
train_projected_grad.append(grad_p)
Q.append(
(
torch.ones(train_batch_data[0].shape[0]).to(self.device)
torch.ones(batch_size).to(self.device)
- self.correct_probability_func(
_unflatten_params(
full_parameters,
Expand All @@ -282,15 +302,19 @@ def attribute( # noqa: PLR0912
desc="calculating gradient of test set...",
leave=False,
):
test_batch_data = tuple(data.to(self.device) for data in test_data)
# TODO: reorganize the data pre-grad processing.
if isinstance(test_data, (tuple, list)):
test_batch_data = tuple(data.to(self.device) for data in test_data)
else:
test_batch_data = test_data
grad_t = self.grad_target_func(parameters, test_batch_data)
grad_t = torch.nan_to_num(grad_t)
grad_t /= self.norm_scaler

batch_size = grad_t.shape[0]
grad_p = (
random_project(
grad_t,
test_batch_data[0].shape[0],
batch_size,
**self.projector_kwargs,
)(grad_t, ensemble_id=ckpt_idx)
.clone()
Expand Down
76 changes: 61 additions & 15 deletions dattri/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@
from dattri.func.utils import flatten_func, flatten_params, partial_param


def _default_checkpoint_load_func(
model: nn.Module,
checkpoint: Union[
str,
List[str],
List[Dict[str, torch.Tensor]],
Dict[str, torch.Tensor],
],
) -> nn.Module:
if isinstance(checkpoint, (str, PosixPath)):
checkpoint = torch.load(
checkpoint,
map_location=next(model.parameters()).device,
)
model.load_state_dict(checkpoint)
model.eval()
return model


class AttributionTask:
"""The abstraction of the attribution task information."""

Expand All @@ -32,6 +51,7 @@ def __init__(
Dict[str, torch.Tensor],
],
target_func: Optional[Callable] = None,
checkpoints_load_func: Optional[Callable] = None,
) -> None:
"""Initialize the AttributionTask.
Expand Down Expand Up @@ -75,8 +95,19 @@ def f(params, data):
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model, params, image)
return loss(yhat, label)
```
checkpoints_load_func (Callable): The checkpoint load function.
The input is optional, if not provided, the checkpoint load
function will be a default one using model.load_state_dict.
The parameter is used for some models that have special
loading strategies, e.g., huggingface model.
A typical example for huggingface model is
```python
def checkpoints_load_func(model, checkpoint):
model = AutoModelForCausalLM.from_pretrained(checkpoint).cuda()
model.eval()
return model
```.
"""
self.model = model
if target_func is None:
Expand All @@ -90,6 +121,11 @@ def f(params, data):
self.original_target_func = target_func
self.target_func = flatten_func(self.model)(target_func)

if checkpoints_load_func is None:
self.checkpoints_load_func = _default_checkpoint_load_func
else:
self.checkpoints_load_func = checkpoints_load_func

if not isinstance(checkpoints, list):
self.checkpoints = [checkpoints]
else:
Expand All @@ -100,13 +136,21 @@ def f(params, data):
self.current_checkpoint_idx = None

# TODO: Make this more general, that is allow customized kwargs.
self.grad_loss_func = vmap(grad(self.loss_func), in_dims=(None, 1))
self.grad_loss_func = vmap(
grad(self.loss_func),
in_dims=(None, 1),
randomness="different",
)
self.grad_loss_func_kwargs = {
"in_dims": (None, 1),
"layer_name": None,
"ckpt_idx": None,
}
self.grad_target_func = vmap(grad(self.target_func), in_dims=(None, 1))
self.grad_target_func = vmap(
grad(self.target_func),
in_dims=(None, 1),
randomness="different",
)
self.grad_target_func_kwargs = {
"in_dims": (None, 1),
"layer_name": None,
Expand All @@ -123,20 +167,14 @@ def _load_checkpoints(self, ckpt_idx: int) -> None:
self.current_checkpoint_idx is None
or self.current_checkpoint_idx != ckpt_idx
):
if isinstance(self.checkpoints[ckpt_idx], (str, PosixPath)):
self.model.load_state_dict(
torch.load(
self.checkpoints[ckpt_idx],
map_location=next(self.model.parameters()).device,
),
)
else:
self.model.load_state_dict(self.checkpoints[ckpt_idx])
self.model = self.checkpoints_load_func(
self.model,
self.checkpoints[ckpt_idx],
)
self.current_checkpoint_idx = ckpt_idx
self.named_parameters = {
k: p for k, p in self.model.named_parameters() if p.requires_grad
}
self.model.eval()

@staticmethod
def _generate_param_layer_map(
Expand Down Expand Up @@ -212,7 +250,11 @@ def get_grad_target_func(
"ckpt_idx": ckpt_idx,
}
if self.grad_target_func_kwargs != grad_target_func_kwargs:
self.grad_target_func = vmap(grad(target_func), in_dims=in_dims)
self.grad_target_func = vmap(
grad(target_func),
in_dims=in_dims,
randomness="different",
)
self.grad_target_func_kwargs = grad_target_func_kwargs
return self.grad_target_func

Expand Down Expand Up @@ -294,7 +336,11 @@ def get_grad_loss_func(
"ckpt_idx": ckpt_idx,
}
if self.grad_loss_func_kwargs != loss_target_func_kwargs:
self.grad_loss_func = vmap(grad(loss_func), in_dims=in_dims)
self.grad_loss_func = vmap(
grad(loss_func),
in_dims=in_dims,
randomness="different",
)
self.grad_loss_func_kwargs = loss_target_func_kwargs
return self.grad_loss_func

Expand Down
87 changes: 87 additions & 0 deletions examples/pretrained_benchmark/trak_dropout_lds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import argparse
from pathlib import PosixPath

import torch
from torch import nn
from torch.utils.data import DataLoader

from dattri.algorithm.trak import TRAKAttributor
from dattri.benchmark.load import load_benchmark
from dattri.metric import lds
from dattri.task import AttributionTask
from dattri.model_util.dropout import activate_dropout

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cpu", type=str)
args = parser.parse_args()

# download the pre-trained benchmark
# includes some trained model and ground truth
model_details, groundtruth = load_benchmark(
model="mlp", dataset="mnist", metric="lds"
)

# Here we use 0.1 dropout rate on the model
model = activate_dropout(model_details["model"], dropout_prob=0.1)

def dropout_checkpoint_load_func(model, checkpoint):
if isinstance(checkpoint, (str, PosixPath)):
checkpoint = torch.load(checkpoint,
map_location=next(model.parameters()).device,)
model.load_state_dict(checkpoint)
model.eval()
model = activate_dropout(model, dropout_prob=0.1)
return model

def f(params, data_target_pair):
image, label = data_target_pair
image_t = image.unsqueeze(0)
label_t = label.unsqueeze(0)
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model, params, image_t)
logp = -loss(yhat, label_t)
return logp - torch.log(1 - torch.exp(logp))

def m(params, image_label_pair):
image, label = image_label_pair
image_t = image.unsqueeze(0)
label_t = label.unsqueeze(0)
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model, params, image_t)
p = torch.exp(-loss(yhat, label_t))
return p

# here we use 10 same checkpoints
task = AttributionTask(
model=model.to(args.device),
loss_func=f,
checkpoints=[model_details["models_half"][0]] * 10,
checkpoints_load_func = dropout_checkpoint_load_func
)

attributor = TRAKAttributor(
task=task,
correct_probability_func=m,
device=args.device,
)

with torch.no_grad():
attributor.cache(
DataLoader(
model_details["train_dataset"],
batch_size=5000,
sampler=model_details["train_sampler"],
)
)

score = attributor.attribute(
DataLoader(
model_details["test_dataset"],
batch_size=5000,
sampler=model_details["test_sampler"],
),
)

lds_score = lds(score, groundtruth)[0]
print("lds:", torch.mean(lds_score[~torch.isnan(lds_score)]))
2 changes: 2 additions & 0 deletions examples/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ This section contains examples using the pretrained checkpoints and pre-calculat

[Use pre-trained Mnist10 + LR benchmark setting and evaluate TRAK algorithm by LOO correlation](./pretrained_benchmark/trak_lds.py)

[Use pre-trained MNIST10 + MLP benchmark setting and evaluate TRAK + dropout ensemble by LDS](./pretrained_benchmark/trak_dropout_lds.py)

## Estimate the brittleness

This section contains examples using attribution scores to estimate the brittleness of a model.
Expand Down
Loading

0 comments on commit 2ecf87a

Please sign in to comment.