@@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
755755 return block_idx
756756
757757
758+ def convert_diffusers_to_sai_if_needed (weights_sd ):
759+ # only supports U-Net LoRA modules
760+
761+ found_up_down_blocks = False
762+ for k in list (weights_sd .keys ()):
763+ if "down_blocks" in k :
764+ found_up_down_blocks = True
765+ break
766+ if "up_blocks" in k :
767+ found_up_down_blocks = True
768+ break
769+ if not found_up_down_blocks :
770+ return
771+
772+ from library .sdxl_model_util import make_unet_conversion_map
773+
774+ unet_conversion_map = make_unet_conversion_map ()
775+ unet_conversion_map = {hf .replace ("." , "_" )[:- 1 ]: sd .replace ("." , "_" )[:- 1 ] for sd , hf in unet_conversion_map }
776+
777+ # # add extra conversion
778+ # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
779+
780+ logger .info (f"Converting LoRA keys from Diffusers to SAI" )
781+ lora_unet_prefix = "lora_unet_"
782+ for k in list (weights_sd .keys ()):
783+ if not k .startswith (lora_unet_prefix ):
784+ continue
785+
786+ unet_module_name = k [len (lora_unet_prefix ) :].split ("." )[0 ]
787+
788+ # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
789+ for hf_module_name , sd_module_name in unet_conversion_map .items ():
790+ if hf_module_name in unet_module_name :
791+ new_key = (
792+ lora_unet_prefix
793+ + unet_module_name .replace (hf_module_name , sd_module_name )
794+ + k [len (lora_unet_prefix ) + len (unet_module_name ) :]
795+ )
796+ weights_sd [new_key ] = weights_sd .pop (k )
797+ found = True
798+ break
799+
800+ if not found :
801+ logger .warning (f"Key { k } is not found in unet_conversion_map" )
802+
803+
758804# Create network from weights for inference, weights are not loaded here (because can be merged)
759805def create_network_from_weights (multiplier , file , vae , text_encoder , unet , weights_sd = None , for_inference = False , ** kwargs ):
760806 # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
@@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
768814 else :
769815 weights_sd = torch .load (file , map_location = "cpu" )
770816
817+ # if keys are Diffusers based, convert to SAI based
818+ convert_diffusers_to_sai_if_needed (weights_sd )
819+
771820 # get dim/alpha mapping
772821 modules_dim = {}
773822 modules_alpha = {}
0 commit comments