Skip to content

Commit

Permalink
DOC TST Reproducibility of models using batch norm (#1734)
Browse files Browse the repository at this point in the history
Fixes #1732

After loading a model that was trained with PEFT on a base model with
some kind of batch norm layer, the loaded model should produce the same
output. Right now, this does not happen.

The reason is that during training, buffers for running mean etc. are
updated, but they are not saved when calling save_pretrained on the
PeftModel instance. Normally in PEFT, we assume that during training,
the base model parameters are kept constant, which is not the case with
batch norm. We only save the PEFT parameters and assume that when the
user loads the base model, all parameters are restored exactly. That
way, the information in the buffers is lost completely.

The fix is to add the batch norm layers to modules_to_save. This fix is
now documented and tested.
  • Loading branch information
BenjaminBossan authored May 22, 2024
1 parent bc6a999 commit 1fec231
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
22 changes: 22 additions & 0 deletions docs/source/developer_guides/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,25 @@ TunerModelStatus(
devices={'adapter-1': ['cpu'], 'adapter-2': ['cuda']},
)
```

## Reproducibility

### Models using batch norm

When loading a trained PEFT model where the base model uses batch norm (e.g. `torch.nn.BatchNorm1d` or `torch.nn.BatchNorm2d`), you may find that you cannot reproduce the exact same outputs. This is because the batch norm layers keep track of running stats during training, but these stats are not part of the PEFT checkpoint. Therefore, when you load the PEFT model, the running stats of the base model will be used (i.e. from before training with PEFT).

Depending on your use case, this may not be a big deal. If, however, you need your outputs to be 100% reproducible, you can achieve this by adding the batch norm layers to `modules_to_save`. Below is an example of this using resnet and LoRA. Notice that we set `modules_to_save=["classifier", "normalization"]`. We need the `"classifier"` argument because our task is image classification, and we add the `"normalization"` argument to ensure that the batch norm layers are saved in the PEFT checkpoint.

```python
from transformers import AutoModelForImageClassification
from peft import LoraConfig, get_peft_model

model_id = "microsoft/resnet-18"
base_model = AutoModelForImageClassification.from_pretrained(self.model_id)
config = LoraConfig(
target_modules=["convolution"],
modules_to_save=["classifier", "normalization"],
),
```

Depending on the type of model you use, the batch norm layers could have different names than `"normalization"`, so please ensure that the name matches your model architecture.
5 changes: 5 additions & 0 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_peft_model_state_dict(
config = model.peft_config[adapter_name]
if state_dict is None:
state_dict = model.state_dict()

# TUNER SPECIFIC CODE
if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
Expand Down Expand Up @@ -165,11 +167,13 @@ def get_peft_model_state_dict(
else:
raise ValueError(f"Unknown PEFT type passed: {config.peft_type}")

# MODULES TO SAVE
if getattr(model, "modules_to_save", None) is not None:
for key, value in state_dict.items():
if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
to_return[key.replace("modules_to_save.", "")] = value

# DEAL WITH EMBEDDINGS
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
is_embedding_in_target_modules = False
if (
Expand Down Expand Up @@ -224,6 +228,7 @@ def get_peft_model_state_dict(
elif save_embedding_layers:
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")

# REMOVE ADAPTER NAME
to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
return to_return

Expand Down
117 changes: 117 additions & 0 deletions tests/test_vision_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is not a full on test suite of vision models, since we already run many tests on dummy models with Conv2d layers
# and on stable diffusion models. Instead, this file contains specific tests for bugs that have been found in the past.
import gc

import pytest
import torch
from datasets import load_dataset
from safetensors.torch import load_file
from transformers import AutoImageProcessor, AutoModelForImageClassification

from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model


CONFIGS = {
"lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
"loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
"lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
"oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
# TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no
# common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel:
# > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered
# "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"], boft_block_size=2),
}


class TestResnet:
model_id = "microsoft/resnet-18"

@pytest.fixture(autouse=True)
def teardown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

@pytest.fixture(scope="class")
def image_processor(self):
image_processor = AutoImageProcessor.from_pretrained(self.model_id)
return image_processor

@pytest.fixture(scope="class")
def data(self, image_processor):
dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
image = dataset["test"]["image"][0]
return image_processor(image, return_tensors="pt")

@pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys())
def test_model_with_batchnorm_reproducibility(self, config, tmp_path, data):
# see 1732
torch.manual_seed(0)
model = AutoModelForImageClassification.from_pretrained(self.model_id)
model = get_peft_model(model, config)

# record outputs before training
model.eval()
with torch.inference_mode():
output_before = model(**data)
model.train()

# train the model
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
max_steps = 5 * batch_size
labels = torch.zeros(1, 1000)
labels[0, 283] = 1
for i in range(0, max_steps, batch_size):
optimizer.zero_grad()
outputs = model(**data, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()

# record outputs after training
model.eval()
with torch.inference_mode():
output_after = model(**data)
assert torch.isfinite(output_after.logits).all()
atol, rtol = 1e-4, 1e-4
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)

# check saving the model and loading it
model.save_pretrained(tmp_path)
del model

torch.manual_seed(0)
model = AutoModelForImageClassification.from_pretrained(self.model_id)
model = PeftModel.from_pretrained(model, tmp_path).eval()
with torch.inference_mode():
output_loaded = model(**data)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)

# ensure that the checkpoint file contains the buffers
model_running_mean = len([k for k in model.state_dict().keys() if "running_mean" in k])
state_dict = load_file(tmp_path / "adapter_model.safetensors")
checkpoint_running_mean = len([k for k in state_dict.keys() if "running_mean" in k])
# note that the model has twice as many "running_mean", as there is one copy per ModulesToSaveWrapper, we need
# to multiply by 2 to get the same number
assert model_running_mean == checkpoint_running_mean * 2

0 comments on commit 1fec231

Please sign in to comment.