Skip to content
70 changes: 60 additions & 10 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import (
check_peft_version,
Expand Down Expand Up @@ -245,20 +245,27 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non

self.set_adapter(adapter_name)

def set_adapter(self, adapter_name: str) -> None:
def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.

Args:
adapter_name (`str`):
The name of the adapter to set.
adapter_name (`Union[List[str], str]`):
The name of the adapter to set. Can be also a list of strings to set multiple adapters.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
elif isinstance(adapter_name, list):
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
elif adapter_name not in self.peft_config:
raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
Expand All @@ -270,7 +277,11 @@ def set_adapter(self, adapter_name: str) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
Expand All @@ -294,7 +305,11 @@ def disable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = True
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Expand All @@ -312,14 +327,22 @@ def enable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False

def active_adapter(self) -> str:
def active_adapters(self) -> List[str]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, .active_adapters is a method, in PEFT, it's a property. Since we newly introduce this method in transformers, do we want to take the opportunity to make it property for consistency? The downside is that active_adapter is a method here, not a property, so it would be inconsistent with that method. We could change it too, but that would be BC breaking.

"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Gets the current active adapter of the model.
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of all active adapters so that users can deal with them accordingly.

For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
For previous PEFT versions (that do not support multi-adapter inference), `module.active_adapter` will return

I think this statement is a bit confusing: This method always returns a list of str, right? The statement seems to relate to what older PEFT versions do under the hood, but that's an implementation detail and should not be mentioned in the docstring (but you could add this as a comment in the code).

a single string.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

Expand All @@ -333,7 +356,34 @@ def active_adapter(self) -> str:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
active_adapters = module.active_adapter
break

# For previous PEFT versions
if isinstance(active_adapters, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to have no active adapter at all or is this prevented at some point? Otherwise, active_adapters could be undefined here.

active_adapters = [active_adapters]

return active_adapters

def active_adapter(self) -> str:
"""
Gets the current active adapter of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the first active adapter - kept for backward compatibility.

For higher versions of PEFT, users should use `model.active_adapters()` instead to get the list of active
adapters.
"""

active_adapters = self.active_adapters()

if isinstance(active_adapters, list):
logger.warning(
"`active_adapter` will return the first adapter in case of multi-adapter inference. Make sure to know what you are doing.",
" you should use `model.active_adapters() instead to get the list of active adapters",
)
active_adapters = active_adapters[0]

return active_adapters

def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
Expand Down
197 changes: 137 additions & 60 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,7 @@ def save_pretrained(
custom_object_save(self, save_directory, config=self.config)

_hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False)
peft_multi_adapter_state_dict = {}

# Save the config
if is_main_process:
Expand All @@ -1995,72 +1996,67 @@ def save_pretrained(
logger.info(
"Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
)
state_dict = model_to_save.get_adapter_state_dict()

if save_peft_format:
logger.info(
"To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`."
)
peft_state_dict = {}
for key, value in state_dict.items():
peft_state_dict[f"base_model.model.{key}"] = value
state_dict = peft_state_dict
total_adapters = list(self.peft_config.keys())

for adapter_name in total_adapters:
adapter_state_dict = model_to_save.get_adapter_state_dict(adapter_name=adapter_name)

if save_peft_format:
logger.info(
"To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`."
)
peft_state_dict = {}
for key, value in adapter_state_dict.items():
peft_state_dict[f"base_model.model.{key}"] = value
adapter_state_dict = peft_state_dict.copy()
# Free memory
del peft_state_dict

current_peft_config = self.peft_config[adapter_name]
peft_multi_adapter_state_dict[adapter_name] = adapter_state_dict

# the default adapter is always saved on the root directory
if adapter_name != "default":
current_peft_config.save_pretrained(os.path.join(save_directory, adapter_name))
else:
current_peft_config.save_pretrained(save_directory)

if len(peft_multi_adapter_state_dict.keys()) == 1:
current_adapter = list(peft_multi_adapter_state_dict.keys())[0]
state_dict = peft_multi_adapter_state_dict[current_adapter].copy()
peft_multi_adapter_state_dict = None
Comment on lines +2025 to +2028
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this change, which I think makes the intent a bit more obvious and avoid changing the type of peft_multi_adapter_state_dict:

Suggested change
if len(peft_multi_adapter_state_dict.keys()) == 1:
current_adapter = list(peft_multi_adapter_state_dict.keys())[0]
state_dict = peft_multi_adapter_state_dict[current_adapter].copy()
peft_multi_adapter_state_dict = None
if len(peft_multi_adapter_state_dict.keys()) == 1:
current_adapter = list(peft_multi_adapter_state_dict.keys())[0]
state_dict = peft_multi_adapter_state_dict.pop(current_adapter).copy()

Not sure if the .copy() is needed?

Then, below change:

- _peft_save_multi_adapter = _hf_peft_config_loaded and peft_multi_adapter_state_dict is not None
+ _peft_save_multi_adapter = _hf_peft_config_loaded and peft_multi_adapter_state_dict


current_peft_config = self.peft_config[self.active_adapter()]
current_peft_config.save_pretrained(save_directory)
_peft_save_multi_adapter = _hf_peft_config_loaded and peft_multi_adapter_state_dict is not None

# Save the model
if state_dict is None:
if state_dict is None and not _peft_save_multi_adapter:
state_dict = model_to_save.state_dict()

# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
if IS_SAGEMAKER_MP_POST_1_10 and not _peft_save_multi_adapter:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
state_dict = smp_to_hf(state_dict)

# Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None:
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
if not _peft_save_multi_adapter:
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
else:
for adapter_name in peft_multi_adapter_state_dict:
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in peft_multi_adapter_state_dict[adapter_name].keys():
del peft_multi_adapter_state_dict[adapter_name][ignore_key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked ChatGPT to make this block more elegant, here is what it came up with:

if not _peft_save_multi_adapter:
    state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
else:
    peft_multi_adapter_state_dict = {
        adapter_name: {k: v for k, v in adapter.items() if k not in self._keys_to_ignore_on_save}
        for adapter_name, adapter in peft_multi_adapter_state_dict.items()
    }

WDYT? :)

if safe_serialization:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[id_tensor_storage(tensor)].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if self._tied_weights_keys is not None:
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys)
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
del state_dict[name]

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)
if not _peft_save_multi_adapter:
state_dict = self._post_process_safe_checkpoint(state_dict)
else:
for adapter_name in peft_multi_adapter_state_dict:
peft_multi_adapter_state_dict[adapter_name] = self._post_process_safe_checkpoint(
peft_multi_adapter_state_dict[adapter_name]
)

# Shard the model if it is too big.
if not _hf_peft_config_loaded:
Expand All @@ -2069,6 +2065,55 @@ def save_pretrained(
else:
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME

if not _peft_save_multi_adapter:
self._shard_and_save_checkpoints(
save_directory,
state_dict,
weights_name,
max_shard_size,
safe_serialization,
is_main_process,
variant,
save_function,
)
else:
for adapter_name in peft_multi_adapter_state_dict:
# The default adapter always needs to be saved on the root directory
adapter_save_path = (
save_directory if adapter_name == "default" else os.path.join(save_directory, adapter_name)
)

self._shard_and_save_checkpoints(
adapter_save_path,
peft_multi_adapter_state_dict[adapter_name],
weights_name,
max_shard_size,
safe_serialization,
is_main_process,
variant,
save_function,
)

if push_to_hub:
self._upload_modified_files(
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=token,
)

def _shard_and_save_checkpoints(
self,
save_directory,
state_dict,
weights_name,
max_shard_size,
safe_serialization,
is_main_process,
variant,
save_function: Callable = torch.save,
):
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)

# Clean the folder from a previous save
Expand Down Expand Up @@ -2116,15 +2161,47 @@ def save_pretrained(
f"index located at {save_index_file}."
)

if push_to_hub:
self._upload_modified_files(
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=token,
def _post_process_safe_checkpoint(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[id_tensor_storage(tensor)].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if self._tied_weights_keys is not None:
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys)
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
del state_dict[name]

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)

return state_dict

def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Expand Down
15 changes: 15 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,11 @@ def test_peft_add_multi_adapter(self):
_ = model.generate(input_ids=dummy_input)

model.set_adapter("default")
self.assertTrue(model.active_adapters() == ["default"])
self.assertTrue(model.active_adapter() == "default")

model.set_adapter("adapter-2")
self.assertTrue(model.active_adapters() == ["adapter-2"])
self.assertTrue(model.active_adapter() == "adapter-2")

# Logits comparison
Expand All @@ -276,6 +278,19 @@ def test_peft_add_multi_adapter(self):
)
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))

model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
self.assertTrue(model.active_adapter() == "adapter-2")

logits_adapter_mixed = model(dummy_input)
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

@require_torch_gpu
def test_peft_from_pretrained_kwargs(self):
"""
Expand Down