@@ -1084,7 +1084,7 @@ def lora_state_dict(
1084
1084
# Map SDXL blocks correctly.
1085
1085
if unet_config is not None :
1086
1086
# use unet config to remap block numbers
1087
- state_dict = cls ._map_sgm_blocks_to_diffusers (state_dict , unet_config )
1087
+ state_dict = cls ._maybe_map_sgm_blocks_to_diffusers (state_dict , unet_config )
1088
1088
state_dict , network_alphas = cls ._convert_kohya_lora_to_diffusers (state_dict )
1089
1089
1090
1090
return state_dict , network_alphas
@@ -1121,24 +1121,41 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
1121
1121
return weight_name
1122
1122
1123
1123
@classmethod
1124
- def _map_sgm_blocks_to_diffusers (cls , state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
1125
- is_all_unet = all (k .startswith ("lora_unet" ) for k in state_dict )
1124
+ def _maybe_map_sgm_blocks_to_diffusers (cls , state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
1125
+ # 1. get all state_dict_keys
1126
+ all_keys = state_dict .keys ()
1127
+ sgm_patterns = ["input_blocks" , "middle_block" , "output_blocks" ]
1128
+
1129
+ # 2. check if needs remapping, if not return original dict
1130
+ is_in_sgm_format = False
1131
+ for key in all_keys :
1132
+ if any (p in key for p in sgm_patterns ):
1133
+ is_in_sgm_format = True
1134
+ break
1135
+
1136
+ if not is_in_sgm_format :
1137
+ return state_dict
1138
+
1139
+ # 3. Else remap from SGM patterns
1126
1140
new_state_dict = {}
1127
1141
inner_block_map = ["resnets" , "attentions" , "upsamplers" ]
1128
1142
1129
1143
# Retrieves # of down, mid and up blocks
1130
1144
input_block_ids , middle_block_ids , output_block_ids = set (), set (), set ()
1131
- for layer in state_dict :
1132
- if "text" not in layer :
1145
+
1146
+ for layer in all_keys :
1147
+ if "text" in layer :
1148
+ new_state_dict [layer ] = state_dict .pop (layer )
1149
+ else :
1133
1150
layer_id = int (layer .split (delimiter )[:block_slice_pos ][- 1 ])
1134
- if "input_blocks" in layer :
1151
+ if sgm_patterns [ 0 ] in layer :
1135
1152
input_block_ids .add (layer_id )
1136
- elif "middle_block" in layer :
1153
+ elif sgm_patterns [ 1 ] in layer :
1137
1154
middle_block_ids .add (layer_id )
1138
- elif "output_blocks" in layer :
1155
+ elif sgm_patterns [ 2 ] in layer :
1139
1156
output_block_ids .add (layer_id )
1140
1157
else :
1141
- raise ValueError ("Checkpoint not supported" )
1158
+ raise ValueError (f "Checkpoint not supported because layer { layer } not supported. " )
1142
1159
1143
1160
input_blocks = {
1144
1161
layer_id : [key for key in state_dict if f"input_blocks{ delimiter } { layer_id } " in key ]
@@ -1201,12 +1218,8 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl
1201
1218
)
1202
1219
new_state_dict [new_key ] = state_dict .pop (key )
1203
1220
1204
- if is_all_unet and len (state_dict ) > 0 :
1221
+ if len (state_dict ) > 0 :
1205
1222
raise ValueError ("At this point all state dict entries have to be converted." )
1206
- else :
1207
- # Remaining is the text encoder state dict.
1208
- for k , v in state_dict .items ():
1209
- new_state_dict .update ({k : v })
1210
1223
1211
1224
return new_state_dict
1212
1225
0 commit comments