Skip to content

Commit 892d1d3

Browse files
Merge pull request #3 from huggingface/smangrul/fixes-peft-integration
fixes peft integration
2 parents ba6c180 + 32dd0d5 commit 892d1d3

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

src/diffusers/loaders.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,7 @@ def load_lora_weights(
12051205
network_alphas=network_alphas,
12061206
unet=self.unet,
12071207
low_cpu_mem_usage=low_cpu_mem_usage,
1208+
adapter_name=adapter_name,
12081209
_pipeline=self,
12091210
)
12101211
self.load_lora_into_text_encoder(
@@ -1515,7 +1516,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
15151516

15161517
@classmethod
15171518
def load_lora_into_unet(
1518-
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None, adapter_name=None
1519+
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
15191520
):
15201521
"""
15211522
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -3005,7 +3006,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
30053006
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
30063007

30073008
# Overrride to properly handle the loading and unloading of the additional text encoder.
3008-
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
3009+
def load_lora_weights(
3010+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3011+
):
30093012
"""
30103013
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
30113014
`self.text_encoder`.
@@ -3023,6 +3026,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30233026
Parameters:
30243027
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
30253028
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
3029+
adapter_name (`str`, *optional*):
3030+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3031+
`default_{i}` where i is the total number of adapters being loaded.
30263032
kwargs (`dict`, *optional*):
30273033
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
30283034
"""
@@ -3040,7 +3046,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30403046
if not is_correct_format:
30413047
raise ValueError("Invalid LoRA checkpoint.")
30423048

3043-
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
3049+
self.load_lora_into_unet(
3050+
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
3051+
)
30443052
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
30453053
if len(text_encoder_state_dict) > 0:
30463054
self.load_lora_into_text_encoder(
@@ -3049,6 +3057,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30493057
text_encoder=self.text_encoder,
30503058
prefix="text_encoder",
30513059
lora_scale=self.lora_scale,
3060+
adapter_name=adapter_name,
30523061
_pipeline=self,
30533062
)
30543063

@@ -3060,6 +3069,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
30603069
text_encoder=self.text_encoder_2,
30613070
prefix="text_encoder_2",
30623071
lora_scale=self.lora_scale,
3072+
adapter_name=adapter_name,
30633073
_pipeline=self,
30643074
)
30653075

src/diffusers/utils/peft_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def set_adapter_layers(model, enabled=True):
163163
if isinstance(module, BaseTunerLayer):
164164
# The recent version of PEFT needs to call `enable_adapters` instead
165165
if hasattr(module, "enable_adapters"):
166-
module.enable_adapters(enabled=False)
166+
module.enable_adapters(enabled=enabled)
167167
else:
168-
module.disable_adapters = True
168+
module.disable_adapters = not enabled
169169

170170

171171
def set_weights_and_activate_adapters(model, adapter_names, weights):

0 commit comments

Comments
 (0)