Skip to content

Commit 6500f78

Browse files
[PEFT] Support low_cpu_mem_usage option for PEFT loading adapters (#33725)
* [PEFT] Support low_cpu_mem_usage for PEFT loading PEFT added support for low_cpu_mem_usage=True when loading adapters in huggingface/peft#1961. This feature is now available when installing PEFT v0.13.0. With this PR, this option is also supported when loading PEFT adapters directly into transformers models. Additionally, with this PR, huggingface/diffusers#9510 will be unblocked, which implements this option in diffusers. * Fix typo
1 parent bf0ffe3 commit 6500f78

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

src/transformers/integrations/peft.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import importlib
1415
import inspect
1516
import warnings
1617
from typing import Any, Dict, List, Optional, Union
1718

19+
from packaging import version
20+
1821
from ..utils import (
1922
check_peft_version,
2023
find_adapter_config_file,
@@ -77,6 +80,7 @@ def load_adapter(
7780
offload_index: Optional[int] = None,
7881
peft_config: Dict[str, Any] = None,
7982
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
83+
low_cpu_mem_usage: bool = False,
8084
adapter_kwargs: Optional[Dict[str, Any]] = None,
8185
) -> None:
8286
"""
@@ -129,12 +133,27 @@ def load_adapter(
129133
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
130134
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
131135
dicts
136+
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
137+
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
138+
Requires PEFT version 0.13.0 or higher.
132139
adapter_kwargs (`Dict[str, Any]`, *optional*):
133140
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
134141
`find_adapter_config_file` method.
135142
"""
136143
check_peft_version(min_version=MIN_PEFT_VERSION)
137144

145+
# peft only supports low_cpu_mem_usage starting from v0.13.0
146+
peft_load_kwargs = {}
147+
if low_cpu_mem_usage:
148+
min_version_lcmu = "0.13.0"
149+
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
150+
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
151+
else:
152+
raise ValueError(
153+
"The version of PEFT you are using does not support `low_cpu_mem_usage` yet, "
154+
f"please install PEFT >= {min_version_lcmu}."
155+
)
156+
138157
adapter_name = adapter_name if adapter_name is not None else "default"
139158
if adapter_kwargs is None:
140159
adapter_kwargs = {}
@@ -192,7 +211,7 @@ def load_adapter(
192211
)
193212

194213
# Create and add fresh new adapters into the model.
195-
inject_adapter_in_model(peft_config, self, adapter_name)
214+
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
196215

197216
if not self._hf_peft_config_loaded:
198217
self._hf_peft_config_loaded = True
@@ -211,7 +230,9 @@ def load_adapter(
211230
processed_adapter_state_dict[new_key] = value
212231

213232
# Load state dict
214-
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)
233+
incompatible_keys = set_peft_model_state_dict(
234+
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
235+
)
215236

216237
if incompatible_keys is not None:
217238
# check only for unexpected keys

tests/peft_integration/test_peft_integration.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import importlib
1516
import os
1617
import tempfile
1718
import unittest
1819

1920
from huggingface_hub import hf_hub_download
21+
from packaging import version
2022

2123
from transformers import AutoModelForCausalLM, OPTForCausalLM
2224
from transformers.testing_utils import (
@@ -478,6 +480,48 @@ def test_peft_add_adapter_with_state_dict(self):
478480
# dummy generation
479481
_ = model.generate(input_ids=dummy_input)
480482

483+
def test_peft_add_adapter_with_state_dict_low_cpu_mem_usage(self):
484+
"""
485+
Check the usage of low_cpu_mem_usage, which is supported in PEFT >= 0.13.0
486+
"""
487+
from peft import LoraConfig
488+
489+
min_version_lcmu = "0.13.0"
490+
is_lcmu_supported = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu)
491+
492+
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
493+
for transformers_class in self.transformers_test_model_classes:
494+
model = transformers_class.from_pretrained(model_id).to(torch_device)
495+
496+
peft_config = LoraConfig()
497+
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
498+
dummy_state_dict = torch.load(state_dict_path)
499+
500+
# this should always work
501+
model.load_adapter(
502+
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
503+
)
504+
505+
if is_lcmu_supported:
506+
# if supported, this should not raise an error
507+
model.load_adapter(
508+
adapter_state_dict=dummy_state_dict,
509+
adapter_name="other",
510+
peft_config=peft_config,
511+
low_cpu_mem_usage=True,
512+
)
513+
# after loading, no meta device should be remaining
514+
self.assertFalse(any((p.device.type == "meta") for p in model.parameters()))
515+
else:
516+
err_msg = r"The version of PEFT you are using does not support `low_cpu_mem_usage` yet"
517+
with self.assertRaisesRegex(ValueError, err_msg):
518+
model.load_adapter(
519+
adapter_state_dict=dummy_state_dict,
520+
adapter_name="other",
521+
peft_config=peft_config,
522+
low_cpu_mem_usage=True,
523+
)
524+
481525
def test_peft_from_pretrained_hub_kwargs(self):
482526
"""
483527
Tests different combinations of PEFT model + from_pretrained + hub kwargs

0 commit comments

Comments
 (0)