@@ -1205,6 +1205,7 @@ def load_lora_weights(
1205
1205
network_alphas = network_alphas ,
1206
1206
unet = self .unet ,
1207
1207
low_cpu_mem_usage = low_cpu_mem_usage ,
1208
+ adapter_name = adapter_name ,
1208
1209
_pipeline = self ,
1209
1210
)
1210
1211
self .load_lora_into_text_encoder (
@@ -1515,7 +1516,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
1515
1516
1516
1517
@classmethod
1517
1518
def load_lora_into_unet (
1518
- cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None , adapter_name = None
1519
+ cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , adapter_name = None , _pipeline = None
1519
1520
):
1520
1521
"""
1521
1522
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -3005,7 +3006,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
3005
3006
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
3006
3007
3007
3008
# Overrride to properly handle the loading and unloading of the additional text encoder.
3008
- def load_lora_weights (self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs ):
3009
+ def load_lora_weights (
3010
+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3011
+ ):
3009
3012
"""
3010
3013
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
3011
3014
`self.text_encoder`.
@@ -3023,6 +3026,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
3023
3026
Parameters:
3024
3027
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3025
3028
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
3029
+ adapter_name (`str`, *optional*):
3030
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3031
+ `default_{i}` where i is the total number of adapters being loaded.
3026
3032
kwargs (`dict`, *optional*):
3027
3033
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
3028
3034
"""
@@ -3040,7 +3046,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
3040
3046
if not is_correct_format :
3041
3047
raise ValueError ("Invalid LoRA checkpoint." )
3042
3048
3043
- self .load_lora_into_unet (state_dict , network_alphas = network_alphas , unet = self .unet , _pipeline = self )
3049
+ self .load_lora_into_unet (
3050
+ state_dict , network_alphas = network_alphas , unet = self .unet , adapter_name = adapter_name , _pipeline = self
3051
+ )
3044
3052
text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
3045
3053
if len (text_encoder_state_dict ) > 0 :
3046
3054
self .load_lora_into_text_encoder (
@@ -3049,6 +3057,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
3049
3057
text_encoder = self .text_encoder ,
3050
3058
prefix = "text_encoder" ,
3051
3059
lora_scale = self .lora_scale ,
3060
+ adapter_name = adapter_name ,
3052
3061
_pipeline = self ,
3053
3062
)
3054
3063
@@ -3060,6 +3069,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
3060
3069
text_encoder = self .text_encoder_2 ,
3061
3070
prefix = "text_encoder_2" ,
3062
3071
lora_scale = self .lora_scale ,
3072
+ adapter_name = adapter_name ,
3063
3073
_pipeline = self ,
3064
3074
)
3065
3075
0 commit comments