-
Notifications
You must be signed in to change notification settings - Fork 31.5k
[WIP] Multi-adapter saving support for PEFT
#26411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
05bbcd7
b2e32a5
d772435
2b6c4b2
438eb67
bd402a1
e1686e0
87e304a
3dc62a3
9250ccf
c7ad2b4
9e93a26
f8e435f
d676d41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
| import inspect | ||||||
| from typing import TYPE_CHECKING, Any, Dict, Optional | ||||||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||||||
|
|
||||||
| from ..utils import ( | ||||||
| check_peft_version, | ||||||
|
|
@@ -245,20 +245,27 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non | |||||
|
|
||||||
| self.set_adapter(adapter_name) | ||||||
|
|
||||||
| def set_adapter(self, adapter_name: str) -> None: | ||||||
| def set_adapter(self, adapter_name: Union[List[str], str]) -> None: | ||||||
| """ | ||||||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||||||
| official documentation: https://huggingface.co/docs/peft | ||||||
|
|
||||||
| Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. | ||||||
|
|
||||||
| Args: | ||||||
| adapter_name (`str`): | ||||||
| The name of the adapter to set. | ||||||
| adapter_name (`Union[List[str], str]`): | ||||||
| The name of the adapter to set. Can be also a list of strings to set multiple adapters. | ||||||
| """ | ||||||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||||||
| if not self._hf_peft_config_loaded: | ||||||
| raise ValueError("No adapter loaded. Please load an adapter first.") | ||||||
| elif isinstance(adapter_name, list): | ||||||
| missing = set(adapter_name) - set(self.peft_config) | ||||||
| if len(missing) > 0: | ||||||
| raise ValueError( | ||||||
| f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." | ||||||
| f" current loaded adapters are: {list(self.peft_config.keys())}" | ||||||
| ) | ||||||
| elif adapter_name not in self.peft_config: | ||||||
| raise ValueError( | ||||||
| f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" | ||||||
|
|
@@ -270,7 +277,11 @@ def set_adapter(self, adapter_name: str) -> None: | |||||
|
|
||||||
| for _, module in self.named_modules(): | ||||||
| if isinstance(module, BaseTunerLayer): | ||||||
| module.active_adapter = adapter_name | ||||||
| # For backward compatbility with previous PEFT versions | ||||||
| if hasattr(module, "set_adapter"): | ||||||
| module.set_adapter(adapter_name) | ||||||
| else: | ||||||
| module.active_adapter = adapter_name | ||||||
| _adapters_has_been_set = True | ||||||
|
|
||||||
| if not _adapters_has_been_set: | ||||||
|
|
@@ -294,7 +305,11 @@ def disable_adapters(self) -> None: | |||||
|
|
||||||
| for _, module in self.named_modules(): | ||||||
| if isinstance(module, BaseTunerLayer): | ||||||
| module.disable_adapters = True | ||||||
| # The recent version of PEFT need to call `enable_adapters` instead | ||||||
| if hasattr(module, "enable_adapters"): | ||||||
| module.enable_adapters(enabled=False) | ||||||
| else: | ||||||
| module.disable_adapters = True | ||||||
|
|
||||||
| def enable_adapters(self) -> None: | ||||||
| """ | ||||||
|
|
@@ -312,14 +327,22 @@ def enable_adapters(self) -> None: | |||||
|
|
||||||
| for _, module in self.named_modules(): | ||||||
| if isinstance(module, BaseTunerLayer): | ||||||
| module.disable_adapters = False | ||||||
| # The recent version of PEFT need to call `enable_adapters` instead | ||||||
| if hasattr(module, "enable_adapters"): | ||||||
| module.enable_adapters(enabled=True) | ||||||
| else: | ||||||
| module.disable_adapters = False | ||||||
|
|
||||||
| def active_adapter(self) -> str: | ||||||
| def active_adapters(self) -> List[str]: | ||||||
| """ | ||||||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||||||
| official documentation: https://huggingface.co/docs/peft | ||||||
|
|
||||||
| Gets the current active adapter of the model. | ||||||
| Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters | ||||||
| for inference) returns the list of all active adapters so that users can deal with them accordingly. | ||||||
|
|
||||||
| For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think this statement is a bit confusing: This method always returns a list of str, right? The statement seems to relate to what older PEFT versions do under the hood, but that's an implementation detail and should not be mentioned in the docstring (but you could add this as a comment in the code). |
||||||
| a single string. | ||||||
| """ | ||||||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||||||
|
|
||||||
|
|
@@ -333,7 +356,34 @@ def active_adapter(self) -> str: | |||||
|
|
||||||
| for _, module in self.named_modules(): | ||||||
| if isinstance(module, BaseTunerLayer): | ||||||
| return module.active_adapter | ||||||
| active_adapters = module.active_adapter | ||||||
| break | ||||||
|
|
||||||
| # For previous PEFT versions | ||||||
| if isinstance(active_adapters, str): | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to have no active adapter at all or is this prevented at some point? Otherwise, |
||||||
| active_adapters = [active_adapters] | ||||||
|
|
||||||
| return active_adapters | ||||||
|
|
||||||
| def active_adapter(self) -> str: | ||||||
| """ | ||||||
| Gets the current active adapter of the model. In case of multi-adapter inference (combining multiple adapters | ||||||
| for inference) returns the first active adapter - kept for backward compatibility. | ||||||
|
|
||||||
| For higher versions of PEFT, users should use `model.active_adapters()` instead to get the list of active | ||||||
| adapters. | ||||||
| """ | ||||||
|
|
||||||
| active_adapters = self.active_adapters() | ||||||
|
|
||||||
| if isinstance(active_adapters, list): | ||||||
| logger.warning( | ||||||
| "`active_adapter` will return the first adapter in case of multi-adapter inference. Make sure to know what you are doing.", | ||||||
| " you should use `model.active_adapters() instead to get the list of active adapters", | ||||||
| ) | ||||||
| active_adapters = active_adapters[0] | ||||||
|
|
||||||
| return active_adapters | ||||||
|
|
||||||
| def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: | ||||||
| """ | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1983,6 +1983,7 @@ def save_pretrained( | |||||||||||||||
| custom_object_save(self, save_directory, config=self.config) | ||||||||||||||||
|
|
||||||||||||||||
| _hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False) | ||||||||||||||||
| peft_multi_adapter_state_dict = {} | ||||||||||||||||
|
|
||||||||||||||||
| # Save the config | ||||||||||||||||
| if is_main_process: | ||||||||||||||||
|
|
@@ -1995,72 +1996,67 @@ def save_pretrained( | |||||||||||||||
| logger.info( | ||||||||||||||||
| "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." | ||||||||||||||||
| ) | ||||||||||||||||
| state_dict = model_to_save.get_adapter_state_dict() | ||||||||||||||||
|
|
||||||||||||||||
| if save_peft_format: | ||||||||||||||||
| logger.info( | ||||||||||||||||
| "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`." | ||||||||||||||||
| ) | ||||||||||||||||
| peft_state_dict = {} | ||||||||||||||||
| for key, value in state_dict.items(): | ||||||||||||||||
| peft_state_dict[f"base_model.model.{key}"] = value | ||||||||||||||||
| state_dict = peft_state_dict | ||||||||||||||||
| total_adapters = list(self.peft_config.keys()) | ||||||||||||||||
|
|
||||||||||||||||
| for adapter_name in total_adapters: | ||||||||||||||||
| adapter_state_dict = model_to_save.get_adapter_state_dict(adapter_name=adapter_name) | ||||||||||||||||
|
|
||||||||||||||||
| if save_peft_format: | ||||||||||||||||
| logger.info( | ||||||||||||||||
| "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`." | ||||||||||||||||
| ) | ||||||||||||||||
| peft_state_dict = {} | ||||||||||||||||
| for key, value in adapter_state_dict.items(): | ||||||||||||||||
| peft_state_dict[f"base_model.model.{key}"] = value | ||||||||||||||||
| adapter_state_dict = peft_state_dict.copy() | ||||||||||||||||
| # Free memory | ||||||||||||||||
| del peft_state_dict | ||||||||||||||||
|
|
||||||||||||||||
| current_peft_config = self.peft_config[adapter_name] | ||||||||||||||||
| peft_multi_adapter_state_dict[adapter_name] = adapter_state_dict | ||||||||||||||||
|
|
||||||||||||||||
| # the default adapter is always saved on the root directory | ||||||||||||||||
| if adapter_name != "default": | ||||||||||||||||
| current_peft_config.save_pretrained(os.path.join(save_directory, adapter_name)) | ||||||||||||||||
| else: | ||||||||||||||||
| current_peft_config.save_pretrained(save_directory) | ||||||||||||||||
|
|
||||||||||||||||
| if len(peft_multi_adapter_state_dict.keys()) == 1: | ||||||||||||||||
| current_adapter = list(peft_multi_adapter_state_dict.keys())[0] | ||||||||||||||||
| state_dict = peft_multi_adapter_state_dict[current_adapter].copy() | ||||||||||||||||
| peft_multi_adapter_state_dict = None | ||||||||||||||||
|
Comment on lines
+2025
to
+2028
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about this change, which I think makes the intent a bit more obvious and avoid changing the type of
Suggested change
Not sure if the Then, below change: - _peft_save_multi_adapter = _hf_peft_config_loaded and peft_multi_adapter_state_dict is not None
+ _peft_save_multi_adapter = _hf_peft_config_loaded and peft_multi_adapter_state_dict |
||||||||||||||||
|
|
||||||||||||||||
| current_peft_config = self.peft_config[self.active_adapter()] | ||||||||||||||||
| current_peft_config.save_pretrained(save_directory) | ||||||||||||||||
| _peft_save_multi_adapter = _hf_peft_config_loaded and peft_multi_adapter_state_dict is not None | ||||||||||||||||
|
|
||||||||||||||||
| # Save the model | ||||||||||||||||
| if state_dict is None: | ||||||||||||||||
| if state_dict is None and not _peft_save_multi_adapter: | ||||||||||||||||
| state_dict = model_to_save.state_dict() | ||||||||||||||||
|
|
||||||||||||||||
| # Translate state_dict from smp to hf if saving with smp >= 1.10 | ||||||||||||||||
| if IS_SAGEMAKER_MP_POST_1_10: | ||||||||||||||||
| if IS_SAGEMAKER_MP_POST_1_10 and not _peft_save_multi_adapter: | ||||||||||||||||
| for smp_to_hf, _ in smp.state.module_manager.translate_functions: | ||||||||||||||||
| state_dict = smp_to_hf(state_dict) | ||||||||||||||||
|
|
||||||||||||||||
| # Handle the case where some state_dict keys shouldn't be saved | ||||||||||||||||
| if self._keys_to_ignore_on_save is not None: | ||||||||||||||||
| for ignore_key in self._keys_to_ignore_on_save: | ||||||||||||||||
| if ignore_key in state_dict.keys(): | ||||||||||||||||
| del state_dict[ignore_key] | ||||||||||||||||
| if not _peft_save_multi_adapter: | ||||||||||||||||
| for ignore_key in self._keys_to_ignore_on_save: | ||||||||||||||||
| if ignore_key in state_dict.keys(): | ||||||||||||||||
| del state_dict[ignore_key] | ||||||||||||||||
| else: | ||||||||||||||||
| for adapter_name in peft_multi_adapter_state_dict: | ||||||||||||||||
| for ignore_key in self._keys_to_ignore_on_save: | ||||||||||||||||
| if ignore_key in peft_multi_adapter_state_dict[adapter_name].keys(): | ||||||||||||||||
| del peft_multi_adapter_state_dict[adapter_name][ignore_key] | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I asked ChatGPT to make this block more elegant, here is what it came up with: if not _peft_save_multi_adapter:
state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
else:
peft_multi_adapter_state_dict = {
adapter_name: {k: v for k, v in adapter.items() if k not in self._keys_to_ignore_on_save}
for adapter_name, adapter in peft_multi_adapter_state_dict.items()
}WDYT? :) |
||||||||||||||||
| if safe_serialization: | ||||||||||||||||
| # Safetensors does not allow tensor aliasing. | ||||||||||||||||
| # We're going to remove aliases before saving | ||||||||||||||||
| ptrs = collections.defaultdict(list) | ||||||||||||||||
| for name, tensor in state_dict.items(): | ||||||||||||||||
| ptrs[id_tensor_storage(tensor)].append(name) | ||||||||||||||||
|
|
||||||||||||||||
| # These are all the pointers of shared tensors. | ||||||||||||||||
| shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} | ||||||||||||||||
| warn_names = set() | ||||||||||||||||
| for names in shared_ptrs.values(): | ||||||||||||||||
| # Removing the keys which are declared as known duplicates on | ||||||||||||||||
| # load. This allows to make sure the name which is kept is consistent. | ||||||||||||||||
| if self._tied_weights_keys is not None: | ||||||||||||||||
| found = 0 | ||||||||||||||||
| for name in sorted(names): | ||||||||||||||||
| matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys) | ||||||||||||||||
| if matches_pattern and name in state_dict: | ||||||||||||||||
| found += 1 | ||||||||||||||||
| if found < len(names): | ||||||||||||||||
| del state_dict[name] | ||||||||||||||||
|
|
||||||||||||||||
| # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. | ||||||||||||||||
| # If the link between tensors was done at runtime then `from_pretrained` will not get | ||||||||||||||||
| # the key back leading to random tensor. A proper warning will be shown | ||||||||||||||||
| # during reload (if applicable), but since the file is not necessarily compatible with | ||||||||||||||||
| # the config, better show a proper warning. | ||||||||||||||||
| found = 0 | ||||||||||||||||
| for name in names: | ||||||||||||||||
| if name in state_dict: | ||||||||||||||||
| found += 1 | ||||||||||||||||
| if found > 1: | ||||||||||||||||
| del state_dict[name] | ||||||||||||||||
| warn_names.add(name) | ||||||||||||||||
| if len(warn_names) > 0: | ||||||||||||||||
| logger.warning_once( | ||||||||||||||||
| f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", | ||||||||||||||||
| ) | ||||||||||||||||
| if not _peft_save_multi_adapter: | ||||||||||||||||
| state_dict = self._post_process_safe_checkpoint(state_dict) | ||||||||||||||||
| else: | ||||||||||||||||
| for adapter_name in peft_multi_adapter_state_dict: | ||||||||||||||||
| peft_multi_adapter_state_dict[adapter_name] = self._post_process_safe_checkpoint( | ||||||||||||||||
| peft_multi_adapter_state_dict[adapter_name] | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # Shard the model if it is too big. | ||||||||||||||||
| if not _hf_peft_config_loaded: | ||||||||||||||||
|
|
@@ -2069,6 +2065,55 @@ def save_pretrained( | |||||||||||||||
| else: | ||||||||||||||||
| weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME | ||||||||||||||||
|
|
||||||||||||||||
| if not _peft_save_multi_adapter: | ||||||||||||||||
| self._shard_and_save_checkpoints( | ||||||||||||||||
| save_directory, | ||||||||||||||||
| state_dict, | ||||||||||||||||
| weights_name, | ||||||||||||||||
| max_shard_size, | ||||||||||||||||
| safe_serialization, | ||||||||||||||||
| is_main_process, | ||||||||||||||||
| variant, | ||||||||||||||||
| save_function, | ||||||||||||||||
| ) | ||||||||||||||||
| else: | ||||||||||||||||
| for adapter_name in peft_multi_adapter_state_dict: | ||||||||||||||||
| # The default adapter always needs to be saved on the root directory | ||||||||||||||||
| adapter_save_path = ( | ||||||||||||||||
| save_directory if adapter_name == "default" else os.path.join(save_directory, adapter_name) | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| self._shard_and_save_checkpoints( | ||||||||||||||||
| adapter_save_path, | ||||||||||||||||
| peft_multi_adapter_state_dict[adapter_name], | ||||||||||||||||
| weights_name, | ||||||||||||||||
| max_shard_size, | ||||||||||||||||
| safe_serialization, | ||||||||||||||||
| is_main_process, | ||||||||||||||||
| variant, | ||||||||||||||||
| save_function, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| if push_to_hub: | ||||||||||||||||
| self._upload_modified_files( | ||||||||||||||||
| save_directory, | ||||||||||||||||
| repo_id, | ||||||||||||||||
| files_timestamps, | ||||||||||||||||
| commit_message=commit_message, | ||||||||||||||||
| token=token, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| def _shard_and_save_checkpoints( | ||||||||||||||||
| self, | ||||||||||||||||
| save_directory, | ||||||||||||||||
| state_dict, | ||||||||||||||||
| weights_name, | ||||||||||||||||
| max_shard_size, | ||||||||||||||||
| safe_serialization, | ||||||||||||||||
| is_main_process, | ||||||||||||||||
| variant, | ||||||||||||||||
| save_function: Callable = torch.save, | ||||||||||||||||
| ): | ||||||||||||||||
| shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) | ||||||||||||||||
|
|
||||||||||||||||
| # Clean the folder from a previous save | ||||||||||||||||
|
|
@@ -2116,15 +2161,47 @@ def save_pretrained( | |||||||||||||||
| f"index located at {save_index_file}." | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| if push_to_hub: | ||||||||||||||||
| self._upload_modified_files( | ||||||||||||||||
| save_directory, | ||||||||||||||||
| repo_id, | ||||||||||||||||
| files_timestamps, | ||||||||||||||||
| commit_message=commit_message, | ||||||||||||||||
| token=token, | ||||||||||||||||
| def _post_process_safe_checkpoint(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: | ||||||||||||||||
| # Safetensors does not allow tensor aliasing. | ||||||||||||||||
| # We're going to remove aliases before saving | ||||||||||||||||
| ptrs = collections.defaultdict(list) | ||||||||||||||||
| for name, tensor in state_dict.items(): | ||||||||||||||||
| ptrs[id_tensor_storage(tensor)].append(name) | ||||||||||||||||
|
|
||||||||||||||||
| # These are all the pointers of shared tensors. | ||||||||||||||||
| shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} | ||||||||||||||||
| warn_names = set() | ||||||||||||||||
| for names in shared_ptrs.values(): | ||||||||||||||||
| # Removing the keys which are declared as known duplicates on | ||||||||||||||||
| # load. This allows to make sure the name which is kept is consistent. | ||||||||||||||||
| if self._tied_weights_keys is not None: | ||||||||||||||||
| found = 0 | ||||||||||||||||
| for name in sorted(names): | ||||||||||||||||
| matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys) | ||||||||||||||||
| if matches_pattern and name in state_dict: | ||||||||||||||||
| found += 1 | ||||||||||||||||
| if found < len(names): | ||||||||||||||||
| del state_dict[name] | ||||||||||||||||
|
|
||||||||||||||||
| # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. | ||||||||||||||||
| # If the link between tensors was done at runtime then `from_pretrained` will not get | ||||||||||||||||
| # the key back leading to random tensor. A proper warning will be shown | ||||||||||||||||
| # during reload (if applicable), but since the file is not necessarily compatible with | ||||||||||||||||
| # the config, better show a proper warning. | ||||||||||||||||
| found = 0 | ||||||||||||||||
| for name in names: | ||||||||||||||||
| if name in state_dict: | ||||||||||||||||
| found += 1 | ||||||||||||||||
| if found > 1: | ||||||||||||||||
| del state_dict[name] | ||||||||||||||||
| warn_names.add(name) | ||||||||||||||||
| if len(warn_names) > 0: | ||||||||||||||||
| logger.warning_once( | ||||||||||||||||
| f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| return state_dict | ||||||||||||||||
|
|
||||||||||||||||
| def get_memory_footprint(self, return_buffers=True): | ||||||||||||||||
| r""" | ||||||||||||||||
| Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here,
.active_adaptersis a method, in PEFT, it's aproperty. Since we newly introduce this method in transformers, do we want to take the opportunity to make itpropertyfor consistency? The downside is thatactive_adapteris a method here, not aproperty, so it would be inconsistent with that method. We could change it too, but that would be BC breaking.